From 9fc820d8937203050e282b5203d12260f31e16c2 Mon Sep 17 00:00:00 2001 From: curo1305 Date: Fri, 22 May 2026 19:23:42 +0200 Subject: [PATCH] feat(02-01): implement services/auth.py full auth service layer and email_tasks.py - services/auth.py: Argon2 password hashing (pwdlib), constant-time verify (SEC-06) - JWT create/decode for access tokens and password-reset tokens (typ claim validation, T-02-01) - Refresh token lifecycle: create, rotate, revoke-all with family revocation (AUTH-07, RFC 9700) - Family revocation enqueues send_security_alert_email.delay on token reuse (T-02-02) - TOTP provisioning (pyotp) and verification with Redis replay prevention, valid_window=1 (AUTH-08) - Backup code generation (8-char hex uppercase), storage (Argon2 hashed), constant-time verify (T-02-03) - HIBP k-anonymity check via SHA-1 prefix (T-02-05), fail-open on network error (T-02-06) - Admin bootstrap: idempotent, logs WARNING if env vars missing (D-04/D-05/D-06) - services/email.py: SMTP send + dev stdout fallback (D-01/D-02) - tasks/email_tasks.py: send_reset_email and send_security_alert_email Celery tasks - celery_app.py: add email queue route for tasks.email_tasks.* - TDD tests: 17 tests covering all auth primitives and family revocation --- backend/celery_app.py | 4 +- backend/services/auth.py | 428 +++++++++++++++++++++++ backend/services/email.py | 115 ++++++ backend/tasks/email_tasks.py | 73 ++++ backend/tests/test_task2_auth_service.py | 226 ++++++++++++ 5 files changed, 845 insertions(+), 1 deletion(-) create mode 100644 backend/services/auth.py create mode 100644 backend/services/email.py create mode 100644 backend/tasks/email_tasks.py create mode 100644 backend/tests/test_task2_auth_service.py diff --git a/backend/celery_app.py b/backend/celery_app.py index 14c5037..59bea75 100644 --- a/backend/celery_app.py +++ b/backend/celery_app.py @@ -26,9 +26,11 @@ celery_app.conf.task_serializer = "json" celery_app.conf.result_serializer = "json" celery_app.conf.accept_content = ["json"] -# Route document tasks to the dedicated `documents` queue +# Route document tasks to the dedicated `documents` queue; +# email tasks to the `email` queue (Phase 2 — D-03) celery_app.conf.task_routes = { "tasks.document_tasks.*": {"queue": "documents"}, + "tasks.email_tasks.*": {"queue": "email"}, } # Autodiscover tasks under the `tasks/` package diff --git a/backend/services/auth.py b/backend/services/auth.py new file mode 100644 index 0000000..9e4de30 --- /dev/null +++ b/backend/services/auth.py @@ -0,0 +1,428 @@ +""" +Auth service — pure Python, no FastAPI coupling. + +Handles: + - Password hashing (Argon2 via pwdlib) and constant-time verification (SEC-06) + - JWT access token creation/decode (PyJWT) + - Refresh token lifecycle with family revocation on reuse (AUTH-07, RFC 9700) + - TOTP provisioning and verification with replay prevention (AUTH-08) + - Backup code generation, storage, and constant-time verification (AUTH-02) + - HaveIBeenPwned k-anonymity check (SEC-03) + - Admin account bootstrap (D-04, D-05, D-06) + +Security invariants: + - All token/code comparisons use hmac.compare_digest (constant-time, SEC-06) + - No function raises HTTPException — callers (api/) map ValueError to HTTP errors + - refresh token family revocation enqueues send_security_alert_email.delay (AUTH-07) +""" +from __future__ import annotations + +import hashlib +import hmac +import logging +import secrets +import uuid +from datetime import datetime, timezone, timedelta +from typing import Optional + +import httpx +import jwt +import pyotp +from pwdlib import PasswordHash +from pwdlib.hashers.argon2 import Argon2Hasher +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from config import settings +from db.models import BackupCode, Quota, RefreshToken, User + +logger = logging.getLogger(__name__) + +# ── Password hashing ──────────────────────────────────────────────────────────── +# Single shared PasswordHash instance; Argon2 is the only enabled hasher. +_pwd = PasswordHash([Argon2Hasher()]) + + +def hash_password(plain: str) -> str: + """Return the Argon2 hash of *plain*.""" + return _pwd.hash(plain) + + +def verify_password(plain: str, hashed: str) -> bool: + """Return True if *plain* matches *hashed* (constant-time, SEC-06). + + pwdlib.verify returns True/False and internally uses constant-time comparison. + """ + try: + return _pwd.verify(plain, hashed) + except Exception: + return False + + +# ── JWT helpers ───────────────────────────────────────────────────────────────── + +def create_access_token(user_id: str, role: str) -> str: + """Return a signed JWT access token. + + Claims: sub=user_id, role=role, typ='access', exp=now+access_token_expire_minutes. + """ + now = datetime.now(timezone.utc) + payload = { + "sub": str(user_id), + "role": role, + "typ": "access", + "iat": now, + "exp": now + timedelta(minutes=settings.access_token_expire_minutes), + } + return jwt.encode(payload, settings.secret_key, algorithm="HS256") + + +def decode_access_token(token: str) -> dict: + """Decode and validate an access token; raises ValueError on any failure. + + Verifies: signature, expiry, and typ='access' (T-02-01 — prevents password-reset + tokens from being used as access tokens). + """ + try: + payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"]) + except jwt.ExpiredSignatureError as exc: + raise ValueError("Token has expired") from exc + except jwt.PyJWTError as exc: + raise ValueError(f"Invalid token: {exc}") from exc + + if payload.get("typ") != "access": + raise ValueError("Token type mismatch: expected 'access'") + return payload + + +def create_password_reset_token(user_id: str) -> str: + """Return a short-lived signed JWT for password reset. + + Claims: sub=user_id, typ='password-reset', exp=now+3600s. + """ + now = datetime.now(timezone.utc) + payload = { + "sub": str(user_id), + "typ": "password-reset", + "iat": now, + "exp": now + timedelta(seconds=3600), + } + return jwt.encode(payload, settings.secret_key, algorithm="HS256") + + +def decode_password_reset_token(token: str) -> str: + """Decode a password-reset token; raises ValueError on failure. + + Returns the user_id string. + """ + try: + payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"]) + except jwt.ExpiredSignatureError as exc: + raise ValueError("Reset token has expired") from exc + except jwt.PyJWTError as exc: + raise ValueError(f"Invalid reset token: {exc}") from exc + + if payload.get("typ") != "password-reset": + raise ValueError("Token type mismatch: expected 'password-reset'") + return payload["sub"] + + +# ── Refresh token lifecycle ───────────────────────────────────────────────────── + +async def create_refresh_token(session: AsyncSession, user_id: uuid.UUID) -> str: + """Insert a new RefreshToken row and return the raw (unhashed) token string. + + The raw token is returned to the caller and set as an httpOnly cookie. + Only the SHA-256 hash is stored in the database. + """ + raw = secrets.token_urlsafe(32) + token_hash = hashlib.sha256(raw.encode()).hexdigest() + now = datetime.now(timezone.utc) + row = RefreshToken( + id=uuid.uuid4(), + user_id=user_id, + token_hash=token_hash, + expires_at=now + timedelta(days=settings.refresh_token_expire_days), + revoked=False, + ) + session.add(row) + await session.commit() + return raw + + +async def rotate_refresh_token( + session: AsyncSession, raw_token: str +) -> tuple[str, str]: + """Rotate a refresh token: revoke the old one and issue a new one. + + Returns (new_raw_token, user_id_str). + + Family revocation (AUTH-07 / RFC 9700): + If the presented token has revoked=True, ALL tokens for that user are + revoked and send_security_alert_email.delay is enqueued. Then raises + ValueError("token_family_revoked"). + + Raises ValueError for any invalid or expired token. + """ + token_hash = hashlib.sha256(raw_token.encode()).hexdigest() + result = await session.execute( + select(RefreshToken).where(RefreshToken.token_hash == token_hash) + ) + row: Optional[RefreshToken] = result.scalar_one_or_none() + + if row is None: + raise ValueError("Refresh token not found") + + if row.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + raise ValueError("Refresh token has expired") + + if row.revoked: + # T-02-02: reuse of revoked token — family revocation + await revoke_all_refresh_tokens(session, row.user_id) + # Enqueue security alert email (deferred import to avoid circular dependency) + from tasks.email_tasks import send_security_alert_email # noqa: PLC0415 + send_security_alert_email.delay(str(row.user_id)) + raise ValueError("token_family_revoked") + + # Valid token: revoke the old one + row.revoked = True + await session.flush() + + # Issue a new token in the same family + new_raw = await create_refresh_token(session, row.user_id) + return new_raw, str(row.user_id) + + +async def revoke_all_refresh_tokens( + session: AsyncSession, user_id: uuid.UUID +) -> int: + """Mark all active refresh tokens for user_id as revoked. + + Returns the count of revoked tokens (supports sign-out-all-devices). + """ + result = await session.execute( + select(RefreshToken).where( + RefreshToken.user_id == user_id, + RefreshToken.revoked.is_(False), + ) + ) + rows = result.scalars().all() + count = 0 + for row in rows: + row.revoked = True + count += 1 + await session.flush() + return count + + +# ── TOTP provisioning ─────────────────────────────────────────────────────────── + +async def provision_totp( + session: AsyncSession, user_id: uuid.UUID +) -> tuple[str, str]: + """Generate a new TOTP secret for the user and store it (not yet enabled). + + Returns (secret, provisioning_uri). + The secret is base32-encoded. provisioning_uri is suitable for QR code generation. + """ + user = await session.get(User, user_id) + if user is None: + raise ValueError("User not found") + + secret = pyotp.random_base32() + user.totp_secret = secret + await session.commit() + + totp = pyotp.totp.TOTP(secret) + uri = totp.provisioning_uri(user.email, issuer_name="DocuVault") + return secret, uri + + +async def verify_totp( + session: AsyncSession, + user_id: uuid.UUID, + code: str, + redis_client, +) -> bool: + """Verify a TOTP code with replay prevention (AUTH-08). + + valid_window=1 allows ±30s clock drift (per STATE.md recommendation). + Replay prevention: stores used codes in Redis with key 'totp_used:{user_id}:{code}' + and TTL=90s (covers the full ±30s validity window). + + Returns False if the code is invalid, already used, or user has no TOTP secret. + """ + user = await session.get(User, user_id) + if user is None or not user.totp_secret: + return False + + totp = pyotp.TOTP(user.totp_secret) + + # Check replay prevention before verifying + replay_key = f"totp_used:{user_id}:{code}" + if await redis_client.get(replay_key): + return False # Code already used within the validity window + + if not totp.verify(code, valid_window=1): + return False + + # Mark as used for 90s (covers valid_window=1: ±30s = 90s total) + await redis_client.set(replay_key, "1", ex=90) + return True + + +# ── Backup codes ──────────────────────────────────────────────────────────────── + +def generate_backup_codes(n: int = 10) -> list[str]: + """Return *n* random 8-character uppercase alphanumeric backup codes.""" + codes = [] + for _ in range(n): + # secrets.token_hex(4) returns 8 hex chars; uppercase for readability + code = secrets.token_hex(4).upper() + codes.append(code) + return codes + + +async def store_backup_codes( + session: AsyncSession, user_id: uuid.UUID, codes: list[str] +) -> None: + """Store backup codes for a user, replacing any existing unused codes. + + Each code is stored as an Argon2 hash (never plaintext, T-02-03). + Existing unused codes are deleted first to prevent accumulation. + """ + # Delete existing unused codes + result = await session.execute( + select(BackupCode).where( + BackupCode.user_id == user_id, + BackupCode.used_at.is_(None), + ) + ) + for row in result.scalars().all(): + await session.delete(row) + await session.flush() + + # Insert new hashed codes + for code in codes: + row = BackupCode( + id=uuid.uuid4(), + user_id=user_id, + code_hash=hash_password(code), + used_at=None, + ) + session.add(row) + + await session.commit() + + +async def verify_backup_code( + session: AsyncSession, user_id: uuid.UUID, code: str +) -> bool: + """Verify a backup code using constant-time comparison. + + Always iterates ALL unused codes (prevents timing-based enumeration, SEC-06). + On match: sets used_at=now() and commits. Returns True on first match. + Returns False if no code matches. + """ + result = await session.execute( + select(BackupCode).where( + BackupCode.user_id == user_id, + BackupCode.used_at.is_(None), + ) + ) + rows = result.scalars().all() + + matched_row: Optional[BackupCode] = None + for row in rows: + # Always call verify_password for ALL rows (constant-time: no early exit) + if verify_password(code, row.code_hash): + matched_row = row # record match but keep iterating + + if matched_row is None: + return False + + # Mark the matched code as used + matched_row.used_at = datetime.now(timezone.utc) + await session.commit() + return True + + +# ── HaveIBeenPwned check ──────────────────────────────────────────────────────── + +async def check_hibp(password: str) -> bool: + """Check if password appears in HaveIBeenPwned using the k-anonymity model. + + Sends only the first 5 chars of the SHA-1 hash to the HIBP API (T-02-05). + Returns True if the password has been breached, False otherwise. + On network error: logs a warning and returns False (fail-open, T-02-06). + """ + sha1 = hashlib.sha1(password.encode("utf-8")).hexdigest().upper() + prefix, suffix = sha1[:5], sha1[5:] + + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get( + f"https://api.pwnedpasswords.com/range/{prefix}", + headers={"Add-Padding": "true"}, + ) + resp.raise_for_status() + except Exception as exc: + logger.warning("HIBP check failed (fail-open): %s", exc) + return False + + for line in resp.text.splitlines(): + parts = line.split(":") + if len(parts) == 2: + candidate_suffix, count_str = parts + if hmac.compare_digest(candidate_suffix.upper(), suffix): + try: + count = int(count_str.strip()) + except ValueError: + count = 1 + if count > 0: + return True + return False + + +# ── Admin bootstrap ───────────────────────────────────────────────────────────── + +async def bootstrap_admin(session: AsyncSession) -> None: + """Idempotent admin account bootstrap (D-04, D-05, D-06). + + If the users table is empty AND settings.admin_email and settings.admin_password + are both non-empty, creates an admin User row with a Quota row. + + Logs a WARNING if env vars are missing (D-05) but never raises. + """ + if not settings.admin_email or not settings.admin_password: + logger.warning( + "Admin bootstrap skipped: ADMIN_EMAIL and/or ADMIN_PASSWORD not set (D-05). " + "Set both env vars to seed the first admin account on startup." + ) + return + + # Check if any users exist + result = await session.execute(select(User).limit(1)) + if result.scalar_one_or_none() is not None: + # Users already exist — idempotent, skip (D-04) + return + + admin_id = uuid.uuid4() + admin_user = User( + id=admin_id, + handle="admin", + email=settings.admin_email, + password_hash=hash_password(settings.admin_password), + role="admin", + is_active=True, + password_must_change=False, + ) + quota = Quota( + user_id=admin_id, + limit_bytes=104857600, # 100 MB default (D-06) + used_bytes=0, + ) + session.add(admin_user) + session.add(quota) + await session.commit() + logger.info("Admin account bootstrapped for %s", settings.admin_email) diff --git a/backend/services/email.py b/backend/services/email.py new file mode 100644 index 0000000..579932a --- /dev/null +++ b/backend/services/email.py @@ -0,0 +1,115 @@ +""" +Email service — pure Python, no FastAPI coupling. + +Sends transactional emails via SMTP when SMTP_HOST is configured; +logs the content to stdout otherwise (D-02 dev fallback). + +Security notes: + - Never raises: email failures are non-fatal; log and return + - Celery task wrapper handles retry/error reporting +""" +import logging +import smtplib +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart + +logger = logging.getLogger(__name__) + + +def send_password_reset_email(to_address: str, reset_link: str) -> None: + """Send (or log) the password reset email. + + When SMTP_HOST is not configured (dev / CI), logs the reset link to stdout + per D-02. The API response is 202 regardless — no token in the body. + + Never raises — failures are logged and the function returns normally. + """ + from config import settings # deferred to avoid module-level side effects + + if not settings.smtp_host: + # D-02: dev fallback — log token link to stdout + logger.info("DEV MODE — password reset link for %s: %s", to_address, reset_link) + return + + try: + msg = MIMEMultipart("alternative") + msg["Subject"] = "DocuVault — password reset" + msg["From"] = settings.smtp_from + msg["To"] = to_address + + text_body = ( + f"You requested a password reset for DocuVault.\n\n" + f"Click the link below (valid for 1 hour):\n{reset_link}\n\n" + "If you did not request this, ignore this email." + ) + html_body = ( + f"

You requested a password reset for DocuVault.

" + f"

Reset your password (valid 1 hour)

" + f"

If you did not request this, ignore this email.

" + ) + msg.attach(MIMEText(text_body, "plain")) + msg.attach(MIMEText(html_body, "html")) + + with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server: + server.ehlo() + server.starttls() + if settings.smtp_user: + server.login(settings.smtp_user, settings.smtp_password) + server.sendmail(settings.smtp_from, [to_address], msg.as_string()) + + logger.info("Password reset email sent to %s", to_address) + except Exception as exc: + logger.error("Failed to send password reset email to %s: %s", to_address, exc) + + +def send_security_alert_email_sync(to_address: str, user_id: str) -> None: + """Send (or log) a security alert email about suspicious refresh token reuse. + + Called by the email_tasks.py Celery task after looking up the user email. + Never raises — failures are logged. + """ + from config import settings # deferred import + + if not settings.smtp_host: + logger.warning( + "Security alert for user %s: suspicious refresh token reuse detected " + "(SMTP not configured — email not sent to %s)", + user_id, + to_address, + ) + return + + try: + msg = MIMEMultipart("alternative") + msg["Subject"] = "DocuVault — security alert: suspicious login detected" + msg["From"] = settings.smtp_from + msg["To"] = to_address + + text_body = ( + "DocuVault detected suspicious activity on your account.\n\n" + "A previously revoked refresh token was used to attempt a session refresh. " + "All active sessions have been revoked as a precaution.\n\n" + "If this was not you, please change your password immediately." + ) + html_body = ( + "

DocuVault security alert

" + "

A previously revoked refresh token was used to attempt a session refresh. " + "All active sessions have been revoked as a precaution.

" + "

If this was not you, please change your password immediately.

" + ) + msg.attach(MIMEText(text_body, "plain")) + msg.attach(MIMEText(html_body, "html")) + + with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server: + server.ehlo() + server.starttls() + if settings.smtp_user: + server.login(settings.smtp_user, settings.smtp_password) + server.sendmail(settings.smtp_from, [to_address], msg.as_string()) + + logger.info("Security alert email sent to %s (user_id=%s)", to_address, user_id) + except Exception as exc: + logger.error( + "Failed to send security alert email to %s (user_id=%s): %s", + to_address, user_id, exc + ) diff --git a/backend/tasks/email_tasks.py b/backend/tasks/email_tasks.py new file mode 100644 index 0000000..e1e37c5 --- /dev/null +++ b/backend/tasks/email_tasks.py @@ -0,0 +1,73 @@ +""" +Celery tasks for email dispatch in DocuVault. + +Tasks follow the same pattern as document_tasks.py: + - Plain sync def (Celery workers have no asyncio event loop by default) + - Async body via asyncio.run() + - All imports deferred inside _run_* functions to avoid circular imports + (see celery_app.py comment — do NOT import config at module level) + +Tasks registered here: + send_reset_email — dispatches a password-reset email to the user + send_security_alert_email — dispatches a security alert on refresh token reuse (AUTH-07) +""" +import asyncio + +from celery_app import celery_app + + +@celery_app.task(name="tasks.email_tasks.send_reset_email") +def send_reset_email(to_address: str, reset_link: str) -> dict: + """Synchronous Celery entry-point — send a password reset email. + + Called as: send_reset_email.delay(to_address, reset_link) + Delegates to the async body via asyncio.run(). + """ + return asyncio.run(_run_send_reset(to_address, reset_link)) + + +async def _run_send_reset(to_address: str, reset_link: str) -> dict: + """Async body of send_reset_email. Deferred imports to avoid circular deps.""" + from services.email import send_password_reset_email # noqa: PLC0415 + try: + send_password_reset_email(to_address, reset_link) + return {"status": "sent", "to": to_address} + except Exception as exc: + return {"status": "failed", "error": str(exc)} + + +@celery_app.task(name="tasks.email_tasks.send_security_alert_email") +def send_security_alert_email(user_id: str) -> dict: + """Synchronous Celery entry-point — send a security alert on token reuse. + + Called as: send_security_alert_email.delay(user_id) + Fetches the user's email from the DB inside the task using asyncio.run(). + On SMTP not configured: logs a WARNING per D-02 convention. + """ + return asyncio.run(_run_send_security_alert(user_id)) + + +async def _run_send_security_alert(user_id: str) -> dict: + """Async body of send_security_alert_email. Deferred imports to avoid circular deps.""" + import uuid as _uuid # noqa: PLC0415 + + from db.session import AsyncSessionLocal # noqa: PLC0415 + from db.models import User # noqa: PLC0415 + from services.email import send_security_alert_email_sync # noqa: PLC0415 + + try: + user_uuid = _uuid.UUID(user_id) + except ValueError: + return {"status": "failed", "error": f"Invalid user_id: {user_id}"} + + try: + async with AsyncSessionLocal() as session: + user = await session.get(User, user_uuid) + if user is None: + return {"status": "failed", "error": f"User {user_id} not found"} + to_address = user.email + + send_security_alert_email_sync(to_address, user_id) + return {"status": "sent", "user_id": user_id} + except Exception as exc: + return {"status": "failed", "error": str(exc)} diff --git a/backend/tests/test_task2_auth_service.py b/backend/tests/test_task2_auth_service.py new file mode 100644 index 0000000..dd68c4f --- /dev/null +++ b/backend/tests/test_task2_auth_service.py @@ -0,0 +1,226 @@ +""" +TDD tests for Task 2: services/auth.py and tasks/email_tasks.py. + +These tests should FAIL before implementation (RED phase). +""" +import pytest +import uuid + + +def test_hash_password_returns_argon2_hash(): + from services.auth import hash_password + h = hash_password("TestPass123!") + assert h.startswith("$argon2") + + +def test_verify_password_correct(): + from services.auth import hash_password, verify_password + h = hash_password("TestPass123!") + assert verify_password("TestPass123!", h) is True + + +def test_verify_password_wrong(): + from services.auth import hash_password, verify_password + h = hash_password("TestPass123!") + assert verify_password("WrongPass!", h) is False + + +def test_create_access_token_jwt_format(): + from services.auth import create_access_token + t = create_access_token("test-uid", "user") + assert isinstance(t, str) + assert t.count(".") == 2 # JWT has 3 parts separated by 2 dots + + +def test_decode_access_token_valid(): + from services.auth import create_access_token, decode_access_token + t = create_access_token("test-uid", "user") + payload = decode_access_token(t) + assert payload["sub"] == "test-uid" + assert payload["role"] == "user" + assert payload["typ"] == "access" + + +def test_decode_access_token_tampered_raises(): + from services.auth import create_access_token, decode_access_token + t = create_access_token("test-uid", "user") + tampered = t[:-5] + "XXXXX" + with pytest.raises(ValueError): + decode_access_token(tampered) + + +def test_decode_access_token_wrong_typ_raises(): + """An access token must have typ='access'; other typ values should raise ValueError.""" + from services.auth import create_password_reset_token, decode_access_token + reset_token = create_password_reset_token("test-uid") + with pytest.raises(ValueError): + decode_access_token(reset_token) + + +def test_generate_backup_codes(): + from services.auth import generate_backup_codes + codes = generate_backup_codes(10) + assert len(codes) == 10 + for code in codes: + assert isinstance(code, str) + assert len(code) == 8 + + +def test_generate_backup_codes_alphanumeric(): + from services.auth import generate_backup_codes + codes = generate_backup_codes(10) + for code in codes: + assert code.isalnum(), f"Code '{code}' contains non-alphanumeric chars" + + +def test_create_password_reset_token(): + from services.auth import create_password_reset_token, decode_password_reset_token + token = create_password_reset_token("test-uid") + assert token.count(".") == 2 + uid = decode_password_reset_token(token) + assert uid == "test-uid" + + +def test_decode_password_reset_token_wrong_typ_raises(): + """An access token must not be accepted as a password reset token.""" + from services.auth import create_access_token, decode_password_reset_token + access_token = create_access_token("test-uid", "user") + with pytest.raises(ValueError): + decode_password_reset_token(access_token) + + +def test_no_fastapi_imports_in_auth_service(): + """services/auth.py must not import FastAPI or raise HTTPException.""" + import os + import re + path = os.path.join(os.path.dirname(__file__), "..", "services", "auth.py") + with open(path) as f: + source = f.read() + # Check for actual imports (not docstring mentions) + assert not re.search(r"^(?:from|import)\s+fastapi", source, re.MULTILINE), \ + "services/auth.py must not import from fastapi" + # Check no raise HTTPException in actual code (not docstrings) + assert not re.search(r"raise\s+HTTPException", source), \ + "services/auth.py must not raise HTTPException" + + +def test_security_alert_email_referenced_in_auth(): + """rotate_refresh_token must reference send_security_alert_email (AUTH-07).""" + import os + path = os.path.join(os.path.dirname(__file__), "..", "services", "auth.py") + with open(path) as f: + source = f.read() + assert "send_security_alert_email" in source + + +def test_email_tasks_importable(): + """email_tasks.py must define both tasks (may skip if celery not installed locally).""" + try: + from tasks.email_tasks import send_reset_email, send_security_alert_email + assert callable(send_reset_email) + assert callable(send_security_alert_email) + except ModuleNotFoundError as e: + if "celery" in str(e): + pytest.skip("celery not installed in local test environment (runs in Docker)") + raise + + +@pytest.mark.asyncio +async def test_create_and_rotate_refresh_token(db_session): + """create_refresh_token creates a DB row; rotate_refresh_token returns new token.""" + import uuid as _uuid + from db.models import User, Quota + from services.auth import ( + create_refresh_token, rotate_refresh_token, hash_password + ) + + # Create a user + user = User( + id=_uuid.uuid4(), + handle="testuser", + email="test@example.com", + password_hash=hash_password("pass"), + role="user", + ) + quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0) + db_session.add(user) + db_session.add(quota) + await db_session.commit() + + # Create a refresh token + raw = await create_refresh_token(db_session, user.id) + assert isinstance(raw, str) + assert len(raw) > 10 + + # Rotate the token + new_raw, user_id_str = await rotate_refresh_token(db_session, raw) + assert isinstance(new_raw, str) + assert user_id_str == str(user.id) + assert new_raw != raw + + +@pytest.mark.asyncio +async def test_rotate_revoked_token_raises(db_session): + """Rotating a revoked token should raise ValueError('token_family_revoked').""" + import uuid as _uuid + from unittest.mock import patch, MagicMock + from db.models import User, Quota + from services.auth import ( + create_refresh_token, rotate_refresh_token, hash_password + ) + + user = User( + id=_uuid.uuid4(), + handle="testuser2", + email="test2@example.com", + password_hash=hash_password("pass"), + role="user", + ) + quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0) + db_session.add(user) + db_session.add(quota) + await db_session.commit() + + raw = await create_refresh_token(db_session, user.id) + # First rotate (valid) + new_raw, _ = await rotate_refresh_token(db_session, raw) + + # Now try to re-use the original (now revoked) token. + # Mock the celery task at the point it is imported inside rotate_refresh_token. + mock_task = MagicMock() + mock_task.delay = MagicMock() + with patch.dict("sys.modules", {"tasks.email_tasks": MagicMock(send_security_alert_email=mock_task)}): + with pytest.raises(ValueError, match="token_family_revoked"): + await rotate_refresh_token(db_session, raw) + + +@pytest.mark.asyncio +async def test_store_and_verify_backup_codes(db_session): + """store_backup_codes inserts rows; verify_backup_code matches correct code.""" + import uuid as _uuid + from db.models import User, Quota + from services.auth import ( + generate_backup_codes, store_backup_codes, verify_backup_code, hash_password + ) + + user = User( + id=_uuid.uuid4(), + handle="backupuser", + email="backup@example.com", + password_hash=hash_password("pass"), + role="user", + ) + quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0) + db_session.add(user) + db_session.add(quota) + await db_session.commit() + + codes = generate_backup_codes(10) + await store_backup_codes(db_session, user.id, codes) + + # Verify with correct code + assert await verify_backup_code(db_session, user.id, codes[0]) is True + # Verify with already-used code should return False + assert await verify_backup_code(db_session, user.id, codes[0]) is False + # Verify with wrong code + assert await verify_backup_code(db_session, user.id, "XXXXXXXX") is False