From 1d425d4392b22e968bced08f448b42e8294ff428 Mon Sep 17 00:00:00 2001 From: curo1305 Date: Fri, 22 May 2026 19:35:31 +0200 Subject: [PATCH] test(02-02): add failing tests for auth API endpoints RED phase - 17 tests covering register, login, TOTP, backup codes, per-account rate limiting, Origin validation, change-password, and password_must_change flow. Co-Authored-By: Claude Sonnet 4.6 --- backend/tests/test_auth_api.py | 425 +++++++++++++++++++++++++++++++++ 1 file changed, 425 insertions(+) create mode 100644 backend/tests/test_auth_api.py diff --git a/backend/tests/test_auth_api.py b/backend/tests/test_auth_api.py new file mode 100644 index 0000000..7619baf --- /dev/null +++ b/backend/tests/test_auth_api.py @@ -0,0 +1,425 @@ +""" +Tests for backend/api/auth.py (Plan 02-02, Task 1). + +Covers: register, login (TOTP, backup codes, password_must_change), refresh, +logout, me, change-password, per-account rate limiting, Origin validation, +and CSP headers. + +Uses the async_client fixture from conftest.py. + +Note on mock Redis: app.state.redis is set to a FakeRedis-like dict-based +mock so that per-account rate limiting can be exercised in isolation. +""" +from __future__ import annotations + +import hashlib +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from db.models import BackupCode, Quota, User + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +async def _register(async_client, handle="testuser", email="t@example.com", password="ValidPass12!"): + """Register a user and return the response.""" + return await async_client.post( + "/api/auth/register", + json={"handle": handle, "email": email, "password": password}, + ) + + +async def _login(async_client, email="t@example.com", password="ValidPass12!", **extra): + body = {"email": email, "password": password} + body.update(extra) + return await async_client.post("/api/auth/login", json=body) + + +# ── Fake Redis for per-account rate limiting ────────────────────────────────── + +class FakeRedis: + """In-memory fake Redis for testing rate limits and TOTP replay prevention.""" + + def __init__(self): + self._store: dict = {} + + async def get(self, key): + entry = self._store.get(key) + if entry is None: + return None + val, exp = entry + if exp is not None and datetime.now(timezone.utc).timestamp() > exp: + del self._store[key] + return None + return val + + async def incr(self, key): + entry = self._store.get(key) + if entry is None: + self._store[key] = (1, None) + return 1 + val, exp = entry + new_val = val + 1 + self._store[key] = (new_val, exp) + return new_val + + async def expire(self, key, seconds): + if key in self._store: + val, _ = self._store[key] + deadline = datetime.now(timezone.utc).timestamp() + seconds + self._store[key] = (val, deadline) + + async def set(self, key, value, ex=None): + deadline = None + if ex is not None: + deadline = datetime.now(timezone.utc).timestamp() + ex + self._store[key] = (value, deadline) + + async def close(self): + pass + + +@pytest_asyncio.fixture +async def authed_client(db_session: AsyncSession): + """Async HTTP test client with DB override AND fake Redis on app.state.redis. + + Each test gets a fresh FakeRedis instance (reset state). The slowapi + in-memory rate limiter is also reset between tests so that IP-level + counts from one test don't bleed into the next. + """ + from deps.db import get_db + from main import app + from api.auth import limiter as auth_limiter + + app.dependency_overrides[get_db] = lambda: db_session + fake_redis = FakeRedis() + app.state.redis = fake_redis + + # Reset slowapi's in-memory storage so previous test IP counters don't + # interfere with the current test's rate limit assertions. + try: + auth_limiter._storage.reset() + except Exception: + try: + # Fallback: clear internal storage dict directly + storage_obj = auth_limiter._storage + if hasattr(storage_obj, "_storage"): + storage_obj._storage.clear() + except Exception: + pass + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: + yield c + + app.dependency_overrides.clear() + # Reset redis to avoid leaking state between test files + app.state.redis = None + + +# ── Tests — register ────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_register_success(authed_client): + """POST /api/auth/register with valid data returns 201 with id and handle.""" + resp = await _register(authed_client) + assert resp.status_code == 201, resp.text + data = resp.json() + assert "id" in data + assert data["handle"] == "testuser" + assert data["email"] == "t@example.com" + + +@pytest.mark.asyncio +async def test_register_weak_password(authed_client): + """POST /api/auth/register with a short password returns 422.""" + resp = await authed_client.post( + "/api/auth/register", + json={"handle": "u", "email": "u@example.com", "password": "short"}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_register_duplicate_email(authed_client): + """Registering the same email twice returns 409.""" + await _register(authed_client, handle="u1", email="dup@example.com") + resp = await _register(authed_client, handle="u2", email="dup@example.com") + assert resp.status_code == 409 + assert "already in use" in resp.json()["detail"].lower() + + +# ── Tests — login ───────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_login_wrong_password(authed_client): + """POST /api/auth/login with wrong password returns 401.""" + await _register(authed_client) + resp = await _login(authed_client, password="WrongPass99!") + assert resp.status_code == 401 + assert "Incorrect email or password" in resp.json()["detail"] + + +@pytest.mark.asyncio +async def test_login_success(authed_client): + """Register then login returns 200 with access_token; Set-Cookie has HttpOnly + SameSite.""" + await _register(authed_client) + resp = await _login(authed_client) + assert resp.status_code == 200, resp.text + data = resp.json() + assert "access_token" in data + cookie_header = resp.headers.get("set-cookie", "") + assert "HttpOnly" in cookie_header or "httponly" in cookie_header.lower() + assert "SameSite=Strict" in cookie_header or "samesite=strict" in cookie_header.lower() + + +@pytest.mark.asyncio +async def test_me_requires_auth(authed_client): + """GET /api/auth/me without Bearer token returns 403 (HTTPBearer default).""" + resp = await authed_client.get("/api/auth/me") + # HTTPBearer raises 403 when no Authorization header present + assert resp.status_code in (401, 403) + + +@pytest.mark.asyncio +async def test_login_password_must_change(authed_client, db_session: AsyncSession): + """Login for user with password_must_change=True returns 200 with requires_password_change=True, no Set-Cookie.""" + await _register(authed_client, handle="pmcu", email="pmc@example.com") + # Flip the flag in the DB + result = await db_session.execute(select(User).where(User.email == "pmc@example.com")) + user = result.scalar_one() + user.password_must_change = True + await db_session.commit() + + resp = await _login(authed_client, email="pmc@example.com") + assert resp.status_code == 200 + data = resp.json() + assert data.get("requires_password_change") is True + assert "access_token" not in data + # No Set-Cookie header + assert "set-cookie" not in resp.headers + + +@pytest.mark.asyncio +async def test_change_password_breach(authed_client): + """POST /api/auth/change-password with breached new password returns 422.""" + await _register(authed_client, handle="cpb", email="cpb@example.com") + login_resp = await _login(authed_client, email="cpb@example.com") + token = login_resp.json()["access_token"] + + with patch("services.auth.check_hibp", return_value=True) as mock_hibp: + resp = await authed_client.post( + "/api/auth/change-password", + json={"current_password": "ValidPass12!", "new_password": "StrongNew99!@"}, + headers={"Authorization": f"Bearer {token}"}, + ) + + assert resp.status_code == 422 + assert "breach" in resp.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_change_password_wrong_current(authed_client): + """POST /api/auth/change-password with wrong current_password returns 401.""" + await _register(authed_client, handle="cpw", email="cpw@example.com") + login_resp = await _login(authed_client, email="cpw@example.com") + token = login_resp.json()["access_token"] + + resp = await authed_client.post( + "/api/auth/change-password", + json={"current_password": "WrongCurrent9!", "new_password": "NewStrong99!@"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_change_password_success(authed_client): + """POST /api/auth/change-password with correct current + strong new password returns 200.""" + await _register(authed_client, handle="cps", email="cps@example.com") + login_resp = await _login(authed_client, email="cps@example.com") + token = login_resp.json()["access_token"] + + with patch("services.auth.check_hibp", return_value=False): + resp = await authed_client.post( + "/api/auth/change-password", + json={"current_password": "ValidPass12!", "new_password": "NewStrong99!@"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + assert resp.json()["message"] == "Password updated" + + +# ── Tests — Origin validation middleware ────────────────────────────────────── + +@pytest.mark.asyncio +async def test_origin_rejected(authed_client): + """POST /api/auth/login with a cross-origin Origin header returns 403.""" + resp = await authed_client.post( + "/api/auth/login", + json={"email": "x@x.com", "password": "any"}, + headers={"Origin": "https://evil.example"}, + ) + assert resp.status_code == 403 + + +@pytest.mark.asyncio +async def test_origin_allowed(authed_client): + """POST /api/auth/login with an allowed Origin proceeds to auth (not 403).""" + resp = await authed_client.post( + "/api/auth/login", + json={"email": "nobody@example.com", "password": "any"}, + headers={"Origin": "http://localhost:5173"}, + ) + # Not 403 — may be 401 (wrong credentials) or 200 + assert resp.status_code != 403 + + +# ── Tests — per-account rate limiting ──────────────────────────────────────── + +@pytest.mark.asyncio +async def test_per_account_rate_limit(authed_client): + """11 consecutive login attempts with the same email returns 429 on the 11th.""" + email = "ratelimit@example.com" + for i in range(10): + resp = await authed_client.post( + "/api/auth/login", + json={"email": email, "password": "WrongPass99!"}, + ) + assert resp.status_code in (401, 200), f"Unexpected {resp.status_code} on attempt {i+1}" + + # 11th attempt should be rate limited + resp = await authed_client.post( + "/api/auth/login", + json={"email": email, "password": "WrongPass99!"}, + ) + assert resp.status_code == 429, f"Expected 429 on 11th attempt, got {resp.status_code}: {resp.text}" + + +# ── Tests — backup codes ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_login_backup_code_success(authed_client, db_session: AsyncSession): + """Login with a valid backup code returns 200 + access_token and invalidates the code.""" + await _register(authed_client, handle="bcu", email="bcu@example.com") + + result = await db_session.execute(select(User).where(User.email == "bcu@example.com")) + user = result.scalar_one() + + # Enable TOTP on the user (so the backup code path is exercised) + user.totp_enabled = True + user.totp_secret = "JBSWY3DPEHPK3PXP" # dummy secret + + # Store a backup code + plaintext_code = "ABCD1234" + from services.auth import hash_password + backup_code_row = BackupCode( + id=uuid.uuid4(), + user_id=user.id, + code_hash=hash_password(plaintext_code), + used_at=None, + ) + db_session.add(backup_code_row) + await db_session.commit() + + resp = await _login(authed_client, email="bcu@example.com", backup_code=plaintext_code) + assert resp.status_code == 200, resp.text + data = resp.json() + assert "access_token" in data + + # Verify the code was marked used + await db_session.refresh(backup_code_row) + assert backup_code_row.used_at is not None + + +@pytest.mark.asyncio +async def test_login_backup_code_reuse(authed_client, db_session: AsyncSession): + """Using the same backup code a second time returns 401.""" + await _register(authed_client, handle="bcr", email="bcr@example.com") + + result = await db_session.execute(select(User).where(User.email == "bcr@example.com")) + user = result.scalar_one() + user.totp_enabled = True + user.totp_secret = "JBSWY3DPEHPK3PXP" + + plaintext_code = "EFGH5678" + from services.auth import hash_password + backup_code_row = BackupCode( + id=uuid.uuid4(), + user_id=user.id, + code_hash=hash_password(plaintext_code), + used_at=None, + ) + db_session.add(backup_code_row) + await db_session.commit() + + # First use: should succeed + resp1 = await _login(authed_client, email="bcr@example.com", backup_code=plaintext_code) + assert resp1.status_code == 200, resp1.text + + # Second use: should fail + resp2 = await _login(authed_client, email="bcr@example.com", backup_code=plaintext_code) + assert resp2.status_code == 401 + assert "Invalid or already used code" in resp2.json()["detail"] + + +@pytest.mark.asyncio +async def test_login_backup_code_invalid(authed_client, db_session: AsyncSession): + """POST /api/auth/login with an unknown backup code returns 401.""" + await _register(authed_client, handle="bci", email="bci@example.com") + + result = await db_session.execute(select(User).where(User.email == "bci@example.com")) + user = result.scalar_one() + user.totp_enabled = True + user.totp_secret = "JBSWY3DPEHPK3PXP" + await db_session.commit() + + resp = await _login(authed_client, email="bci@example.com", backup_code="XXXXXXXX") + assert resp.status_code == 401 + assert "Invalid or already used code" in resp.json()["detail"] + + +@pytest.mark.asyncio +async def test_login_totp_takes_precedence(authed_client, db_session: AsyncSession): + """When both totp_code and backup_code are provided, the TOTP path is used (not backup).""" + await _register(authed_client, handle="ttp", email="ttp@example.com") + + result = await db_session.execute(select(User).where(User.email == "ttp@example.com")) + user = result.scalar_one() + user.totp_enabled = True + user.totp_secret = "JBSWY3DPEHPK3PXP" + + # Store a valid backup code + plaintext_code = "TOTP1234" + from services.auth import hash_password + backup_code_row = BackupCode( + id=uuid.uuid4(), + user_id=user.id, + code_hash=hash_password(plaintext_code), + used_at=None, + ) + db_session.add(backup_code_row) + await db_session.commit() + + # Use both fields — totp_code is wrong, backup_code is valid + # If TOTP takes precedence, the wrong totp_code should cause 401 + with patch("services.auth.verify_totp", return_value=False) as mock_totp: + resp = await _login( + authed_client, + email="ttp@example.com", + totp_code="000000", + backup_code=plaintext_code, + ) + # verify_totp was called (TOTP path taken) + mock_totp.assert_called_once() + + # The backup code path was NOT taken — code remains unused + await db_session.refresh(backup_code_row) + assert backup_code_row.used_at is None, "Backup code should NOT have been consumed when totp_code is provided" + assert resp.status_code == 401