""" Tests for backend/api/auth.py (Plan 02-02, Task 1). Covers: register, login (TOTP, backup codes, password_must_change), refresh, logout, me, change-password, per-account rate limiting, Origin validation, and CSP headers. Uses the async_client fixture from conftest.py. Note on mock Redis: app.state.redis is set to a FakeRedis-like dict-based mock so that per-account rate limiting can be exercised in isolation. """ from __future__ import annotations import hashlib import uuid from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from db.models import BackupCode, Quota, User # ── Helpers ────────────────────────────────────────────────────────────────── async def _register(async_client, handle="testuser", email="t@example.com", password="ValidPass12!"): """Register a user and return the response.""" return await async_client.post( "/api/auth/register", json={"handle": handle, "email": email, "password": password}, ) async def _login(async_client, email="t@example.com", password="ValidPass12!", **extra): body = {"email": email, "password": password} body.update(extra) return await async_client.post("/api/auth/login", json=body) # ── Fake Redis for per-account rate limiting ────────────────────────────────── class FakeRedis: """In-memory fake Redis for testing rate limits and TOTP replay prevention.""" def __init__(self): self._store: dict = {} async def get(self, key): entry = self._store.get(key) if entry is None: return None val, exp = entry if exp is not None and datetime.now(timezone.utc).timestamp() > exp: del self._store[key] return None return val async def incr(self, key): entry = self._store.get(key) if entry is None: self._store[key] = (1, None) return 1 val, exp = entry new_val = val + 1 self._store[key] = (new_val, exp) return new_val async def expire(self, key, seconds): if key in self._store: val, _ = self._store[key] deadline = datetime.now(timezone.utc).timestamp() + seconds self._store[key] = (val, deadline) async def set(self, key, value, ex=None): deadline = None if ex is not None: deadline = datetime.now(timezone.utc).timestamp() + ex self._store[key] = (value, deadline) async def close(self): pass @pytest_asyncio.fixture async def authed_client(db_session: AsyncSession): """Async HTTP test client with DB override AND fake Redis on app.state.redis. Each test gets a fresh FakeRedis instance (reset state). The slowapi in-memory rate limiter is also reset between tests so that IP-level counts from one test don't bleed into the next. """ from deps.db import get_db from main import app from api.auth import limiter as auth_limiter app.dependency_overrides[get_db] = lambda: db_session fake_redis = FakeRedis() app.state.redis = fake_redis # Reset slowapi's in-memory storage so previous test IP counters don't # interfere with the current test's rate limit assertions. try: auth_limiter._storage.reset() except Exception: try: # Fallback: clear internal storage dict directly storage_obj = auth_limiter._storage if hasattr(storage_obj, "_storage"): storage_obj._storage.clear() except Exception: pass async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: yield c app.dependency_overrides.clear() # Reset redis to avoid leaking state between test files app.state.redis = None # ── Tests — register ────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_register_success(authed_client): """POST /api/auth/register with valid data returns 201 with id and handle.""" resp = await _register(authed_client) assert resp.status_code == 201, resp.text data = resp.json() assert "id" in data assert data["handle"] == "testuser" assert data["email"] == "t@example.com" @pytest.mark.asyncio async def test_register_weak_password(authed_client): """POST /api/auth/register with a short password returns 422.""" resp = await authed_client.post( "/api/auth/register", json={"handle": "u", "email": "u@example.com", "password": "short"}, ) assert resp.status_code == 422 @pytest.mark.asyncio async def test_register_duplicate_email(authed_client): """Registering the same email twice returns 409.""" await _register(authed_client, handle="u1", email="dup@example.com") resp = await _register(authed_client, handle="u2", email="dup@example.com") assert resp.status_code == 409 assert "already in use" in resp.json()["detail"].lower() # ── Tests — login ───────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_login_wrong_password(authed_client): """POST /api/auth/login with wrong password returns 401.""" await _register(authed_client) resp = await _login(authed_client, password="WrongPass99!") assert resp.status_code == 401 assert "Incorrect email or password" in resp.json()["detail"] @pytest.mark.asyncio async def test_login_success(authed_client): """Register then login returns 200 with access_token; Set-Cookie has HttpOnly + SameSite.""" await _register(authed_client) resp = await _login(authed_client) assert resp.status_code == 200, resp.text data = resp.json() assert "access_token" in data cookie_header = resp.headers.get("set-cookie", "") assert "HttpOnly" in cookie_header or "httponly" in cookie_header.lower() assert "SameSite=Strict" in cookie_header or "samesite=strict" in cookie_header.lower() @pytest.mark.asyncio async def test_me_requires_auth(authed_client): """GET /api/auth/me without Bearer token returns 403 (HTTPBearer default).""" resp = await authed_client.get("/api/auth/me") # HTTPBearer raises 403 when no Authorization header present assert resp.status_code in (401, 403) @pytest.mark.asyncio async def test_login_password_must_change(authed_client, db_session: AsyncSession): """Login for user with password_must_change=True returns 200 with requires_password_change=True, no Set-Cookie.""" await _register(authed_client, handle="pmcu", email="pmc@example.com") # Flip the flag in the DB result = await db_session.execute(select(User).where(User.email == "pmc@example.com")) user = result.scalar_one() user.password_must_change = True await db_session.commit() resp = await _login(authed_client, email="pmc@example.com") assert resp.status_code == 200 data = resp.json() assert data.get("requires_password_change") is True assert "access_token" not in data # No Set-Cookie header assert "set-cookie" not in resp.headers @pytest.mark.asyncio async def test_change_password_breach(authed_client): """POST /api/auth/change-password with breached new password returns 422.""" await _register(authed_client, handle="cpb", email="cpb@example.com") login_resp = await _login(authed_client, email="cpb@example.com") token = login_resp.json()["access_token"] with patch("services.auth.check_hibp", return_value=True) as mock_hibp: resp = await authed_client.post( "/api/auth/change-password", json={"current_password": "ValidPass12!", "new_password": "StrongNew99!@"}, headers={"Authorization": f"Bearer {token}"}, ) assert resp.status_code == 422 assert "breach" in resp.json()["detail"].lower() @pytest.mark.asyncio async def test_change_password_wrong_current(authed_client): """POST /api/auth/change-password with wrong current_password returns 401.""" await _register(authed_client, handle="cpw", email="cpw@example.com") login_resp = await _login(authed_client, email="cpw@example.com") token = login_resp.json()["access_token"] resp = await authed_client.post( "/api/auth/change-password", json={"current_password": "WrongCurrent9!", "new_password": "NewStrong99!@"}, headers={"Authorization": f"Bearer {token}"}, ) assert resp.status_code == 401 @pytest.mark.asyncio async def test_change_password_success(authed_client): """POST /api/auth/change-password with correct current + strong new password returns 200.""" await _register(authed_client, handle="cps", email="cps@example.com") login_resp = await _login(authed_client, email="cps@example.com") token = login_resp.json()["access_token"] with patch("services.auth.check_hibp", return_value=False): resp = await authed_client.post( "/api/auth/change-password", json={"current_password": "ValidPass12!", "new_password": "NewStrong99!@"}, headers={"Authorization": f"Bearer {token}"}, ) assert resp.status_code == 200 assert resp.json()["message"] == "Password updated" # ── Tests — Origin validation middleware ────────────────────────────────────── @pytest.mark.asyncio async def test_origin_rejected(authed_client): """POST /api/auth/login with a cross-origin Origin header returns 403.""" resp = await authed_client.post( "/api/auth/login", json={"email": "x@x.com", "password": "any"}, headers={"Origin": "https://evil.example"}, ) assert resp.status_code == 403 @pytest.mark.asyncio async def test_origin_allowed(authed_client): """POST /api/auth/login with an allowed Origin proceeds to auth (not 403).""" resp = await authed_client.post( "/api/auth/login", json={"email": "nobody@example.com", "password": "any"}, headers={"Origin": "http://localhost:5173"}, ) # Not 403 — may be 401 (wrong credentials) or 200 assert resp.status_code != 403 # ── Tests — per-account rate limiting ──────────────────────────────────────── @pytest.mark.asyncio async def test_per_account_rate_limit(authed_client): """11 consecutive login attempts with the same email returns 429 on the 11th.""" email = "ratelimit@example.com" for i in range(10): resp = await authed_client.post( "/api/auth/login", json={"email": email, "password": "WrongPass99!"}, ) assert resp.status_code in (401, 200), f"Unexpected {resp.status_code} on attempt {i+1}" # 11th attempt should be rate limited resp = await authed_client.post( "/api/auth/login", json={"email": email, "password": "WrongPass99!"}, ) assert resp.status_code == 429, f"Expected 429 on 11th attempt, got {resp.status_code}: {resp.text}" # ── Tests — backup codes ────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_login_backup_code_success(authed_client, db_session: AsyncSession): """Login with a valid backup code returns 200 + access_token and invalidates the code.""" await _register(authed_client, handle="bcu", email="bcu@example.com") result = await db_session.execute(select(User).where(User.email == "bcu@example.com")) user = result.scalar_one() # Enable TOTP on the user (so the backup code path is exercised) user.totp_enabled = True user.totp_secret = "JBSWY3DPEHPK3PXP" # dummy secret # Store a backup code plaintext_code = "ABCD1234" from services.auth import hash_password backup_code_row = BackupCode( id=uuid.uuid4(), user_id=user.id, code_hash=hash_password(plaintext_code), used_at=None, ) db_session.add(backup_code_row) await db_session.commit() resp = await _login(authed_client, email="bcu@example.com", backup_code=plaintext_code) assert resp.status_code == 200, resp.text data = resp.json() assert "access_token" in data # Verify the code was marked used await db_session.refresh(backup_code_row) assert backup_code_row.used_at is not None @pytest.mark.asyncio async def test_login_backup_code_reuse(authed_client, db_session: AsyncSession): """Using the same backup code a second time returns 401.""" await _register(authed_client, handle="bcr", email="bcr@example.com") result = await db_session.execute(select(User).where(User.email == "bcr@example.com")) user = result.scalar_one() user.totp_enabled = True user.totp_secret = "JBSWY3DPEHPK3PXP" plaintext_code = "EFGH5678" from services.auth import hash_password backup_code_row = BackupCode( id=uuid.uuid4(), user_id=user.id, code_hash=hash_password(plaintext_code), used_at=None, ) db_session.add(backup_code_row) await db_session.commit() # First use: should succeed resp1 = await _login(authed_client, email="bcr@example.com", backup_code=plaintext_code) assert resp1.status_code == 200, resp1.text # Second use: should fail resp2 = await _login(authed_client, email="bcr@example.com", backup_code=plaintext_code) assert resp2.status_code == 401 assert "Invalid or already used code" in resp2.json()["detail"] @pytest.mark.asyncio async def test_login_backup_code_invalid(authed_client, db_session: AsyncSession): """POST /api/auth/login with an unknown backup code returns 401.""" await _register(authed_client, handle="bci", email="bci@example.com") result = await db_session.execute(select(User).where(User.email == "bci@example.com")) user = result.scalar_one() user.totp_enabled = True user.totp_secret = "JBSWY3DPEHPK3PXP" await db_session.commit() resp = await _login(authed_client, email="bci@example.com", backup_code="XXXXXXXX") assert resp.status_code == 401 assert "Invalid or already used code" in resp.json()["detail"] @pytest.mark.asyncio async def test_login_totp_takes_precedence(authed_client, db_session: AsyncSession): """When both totp_code and backup_code are provided, the TOTP path is used (not backup).""" await _register(authed_client, handle="ttp", email="ttp@example.com") result = await db_session.execute(select(User).where(User.email == "ttp@example.com")) user = result.scalar_one() user.totp_enabled = True user.totp_secret = "JBSWY3DPEHPK3PXP" # Store a valid backup code plaintext_code = "TOTP1234" from services.auth import hash_password backup_code_row = BackupCode( id=uuid.uuid4(), user_id=user.id, code_hash=hash_password(plaintext_code), used_at=None, ) db_session.add(backup_code_row) await db_session.commit() # Use both fields — totp_code is wrong, backup_code is valid # If TOTP takes precedence, the wrong totp_code should cause 401 with patch("services.auth.verify_totp", return_value=False) as mock_totp: resp = await _login( authed_client, email="ttp@example.com", totp_code="000000", backup_code=plaintext_code, ) # verify_totp was called (TOTP path taken) mock_totp.assert_called_once() # The backup code path was NOT taken — code remains unused await db_session.refresh(backup_code_row) assert backup_code_row.used_at is None, "Backup code should NOT have been consumed when totp_code is provided" assert resp.status_code == 401