Files
kite/backend/tests/test_auth_api.py
T
curo1305 1d425d4392 test(02-02): add failing tests for auth API endpoints
RED phase - 17 tests covering register, login, TOTP, backup codes,
per-account rate limiting, Origin validation, change-password, and
password_must_change flow.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-22 19:35:31 +02:00

426 lines
16 KiB
Python

"""
Tests for backend/api/auth.py (Plan 02-02, Task 1).
Covers: register, login (TOTP, backup codes, password_must_change), refresh,
logout, me, change-password, per-account rate limiting, Origin validation,
and CSP headers.
Uses the async_client fixture from conftest.py.
Note on mock Redis: app.state.redis is set to a FakeRedis-like dict-based
mock so that per-account rate limiting can be exercised in isolation.
"""
from __future__ import annotations
import hashlib
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from db.models import BackupCode, Quota, User
# ── Helpers ──────────────────────────────────────────────────────────────────
async def _register(async_client, handle="testuser", email="t@example.com", password="ValidPass12!"):
"""Register a user and return the response."""
return await async_client.post(
"/api/auth/register",
json={"handle": handle, "email": email, "password": password},
)
async def _login(async_client, email="t@example.com", password="ValidPass12!", **extra):
body = {"email": email, "password": password}
body.update(extra)
return await async_client.post("/api/auth/login", json=body)
# ── Fake Redis for per-account rate limiting ──────────────────────────────────
class FakeRedis:
"""In-memory fake Redis for testing rate limits and TOTP replay prevention."""
def __init__(self):
self._store: dict = {}
async def get(self, key):
entry = self._store.get(key)
if entry is None:
return None
val, exp = entry
if exp is not None and datetime.now(timezone.utc).timestamp() > exp:
del self._store[key]
return None
return val
async def incr(self, key):
entry = self._store.get(key)
if entry is None:
self._store[key] = (1, None)
return 1
val, exp = entry
new_val = val + 1
self._store[key] = (new_val, exp)
return new_val
async def expire(self, key, seconds):
if key in self._store:
val, _ = self._store[key]
deadline = datetime.now(timezone.utc).timestamp() + seconds
self._store[key] = (val, deadline)
async def set(self, key, value, ex=None):
deadline = None
if ex is not None:
deadline = datetime.now(timezone.utc).timestamp() + ex
self._store[key] = (value, deadline)
async def close(self):
pass
@pytest_asyncio.fixture
async def authed_client(db_session: AsyncSession):
"""Async HTTP test client with DB override AND fake Redis on app.state.redis.
Each test gets a fresh FakeRedis instance (reset state). The slowapi
in-memory rate limiter is also reset between tests so that IP-level
counts from one test don't bleed into the next.
"""
from deps.db import get_db
from main import app
from api.auth import limiter as auth_limiter
app.dependency_overrides[get_db] = lambda: db_session
fake_redis = FakeRedis()
app.state.redis = fake_redis
# Reset slowapi's in-memory storage so previous test IP counters don't
# interfere with the current test's rate limit assertions.
try:
auth_limiter._storage.reset()
except Exception:
try:
# Fallback: clear internal storage dict directly
storage_obj = auth_limiter._storage
if hasattr(storage_obj, "_storage"):
storage_obj._storage.clear()
except Exception:
pass
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
yield c
app.dependency_overrides.clear()
# Reset redis to avoid leaking state between test files
app.state.redis = None
# ── Tests — register ──────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_register_success(authed_client):
"""POST /api/auth/register with valid data returns 201 with id and handle."""
resp = await _register(authed_client)
assert resp.status_code == 201, resp.text
data = resp.json()
assert "id" in data
assert data["handle"] == "testuser"
assert data["email"] == "t@example.com"
@pytest.mark.asyncio
async def test_register_weak_password(authed_client):
"""POST /api/auth/register with a short password returns 422."""
resp = await authed_client.post(
"/api/auth/register",
json={"handle": "u", "email": "u@example.com", "password": "short"},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_register_duplicate_email(authed_client):
"""Registering the same email twice returns 409."""
await _register(authed_client, handle="u1", email="dup@example.com")
resp = await _register(authed_client, handle="u2", email="dup@example.com")
assert resp.status_code == 409
assert "already in use" in resp.json()["detail"].lower()
# ── Tests — login ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_login_wrong_password(authed_client):
"""POST /api/auth/login with wrong password returns 401."""
await _register(authed_client)
resp = await _login(authed_client, password="WrongPass99!")
assert resp.status_code == 401
assert "Incorrect email or password" in resp.json()["detail"]
@pytest.mark.asyncio
async def test_login_success(authed_client):
"""Register then login returns 200 with access_token; Set-Cookie has HttpOnly + SameSite."""
await _register(authed_client)
resp = await _login(authed_client)
assert resp.status_code == 200, resp.text
data = resp.json()
assert "access_token" in data
cookie_header = resp.headers.get("set-cookie", "")
assert "HttpOnly" in cookie_header or "httponly" in cookie_header.lower()
assert "SameSite=Strict" in cookie_header or "samesite=strict" in cookie_header.lower()
@pytest.mark.asyncio
async def test_me_requires_auth(authed_client):
"""GET /api/auth/me without Bearer token returns 403 (HTTPBearer default)."""
resp = await authed_client.get("/api/auth/me")
# HTTPBearer raises 403 when no Authorization header present
assert resp.status_code in (401, 403)
@pytest.mark.asyncio
async def test_login_password_must_change(authed_client, db_session: AsyncSession):
"""Login for user with password_must_change=True returns 200 with requires_password_change=True, no Set-Cookie."""
await _register(authed_client, handle="pmcu", email="pmc@example.com")
# Flip the flag in the DB
result = await db_session.execute(select(User).where(User.email == "pmc@example.com"))
user = result.scalar_one()
user.password_must_change = True
await db_session.commit()
resp = await _login(authed_client, email="pmc@example.com")
assert resp.status_code == 200
data = resp.json()
assert data.get("requires_password_change") is True
assert "access_token" not in data
# No Set-Cookie header
assert "set-cookie" not in resp.headers
@pytest.mark.asyncio
async def test_change_password_breach(authed_client):
"""POST /api/auth/change-password with breached new password returns 422."""
await _register(authed_client, handle="cpb", email="cpb@example.com")
login_resp = await _login(authed_client, email="cpb@example.com")
token = login_resp.json()["access_token"]
with patch("services.auth.check_hibp", return_value=True) as mock_hibp:
resp = await authed_client.post(
"/api/auth/change-password",
json={"current_password": "ValidPass12!", "new_password": "StrongNew99!@"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 422
assert "breach" in resp.json()["detail"].lower()
@pytest.mark.asyncio
async def test_change_password_wrong_current(authed_client):
"""POST /api/auth/change-password with wrong current_password returns 401."""
await _register(authed_client, handle="cpw", email="cpw@example.com")
login_resp = await _login(authed_client, email="cpw@example.com")
token = login_resp.json()["access_token"]
resp = await authed_client.post(
"/api/auth/change-password",
json={"current_password": "WrongCurrent9!", "new_password": "NewStrong99!@"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_change_password_success(authed_client):
"""POST /api/auth/change-password with correct current + strong new password returns 200."""
await _register(authed_client, handle="cps", email="cps@example.com")
login_resp = await _login(authed_client, email="cps@example.com")
token = login_resp.json()["access_token"]
with patch("services.auth.check_hibp", return_value=False):
resp = await authed_client.post(
"/api/auth/change-password",
json={"current_password": "ValidPass12!", "new_password": "NewStrong99!@"},
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
assert resp.json()["message"] == "Password updated"
# ── Tests — Origin validation middleware ──────────────────────────────────────
@pytest.mark.asyncio
async def test_origin_rejected(authed_client):
"""POST /api/auth/login with a cross-origin Origin header returns 403."""
resp = await authed_client.post(
"/api/auth/login",
json={"email": "x@x.com", "password": "any"},
headers={"Origin": "https://evil.example"},
)
assert resp.status_code == 403
@pytest.mark.asyncio
async def test_origin_allowed(authed_client):
"""POST /api/auth/login with an allowed Origin proceeds to auth (not 403)."""
resp = await authed_client.post(
"/api/auth/login",
json={"email": "nobody@example.com", "password": "any"},
headers={"Origin": "http://localhost:5173"},
)
# Not 403 — may be 401 (wrong credentials) or 200
assert resp.status_code != 403
# ── Tests — per-account rate limiting ────────────────────────────────────────
@pytest.mark.asyncio
async def test_per_account_rate_limit(authed_client):
"""11 consecutive login attempts with the same email returns 429 on the 11th."""
email = "ratelimit@example.com"
for i in range(10):
resp = await authed_client.post(
"/api/auth/login",
json={"email": email, "password": "WrongPass99!"},
)
assert resp.status_code in (401, 200), f"Unexpected {resp.status_code} on attempt {i+1}"
# 11th attempt should be rate limited
resp = await authed_client.post(
"/api/auth/login",
json={"email": email, "password": "WrongPass99!"},
)
assert resp.status_code == 429, f"Expected 429 on 11th attempt, got {resp.status_code}: {resp.text}"
# ── Tests — backup codes ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_login_backup_code_success(authed_client, db_session: AsyncSession):
"""Login with a valid backup code returns 200 + access_token and invalidates the code."""
await _register(authed_client, handle="bcu", email="bcu@example.com")
result = await db_session.execute(select(User).where(User.email == "bcu@example.com"))
user = result.scalar_one()
# Enable TOTP on the user (so the backup code path is exercised)
user.totp_enabled = True
user.totp_secret = "JBSWY3DPEHPK3PXP" # dummy secret
# Store a backup code
plaintext_code = "ABCD1234"
from services.auth import hash_password
backup_code_row = BackupCode(
id=uuid.uuid4(),
user_id=user.id,
code_hash=hash_password(plaintext_code),
used_at=None,
)
db_session.add(backup_code_row)
await db_session.commit()
resp = await _login(authed_client, email="bcu@example.com", backup_code=plaintext_code)
assert resp.status_code == 200, resp.text
data = resp.json()
assert "access_token" in data
# Verify the code was marked used
await db_session.refresh(backup_code_row)
assert backup_code_row.used_at is not None
@pytest.mark.asyncio
async def test_login_backup_code_reuse(authed_client, db_session: AsyncSession):
"""Using the same backup code a second time returns 401."""
await _register(authed_client, handle="bcr", email="bcr@example.com")
result = await db_session.execute(select(User).where(User.email == "bcr@example.com"))
user = result.scalar_one()
user.totp_enabled = True
user.totp_secret = "JBSWY3DPEHPK3PXP"
plaintext_code = "EFGH5678"
from services.auth import hash_password
backup_code_row = BackupCode(
id=uuid.uuid4(),
user_id=user.id,
code_hash=hash_password(plaintext_code),
used_at=None,
)
db_session.add(backup_code_row)
await db_session.commit()
# First use: should succeed
resp1 = await _login(authed_client, email="bcr@example.com", backup_code=plaintext_code)
assert resp1.status_code == 200, resp1.text
# Second use: should fail
resp2 = await _login(authed_client, email="bcr@example.com", backup_code=plaintext_code)
assert resp2.status_code == 401
assert "Invalid or already used code" in resp2.json()["detail"]
@pytest.mark.asyncio
async def test_login_backup_code_invalid(authed_client, db_session: AsyncSession):
"""POST /api/auth/login with an unknown backup code returns 401."""
await _register(authed_client, handle="bci", email="bci@example.com")
result = await db_session.execute(select(User).where(User.email == "bci@example.com"))
user = result.scalar_one()
user.totp_enabled = True
user.totp_secret = "JBSWY3DPEHPK3PXP"
await db_session.commit()
resp = await _login(authed_client, email="bci@example.com", backup_code="XXXXXXXX")
assert resp.status_code == 401
assert "Invalid or already used code" in resp.json()["detail"]
@pytest.mark.asyncio
async def test_login_totp_takes_precedence(authed_client, db_session: AsyncSession):
"""When both totp_code and backup_code are provided, the TOTP path is used (not backup)."""
await _register(authed_client, handle="ttp", email="ttp@example.com")
result = await db_session.execute(select(User).where(User.email == "ttp@example.com"))
user = result.scalar_one()
user.totp_enabled = True
user.totp_secret = "JBSWY3DPEHPK3PXP"
# Store a valid backup code
plaintext_code = "TOTP1234"
from services.auth import hash_password
backup_code_row = BackupCode(
id=uuid.uuid4(),
user_id=user.id,
code_hash=hash_password(plaintext_code),
used_at=None,
)
db_session.add(backup_code_row)
await db_session.commit()
# Use both fields — totp_code is wrong, backup_code is valid
# If TOTP takes precedence, the wrong totp_code should cause 401
with patch("services.auth.verify_totp", return_value=False) as mock_totp:
resp = await _login(
authed_client,
email="ttp@example.com",
totp_code="000000",
backup_code=plaintext_code,
)
# verify_totp was called (TOTP path taken)
mock_totp.assert_called_once()
# The backup code path was NOT taken — code remains unused
await db_session.refresh(backup_code_row)
assert backup_code_row.used_at is None, "Backup code should NOT have been consumed when totp_code is provided"
assert resp.status_code == 401