""" Cloud storage shared utilities for DocuVault. Security design: SSRF prevention (D-17): validate_cloud_url() resolves DNS via socket.getaddrinfo *before* checking the resolved IP against blocked networks. This prevents DNS-rebinding attacks where a hostname passes a string check but resolves to an internal IP. It also explicitly blocks the string "localhost" before any DNS resolution. HKDF credential encryption (D-18, CLOUD-02): _derive_fernet_key() creates a FRESH HKDF instance on every call. The cryptography library raises AlreadyFinalized if .derive() is called twice on the same instance (Pitfall 3 in RESEARCH.md). This function avoids that by constructing a new HKDF(...) object each time. References: RESEARCH.md Pattern 2 — HKDF+Fernet RESEARCH.md Pattern 6 — SSRF validation via ipaddress + socket.getaddrinfo CLAUDE.md — cloud credentials encrypted with HKDF per-user key derivation """ from __future__ import annotations import base64 import ipaddress import json import socket from urllib.parse import urlparse from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDF # Networks that must never be the target of outbound cloud HTTP requests (D-17) _BLOCKED_NETS = [ ipaddress.ip_network("127.0.0.0/8"), # IPv4 loopback ipaddress.ip_network("169.254.0.0/16"), # Link-local (AWS/GCP metadata endpoint) ipaddress.ip_network("10.0.0.0/8"), # RFC-1918 class A ipaddress.ip_network("172.16.0.0/12"), # RFC-1918 class B ipaddress.ip_network("192.168.0.0/16"), # RFC-1918 class C ipaddress.ip_network("::1/128"), # IPv6 loopback ipaddress.ip_network("fc00::/7"), # IPv6 Unique Local Address (ULA) ] def validate_cloud_url(url: str) -> None: """Raise ValueError if the URL targets a private, internal, or restricted address. Security contract (D-17): 1. Reject non-http/https schemes. 2. Reject URLs with no hostname. 3. Explicitly reject the string "localhost" before DNS resolution to avoid cases where getaddrinfo behaviour varies by OS. 4. If the hostname is a raw IP address, check it directly. Otherwise, resolve via socket.getaddrinfo (DNS lookup) and check the resolved IP. This closes the DNS-rebinding window: the hostname must resolve to a non-private IP *at validation time*. 5. Raise ValueError for any IP that falls inside a BLOCKED_NETS entry. Called immediately before every outbound WebDAV/Nextcloud HTTP request, not only at connect-time (RESEARCH.md Pitfall 5 — DNS rebinding mitigation). Args: url: The user-supplied WebDAV, Nextcloud, or cloud server URL. Raises: ValueError: If the URL uses a blocked scheme, has no hostname, or resolves to a private/internal address. """ parsed = urlparse(url) # Step 1: scheme check if parsed.scheme not in ("http", "https"): raise ValueError( f"Unsupported URL scheme '{parsed.scheme}': only http and https are allowed." ) # Step 2: hostname presence hostname = parsed.hostname # lowercased, brackets stripped for IPv6 if not hostname: raise ValueError("URL has no hostname.") # Step 3: explicit string block for 'localhost' (before DNS resolution) if hostname == "localhost": raise ValueError( "URL targets localhost — this is a private/internal address." ) # Step 4: resolve hostname to IP (or parse if already an IP literal) try: addr = ipaddress.ip_address(hostname) except ValueError: # Not a raw IP literal — resolve via DNS try: resolved = socket.getaddrinfo(hostname, None)[0][4][0] addr = ipaddress.ip_address(resolved) except socket.gaierror as exc: raise ValueError(f"Cannot resolve hostname '{hostname}': {exc}") from exc except (ValueError, IndexError) as exc: raise ValueError(f"Unexpected error resolving '{hostname}': {exc}") from exc # Step 5: check resolved IP against each blocked network for net in _BLOCKED_NETS: # Use try/except to handle IPv4/IPv6 family mismatch gracefully try: if addr in net: raise ValueError( f"URL targets a private/internal address: {addr} is in {net}" ) except TypeError: # Different address families (e.g. IPv4 addr in an IPv6 network) — skip continue def _derive_fernet_key(master_key: bytes, user_id: str) -> Fernet: """Derive a per-user Fernet encryption key using HKDF-SHA256. Security notes: - A FRESH HKDF instance is created on every call. The cryptography library raises AlreadyFinalized if .derive() is called twice on the same instance. Never cache or reuse the HKDF object (RESEARCH.md Pitfall 3). - salt = user_id.encode() is deterministic (same user → same key), which is required so that encrypt and decrypt produce consistent results. - info = b"cloud-credentials" provides domain separation so the same master_key cannot be used for unrelated HKDF derivations. Args: master_key: The CLOUD_CREDS_KEY env var as bytes. user_id: The authenticated user's UUID string (used as HKDF salt). Returns: A Fernet instance ready for encrypt/decrypt operations. """ hkdf = HKDF( algorithm=hashes.SHA256(), length=32, salt=user_id.encode("utf-8"), info=b"cloud-credentials", ) raw_key: bytes = hkdf.derive(master_key) fernet_key = base64.urlsafe_b64encode(raw_key) return Fernet(fernet_key) def encrypt_credentials(master_key: bytes, user_id: str, credentials: dict) -> str: """Encrypt a credentials dict to a Fernet token string. The returned string is safe to store in the database credentials_enc column. It is opaque base64 ciphertext — no plaintext fields are present. Args: master_key: The CLOUD_CREDS_KEY env var as bytes. user_id: The authenticated user's UUID string (HKDF salt). credentials: A JSON-serialisable dict (access_token, refresh_token, etc.). Returns: A URL-safe base64 Fernet token (str). """ f = _derive_fernet_key(master_key, user_id) plaintext = json.dumps(credentials).encode("utf-8") return f.encrypt(plaintext).decode("utf-8") def decrypt_credentials(master_key: bytes, user_id: str, credentials_enc: str) -> dict: """Decrypt a Fernet token back to the original credentials dict. Args: master_key: The CLOUD_CREDS_KEY env var as bytes. user_id: The authenticated user's UUID string (HKDF salt). credentials_enc: The Fernet token string from the database. Returns: The original credentials dict. Raises: cryptography.fernet.InvalidToken: If the token is tampered with or the wrong user_id (and thus wrong key) is used. """ f = _derive_fernet_key(master_key, user_id) plaintext = f.decrypt(credentials_enc.encode("utf-8")) return json.loads(plaintext)