""" Auth service — pure Python, no FastAPI coupling. Handles: - Password hashing (Argon2 via pwdlib) and constant-time verification (SEC-06) - JWT access token creation/decode (PyJWT) - Refresh token lifecycle with family revocation on reuse (AUTH-07, RFC 9700) - TOTP provisioning and verification with replay prevention (AUTH-08) - Backup code generation, storage, and constant-time verification (AUTH-02) - HaveIBeenPwned k-anonymity check (SEC-03) - Admin account bootstrap (D-04, D-05, D-06) Security invariants: - All token/code comparisons use hmac.compare_digest (constant-time, SEC-06) - No function raises HTTPException — callers (api/) map ValueError to HTTP errors - refresh token family revocation enqueues send_security_alert_email.delay (AUTH-07) """ from __future__ import annotations import hashlib import hmac import logging import secrets import uuid from datetime import datetime, timezone, timedelta from typing import Optional import httpx import jwt import pyotp from pwdlib import PasswordHash from pwdlib.hashers.argon2 import Argon2Hasher from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from config import settings from db.models import BackupCode, Quota, RefreshToken, User logger = logging.getLogger(__name__) # ── Password hashing ──────────────────────────────────────────────────────────── # Single shared PasswordHash instance; Argon2 is the only enabled hasher. _pwd = PasswordHash([Argon2Hasher()]) def hash_password(plain: str) -> str: """Return the Argon2 hash of *plain*.""" return _pwd.hash(plain) def verify_password(plain: str, hashed: str) -> bool: """Return True if *plain* matches *hashed* (constant-time, SEC-06). pwdlib.verify returns True/False and internally uses constant-time comparison. """ try: return _pwd.verify(plain, hashed) except Exception: return False # ── JWT helpers ───────────────────────────────────────────────────────────────── def create_access_token(user_id: str, role: str) -> str: """Return a signed JWT access token. Claims: sub=user_id, role=role, typ='access', exp=now+access_token_expire_minutes. """ now = datetime.now(timezone.utc) payload = { "sub": str(user_id), "role": role, "typ": "access", "iat": now, "exp": now + timedelta(minutes=settings.access_token_expire_minutes), } return jwt.encode(payload, settings.secret_key, algorithm="HS256") def decode_access_token(token: str) -> dict: """Decode and validate an access token; raises ValueError on any failure. Verifies: signature, expiry, and typ='access' (T-02-01 — prevents password-reset tokens from being used as access tokens). """ try: payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"]) except jwt.ExpiredSignatureError as exc: raise ValueError("Token has expired") from exc except jwt.PyJWTError as exc: raise ValueError(f"Invalid token: {exc}") from exc if payload.get("typ") != "access": raise ValueError("Token type mismatch: expected 'access'") return payload def create_password_reset_token(user_id: str) -> str: """Return a short-lived signed JWT for password reset. Claims: sub=user_id, typ='password-reset', exp=now+3600s. """ now = datetime.now(timezone.utc) payload = { "sub": str(user_id), "typ": "password-reset", "iat": now, "exp": now + timedelta(seconds=3600), } return jwt.encode(payload, settings.secret_key, algorithm="HS256") def decode_password_reset_token(token: str) -> str: """Decode a password-reset token; raises ValueError on failure. Returns the user_id string. """ try: payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"]) except jwt.ExpiredSignatureError as exc: raise ValueError("Reset token has expired") from exc except jwt.PyJWTError as exc: raise ValueError(f"Invalid reset token: {exc}") from exc if payload.get("typ") != "password-reset": raise ValueError("Token type mismatch: expected 'password-reset'") return payload["sub"] # ── Refresh token lifecycle ───────────────────────────────────────────────────── async def create_refresh_token(session: AsyncSession, user_id: uuid.UUID) -> str: """Insert a new RefreshToken row and return the raw (unhashed) token string. The raw token is returned to the caller and set as an httpOnly cookie. Only the SHA-256 hash is stored in the database. """ raw = secrets.token_urlsafe(32) token_hash = hashlib.sha256(raw.encode()).hexdigest() now = datetime.now(timezone.utc) row = RefreshToken( id=uuid.uuid4(), user_id=user_id, token_hash=token_hash, expires_at=now + timedelta(days=settings.refresh_token_expire_days), revoked=False, ) session.add(row) await session.commit() return raw async def rotate_refresh_token( session: AsyncSession, raw_token: str ) -> tuple[str, str]: """Rotate a refresh token: revoke the old one and issue a new one. Returns (new_raw_token, user_id_str). Family revocation (AUTH-07 / RFC 9700): If the presented token has revoked=True, ALL tokens for that user are revoked and send_security_alert_email.delay is enqueued. Then raises ValueError("token_family_revoked"). Raises ValueError for any invalid or expired token. """ token_hash = hashlib.sha256(raw_token.encode()).hexdigest() result = await session.execute( select(RefreshToken).where(RefreshToken.token_hash == token_hash) ) row: Optional[RefreshToken] = result.scalar_one_or_none() if row is None: raise ValueError("Refresh token not found") if row.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): raise ValueError("Refresh token has expired") if row.revoked: # T-02-02: reuse of revoked token — family revocation await revoke_all_refresh_tokens(session, row.user_id) # Enqueue security alert email (deferred import to avoid circular dependency) from tasks.email_tasks import send_security_alert_email # noqa: PLC0415 send_security_alert_email.delay(str(row.user_id)) raise ValueError("token_family_revoked") # Valid token: revoke the old one row.revoked = True await session.flush() # Issue a new token in the same family new_raw = await create_refresh_token(session, row.user_id) return new_raw, str(row.user_id) async def revoke_all_refresh_tokens( session: AsyncSession, user_id: uuid.UUID ) -> int: """Mark all active refresh tokens for user_id as revoked. Returns the count of revoked tokens (supports sign-out-all-devices). """ result = await session.execute( select(RefreshToken).where( RefreshToken.user_id == user_id, RefreshToken.revoked.is_(False), ) ) rows = result.scalars().all() count = 0 for row in rows: row.revoked = True count += 1 await session.flush() return count # ── TOTP provisioning ─────────────────────────────────────────────────────────── async def provision_totp( session: AsyncSession, user_id: uuid.UUID ) -> tuple[str, str]: """Generate a new TOTP secret for the user and store it (not yet enabled). Returns (secret, provisioning_uri). The secret is base32-encoded. provisioning_uri is suitable for QR code generation. """ user = await session.get(User, user_id) if user is None: raise ValueError("User not found") secret = pyotp.random_base32() user.totp_secret = secret await session.commit() totp = pyotp.totp.TOTP(secret) uri = totp.provisioning_uri(user.email, issuer_name="DocuVault") return secret, uri async def verify_totp( session: AsyncSession, user_id: uuid.UUID, code: str, redis_client, ) -> bool: """Verify a TOTP code with replay prevention (AUTH-08). valid_window=1 allows ±30s clock drift (per STATE.md recommendation). Replay prevention: stores used codes in Redis with key 'totp_used:{user_id}:{code}' and TTL=90s (covers the full ±30s validity window). Returns False if the code is invalid, already used, or user has no TOTP secret. """ user = await session.get(User, user_id) if user is None or not user.totp_secret: return False totp = pyotp.TOTP(user.totp_secret) # Check replay prevention before verifying replay_key = f"totp_used:{user_id}:{code}" if await redis_client.get(replay_key): return False # Code already used within the validity window if not totp.verify(code, valid_window=1): return False # Mark as used for 90s (covers valid_window=1: ±30s = 90s total) await redis_client.set(replay_key, "1", ex=90) return True # ── Backup codes ──────────────────────────────────────────────────────────────── def generate_backup_codes(n: int = 10) -> list[str]: """Return *n* random 8-character uppercase alphanumeric backup codes.""" codes = [] for _ in range(n): # secrets.token_hex(4) returns 8 hex chars; uppercase for readability code = secrets.token_hex(4).upper() codes.append(code) return codes async def store_backup_codes( session: AsyncSession, user_id: uuid.UUID, codes: list[str] ) -> None: """Store backup codes for a user, replacing any existing unused codes. Each code is stored as an Argon2 hash (never plaintext, T-02-03). Existing unused codes are deleted first to prevent accumulation. """ # Delete existing unused codes result = await session.execute( select(BackupCode).where( BackupCode.user_id == user_id, BackupCode.used_at.is_(None), ) ) for row in result.scalars().all(): await session.delete(row) await session.flush() # Insert new hashed codes for code in codes: row = BackupCode( id=uuid.uuid4(), user_id=user_id, code_hash=hash_password(code), used_at=None, ) session.add(row) await session.commit() async def verify_backup_code( session: AsyncSession, user_id: uuid.UUID, code: str ) -> bool: """Verify a backup code using constant-time comparison. Always iterates ALL unused codes (prevents timing-based enumeration, SEC-06). On match: sets used_at=now() and commits. Returns True on first match. Returns False if no code matches. """ result = await session.execute( select(BackupCode).where( BackupCode.user_id == user_id, BackupCode.used_at.is_(None), ) ) rows = result.scalars().all() matched_row: Optional[BackupCode] = None for row in rows: # Always call verify_password for ALL rows (constant-time: no early exit) if verify_password(code, row.code_hash): matched_row = row # record match but keep iterating if matched_row is None: return False # Mark the matched code as used matched_row.used_at = datetime.now(timezone.utc) await session.commit() return True # ── HaveIBeenPwned check ──────────────────────────────────────────────────────── async def check_hibp(password: str) -> bool: """Check if password appears in HaveIBeenPwned using the k-anonymity model. Sends only the first 5 chars of the SHA-1 hash to the HIBP API (T-02-05). Returns True if the password has been breached, False otherwise. On network error: logs a warning and returns False (fail-open, T-02-06). """ sha1 = hashlib.sha1(password.encode("utf-8")).hexdigest().upper() prefix, suffix = sha1[:5], sha1[5:] try: async with httpx.AsyncClient(timeout=5.0) as client: resp = await client.get( f"https://api.pwnedpasswords.com/range/{prefix}", headers={"Add-Padding": "true"}, ) resp.raise_for_status() except Exception as exc: logger.warning("HIBP check failed (fail-open): %s", exc) return False for line in resp.text.splitlines(): parts = line.split(":") if len(parts) == 2: candidate_suffix, count_str = parts if hmac.compare_digest(candidate_suffix.upper(), suffix): try: count = int(count_str.strip()) except ValueError: count = 1 if count > 0: return True return False # ── Admin bootstrap ───────────────────────────────────────────────────────────── async def bootstrap_admin(session: AsyncSession) -> None: """Idempotent admin account bootstrap (D-04, D-05, D-06). If the users table is empty AND settings.admin_email and settings.admin_password are both non-empty, creates an admin User row with a Quota row. Logs a WARNING if env vars are missing (D-05) but never raises. """ if not settings.admin_email or not settings.admin_password: logger.warning( "Admin bootstrap skipped: ADMIN_EMAIL and/or ADMIN_PASSWORD not set (D-05). " "Set both env vars to seed the first admin account on startup." ) return # Check if any users exist result = await session.execute(select(User).limit(1)) if result.scalar_one_or_none() is not None: # Users already exist — idempotent, skip (D-04) return admin_id = uuid.uuid4() admin_user = User( id=admin_id, handle="admin", email=settings.admin_email, password_hash=hash_password(settings.admin_password), role="admin", is_active=True, password_must_change=False, ) quota = Quota( user_id=admin_id, limit_bytes=104857600, # 100 MB default (D-06) used_bytes=0, ) session.add(admin_user) await session.flush() # persist User first so Quota FK is satisfied session.add(quota) await session.commit() logger.info("Admin account bootstrapped for %s", settings.admin_email)