test(02-02): add failing tests for auth API endpoints
RED phase - 17 tests covering register, login, TOTP, backup codes, per-account rate limiting, Origin validation, change-password, and password_must_change flow. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,425 @@
|
|||||||
|
"""
|
||||||
|
Tests for backend/api/auth.py (Plan 02-02, Task 1).
|
||||||
|
|
||||||
|
Covers: register, login (TOTP, backup codes, password_must_change), refresh,
|
||||||
|
logout, me, change-password, per-account rate limiting, Origin validation,
|
||||||
|
and CSP headers.
|
||||||
|
|
||||||
|
Uses the async_client fixture from conftest.py.
|
||||||
|
|
||||||
|
Note on mock Redis: app.state.redis is set to a FakeRedis-like dict-based
|
||||||
|
mock so that per-account rate limiting can be exercised in isolation.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from db.models import BackupCode, Quota, User
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _register(async_client, handle="testuser", email="t@example.com", password="ValidPass12!"):
|
||||||
|
"""Register a user and return the response."""
|
||||||
|
return await async_client.post(
|
||||||
|
"/api/auth/register",
|
||||||
|
json={"handle": handle, "email": email, "password": password},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _login(async_client, email="t@example.com", password="ValidPass12!", **extra):
|
||||||
|
body = {"email": email, "password": password}
|
||||||
|
body.update(extra)
|
||||||
|
return await async_client.post("/api/auth/login", json=body)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fake Redis for per-account rate limiting ──────────────────────────────────
|
||||||
|
|
||||||
|
class FakeRedis:
|
||||||
|
"""In-memory fake Redis for testing rate limits and TOTP replay prevention."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._store: dict = {}
|
||||||
|
|
||||||
|
async def get(self, key):
|
||||||
|
entry = self._store.get(key)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
val, exp = entry
|
||||||
|
if exp is not None and datetime.now(timezone.utc).timestamp() > exp:
|
||||||
|
del self._store[key]
|
||||||
|
return None
|
||||||
|
return val
|
||||||
|
|
||||||
|
async def incr(self, key):
|
||||||
|
entry = self._store.get(key)
|
||||||
|
if entry is None:
|
||||||
|
self._store[key] = (1, None)
|
||||||
|
return 1
|
||||||
|
val, exp = entry
|
||||||
|
new_val = val + 1
|
||||||
|
self._store[key] = (new_val, exp)
|
||||||
|
return new_val
|
||||||
|
|
||||||
|
async def expire(self, key, seconds):
|
||||||
|
if key in self._store:
|
||||||
|
val, _ = self._store[key]
|
||||||
|
deadline = datetime.now(timezone.utc).timestamp() + seconds
|
||||||
|
self._store[key] = (val, deadline)
|
||||||
|
|
||||||
|
async def set(self, key, value, ex=None):
|
||||||
|
deadline = None
|
||||||
|
if ex is not None:
|
||||||
|
deadline = datetime.now(timezone.utc).timestamp() + ex
|
||||||
|
self._store[key] = (value, deadline)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def authed_client(db_session: AsyncSession):
|
||||||
|
"""Async HTTP test client with DB override AND fake Redis on app.state.redis.
|
||||||
|
|
||||||
|
Each test gets a fresh FakeRedis instance (reset state). The slowapi
|
||||||
|
in-memory rate limiter is also reset between tests so that IP-level
|
||||||
|
counts from one test don't bleed into the next.
|
||||||
|
"""
|
||||||
|
from deps.db import get_db
|
||||||
|
from main import app
|
||||||
|
from api.auth import limiter as auth_limiter
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = lambda: db_session
|
||||||
|
fake_redis = FakeRedis()
|
||||||
|
app.state.redis = fake_redis
|
||||||
|
|
||||||
|
# Reset slowapi's in-memory storage so previous test IP counters don't
|
||||||
|
# interfere with the current test's rate limit assertions.
|
||||||
|
try:
|
||||||
|
auth_limiter._storage.reset()
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
# Fallback: clear internal storage dict directly
|
||||||
|
storage_obj = auth_limiter._storage
|
||||||
|
if hasattr(storage_obj, "_storage"):
|
||||||
|
storage_obj._storage.clear()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
# Reset redis to avoid leaking state between test files
|
||||||
|
app.state.redis = None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests — register ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_success(authed_client):
|
||||||
|
"""POST /api/auth/register with valid data returns 201 with id and handle."""
|
||||||
|
resp = await _register(authed_client)
|
||||||
|
assert resp.status_code == 201, resp.text
|
||||||
|
data = resp.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert data["handle"] == "testuser"
|
||||||
|
assert data["email"] == "t@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_weak_password(authed_client):
|
||||||
|
"""POST /api/auth/register with a short password returns 422."""
|
||||||
|
resp = await authed_client.post(
|
||||||
|
"/api/auth/register",
|
||||||
|
json={"handle": "u", "email": "u@example.com", "password": "short"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_duplicate_email(authed_client):
|
||||||
|
"""Registering the same email twice returns 409."""
|
||||||
|
await _register(authed_client, handle="u1", email="dup@example.com")
|
||||||
|
resp = await _register(authed_client, handle="u2", email="dup@example.com")
|
||||||
|
assert resp.status_code == 409
|
||||||
|
assert "already in use" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests — login ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_wrong_password(authed_client):
|
||||||
|
"""POST /api/auth/login with wrong password returns 401."""
|
||||||
|
await _register(authed_client)
|
||||||
|
resp = await _login(authed_client, password="WrongPass99!")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
assert "Incorrect email or password" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_success(authed_client):
|
||||||
|
"""Register then login returns 200 with access_token; Set-Cookie has HttpOnly + SameSite."""
|
||||||
|
await _register(authed_client)
|
||||||
|
resp = await _login(authed_client)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
cookie_header = resp.headers.get("set-cookie", "")
|
||||||
|
assert "HttpOnly" in cookie_header or "httponly" in cookie_header.lower()
|
||||||
|
assert "SameSite=Strict" in cookie_header or "samesite=strict" in cookie_header.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_me_requires_auth(authed_client):
|
||||||
|
"""GET /api/auth/me without Bearer token returns 403 (HTTPBearer default)."""
|
||||||
|
resp = await authed_client.get("/api/auth/me")
|
||||||
|
# HTTPBearer raises 403 when no Authorization header present
|
||||||
|
assert resp.status_code in (401, 403)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_password_must_change(authed_client, db_session: AsyncSession):
|
||||||
|
"""Login for user with password_must_change=True returns 200 with requires_password_change=True, no Set-Cookie."""
|
||||||
|
await _register(authed_client, handle="pmcu", email="pmc@example.com")
|
||||||
|
# Flip the flag in the DB
|
||||||
|
result = await db_session.execute(select(User).where(User.email == "pmc@example.com"))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.password_must_change = True
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
resp = await _login(authed_client, email="pmc@example.com")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data.get("requires_password_change") is True
|
||||||
|
assert "access_token" not in data
|
||||||
|
# No Set-Cookie header
|
||||||
|
assert "set-cookie" not in resp.headers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_breach(authed_client):
|
||||||
|
"""POST /api/auth/change-password with breached new password returns 422."""
|
||||||
|
await _register(authed_client, handle="cpb", email="cpb@example.com")
|
||||||
|
login_resp = await _login(authed_client, email="cpb@example.com")
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
|
||||||
|
with patch("services.auth.check_hibp", return_value=True) as mock_hibp:
|
||||||
|
resp = await authed_client.post(
|
||||||
|
"/api/auth/change-password",
|
||||||
|
json={"current_password": "ValidPass12!", "new_password": "StrongNew99!@"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
assert "breach" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_wrong_current(authed_client):
|
||||||
|
"""POST /api/auth/change-password with wrong current_password returns 401."""
|
||||||
|
await _register(authed_client, handle="cpw", email="cpw@example.com")
|
||||||
|
login_resp = await _login(authed_client, email="cpw@example.com")
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
|
||||||
|
resp = await authed_client.post(
|
||||||
|
"/api/auth/change-password",
|
||||||
|
json={"current_password": "WrongCurrent9!", "new_password": "NewStrong99!@"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_success(authed_client):
|
||||||
|
"""POST /api/auth/change-password with correct current + strong new password returns 200."""
|
||||||
|
await _register(authed_client, handle="cps", email="cps@example.com")
|
||||||
|
login_resp = await _login(authed_client, email="cps@example.com")
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
|
||||||
|
with patch("services.auth.check_hibp", return_value=False):
|
||||||
|
resp = await authed_client.post(
|
||||||
|
"/api/auth/change-password",
|
||||||
|
json={"current_password": "ValidPass12!", "new_password": "NewStrong99!@"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["message"] == "Password updated"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests — Origin validation middleware ──────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_origin_rejected(authed_client):
|
||||||
|
"""POST /api/auth/login with a cross-origin Origin header returns 403."""
|
||||||
|
resp = await authed_client.post(
|
||||||
|
"/api/auth/login",
|
||||||
|
json={"email": "x@x.com", "password": "any"},
|
||||||
|
headers={"Origin": "https://evil.example"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_origin_allowed(authed_client):
|
||||||
|
"""POST /api/auth/login with an allowed Origin proceeds to auth (not 403)."""
|
||||||
|
resp = await authed_client.post(
|
||||||
|
"/api/auth/login",
|
||||||
|
json={"email": "nobody@example.com", "password": "any"},
|
||||||
|
headers={"Origin": "http://localhost:5173"},
|
||||||
|
)
|
||||||
|
# Not 403 — may be 401 (wrong credentials) or 200
|
||||||
|
assert resp.status_code != 403
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests — per-account rate limiting ────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_per_account_rate_limit(authed_client):
|
||||||
|
"""11 consecutive login attempts with the same email returns 429 on the 11th."""
|
||||||
|
email = "ratelimit@example.com"
|
||||||
|
for i in range(10):
|
||||||
|
resp = await authed_client.post(
|
||||||
|
"/api/auth/login",
|
||||||
|
json={"email": email, "password": "WrongPass99!"},
|
||||||
|
)
|
||||||
|
assert resp.status_code in (401, 200), f"Unexpected {resp.status_code} on attempt {i+1}"
|
||||||
|
|
||||||
|
# 11th attempt should be rate limited
|
||||||
|
resp = await authed_client.post(
|
||||||
|
"/api/auth/login",
|
||||||
|
json={"email": email, "password": "WrongPass99!"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 429, f"Expected 429 on 11th attempt, got {resp.status_code}: {resp.text}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests — backup codes ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_backup_code_success(authed_client, db_session: AsyncSession):
|
||||||
|
"""Login with a valid backup code returns 200 + access_token and invalidates the code."""
|
||||||
|
await _register(authed_client, handle="bcu", email="bcu@example.com")
|
||||||
|
|
||||||
|
result = await db_session.execute(select(User).where(User.email == "bcu@example.com"))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
# Enable TOTP on the user (so the backup code path is exercised)
|
||||||
|
user.totp_enabled = True
|
||||||
|
user.totp_secret = "JBSWY3DPEHPK3PXP" # dummy secret
|
||||||
|
|
||||||
|
# Store a backup code
|
||||||
|
plaintext_code = "ABCD1234"
|
||||||
|
from services.auth import hash_password
|
||||||
|
backup_code_row = BackupCode(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
user_id=user.id,
|
||||||
|
code_hash=hash_password(plaintext_code),
|
||||||
|
used_at=None,
|
||||||
|
)
|
||||||
|
db_session.add(backup_code_row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
resp = await _login(authed_client, email="bcu@example.com", backup_code=plaintext_code)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
|
||||||
|
# Verify the code was marked used
|
||||||
|
await db_session.refresh(backup_code_row)
|
||||||
|
assert backup_code_row.used_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_backup_code_reuse(authed_client, db_session: AsyncSession):
|
||||||
|
"""Using the same backup code a second time returns 401."""
|
||||||
|
await _register(authed_client, handle="bcr", email="bcr@example.com")
|
||||||
|
|
||||||
|
result = await db_session.execute(select(User).where(User.email == "bcr@example.com"))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.totp_enabled = True
|
||||||
|
user.totp_secret = "JBSWY3DPEHPK3PXP"
|
||||||
|
|
||||||
|
plaintext_code = "EFGH5678"
|
||||||
|
from services.auth import hash_password
|
||||||
|
backup_code_row = BackupCode(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
user_id=user.id,
|
||||||
|
code_hash=hash_password(plaintext_code),
|
||||||
|
used_at=None,
|
||||||
|
)
|
||||||
|
db_session.add(backup_code_row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# First use: should succeed
|
||||||
|
resp1 = await _login(authed_client, email="bcr@example.com", backup_code=plaintext_code)
|
||||||
|
assert resp1.status_code == 200, resp1.text
|
||||||
|
|
||||||
|
# Second use: should fail
|
||||||
|
resp2 = await _login(authed_client, email="bcr@example.com", backup_code=plaintext_code)
|
||||||
|
assert resp2.status_code == 401
|
||||||
|
assert "Invalid or already used code" in resp2.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_backup_code_invalid(authed_client, db_session: AsyncSession):
|
||||||
|
"""POST /api/auth/login with an unknown backup code returns 401."""
|
||||||
|
await _register(authed_client, handle="bci", email="bci@example.com")
|
||||||
|
|
||||||
|
result = await db_session.execute(select(User).where(User.email == "bci@example.com"))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.totp_enabled = True
|
||||||
|
user.totp_secret = "JBSWY3DPEHPK3PXP"
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
resp = await _login(authed_client, email="bci@example.com", backup_code="XXXXXXXX")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
assert "Invalid or already used code" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_totp_takes_precedence(authed_client, db_session: AsyncSession):
|
||||||
|
"""When both totp_code and backup_code are provided, the TOTP path is used (not backup)."""
|
||||||
|
await _register(authed_client, handle="ttp", email="ttp@example.com")
|
||||||
|
|
||||||
|
result = await db_session.execute(select(User).where(User.email == "ttp@example.com"))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.totp_enabled = True
|
||||||
|
user.totp_secret = "JBSWY3DPEHPK3PXP"
|
||||||
|
|
||||||
|
# Store a valid backup code
|
||||||
|
plaintext_code = "TOTP1234"
|
||||||
|
from services.auth import hash_password
|
||||||
|
backup_code_row = BackupCode(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
user_id=user.id,
|
||||||
|
code_hash=hash_password(plaintext_code),
|
||||||
|
used_at=None,
|
||||||
|
)
|
||||||
|
db_session.add(backup_code_row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Use both fields — totp_code is wrong, backup_code is valid
|
||||||
|
# If TOTP takes precedence, the wrong totp_code should cause 401
|
||||||
|
with patch("services.auth.verify_totp", return_value=False) as mock_totp:
|
||||||
|
resp = await _login(
|
||||||
|
authed_client,
|
||||||
|
email="ttp@example.com",
|
||||||
|
totp_code="000000",
|
||||||
|
backup_code=plaintext_code,
|
||||||
|
)
|
||||||
|
# verify_totp was called (TOTP path taken)
|
||||||
|
mock_totp.assert_called_once()
|
||||||
|
|
||||||
|
# The backup code path was NOT taken — code remains unused
|
||||||
|
await db_session.refresh(backup_code_row)
|
||||||
|
assert backup_code_row.used_at is None, "Backup code should NOT have been consumed when totp_code is provided"
|
||||||
|
assert resp.status_code == 401
|
||||||
Reference in New Issue
Block a user