9fc820d893
- services/auth.py: Argon2 password hashing (pwdlib), constant-time verify (SEC-06) - JWT create/decode for access tokens and password-reset tokens (typ claim validation, T-02-01) - Refresh token lifecycle: create, rotate, revoke-all with family revocation (AUTH-07, RFC 9700) - Family revocation enqueues send_security_alert_email.delay on token reuse (T-02-02) - TOTP provisioning (pyotp) and verification with Redis replay prevention, valid_window=1 (AUTH-08) - Backup code generation (8-char hex uppercase), storage (Argon2 hashed), constant-time verify (T-02-03) - HIBP k-anonymity check via SHA-1 prefix (T-02-05), fail-open on network error (T-02-06) - Admin bootstrap: idempotent, logs WARNING if env vars missing (D-04/D-05/D-06) - services/email.py: SMTP send + dev stdout fallback (D-01/D-02) - tasks/email_tasks.py: send_reset_email and send_security_alert_email Celery tasks - celery_app.py: add email queue route for tasks.email_tasks.* - TDD tests: 17 tests covering all auth primitives and family revocation
227 lines
7.8 KiB
Python
227 lines
7.8 KiB
Python
"""
|
|
TDD tests for Task 2: services/auth.py and tasks/email_tasks.py.
|
|
|
|
These tests should FAIL before implementation (RED phase).
|
|
"""
|
|
import pytest
|
|
import uuid
|
|
|
|
|
|
def test_hash_password_returns_argon2_hash():
|
|
from services.auth import hash_password
|
|
h = hash_password("TestPass123!")
|
|
assert h.startswith("$argon2")
|
|
|
|
|
|
def test_verify_password_correct():
|
|
from services.auth import hash_password, verify_password
|
|
h = hash_password("TestPass123!")
|
|
assert verify_password("TestPass123!", h) is True
|
|
|
|
|
|
def test_verify_password_wrong():
|
|
from services.auth import hash_password, verify_password
|
|
h = hash_password("TestPass123!")
|
|
assert verify_password("WrongPass!", h) is False
|
|
|
|
|
|
def test_create_access_token_jwt_format():
|
|
from services.auth import create_access_token
|
|
t = create_access_token("test-uid", "user")
|
|
assert isinstance(t, str)
|
|
assert t.count(".") == 2 # JWT has 3 parts separated by 2 dots
|
|
|
|
|
|
def test_decode_access_token_valid():
|
|
from services.auth import create_access_token, decode_access_token
|
|
t = create_access_token("test-uid", "user")
|
|
payload = decode_access_token(t)
|
|
assert payload["sub"] == "test-uid"
|
|
assert payload["role"] == "user"
|
|
assert payload["typ"] == "access"
|
|
|
|
|
|
def test_decode_access_token_tampered_raises():
|
|
from services.auth import create_access_token, decode_access_token
|
|
t = create_access_token("test-uid", "user")
|
|
tampered = t[:-5] + "XXXXX"
|
|
with pytest.raises(ValueError):
|
|
decode_access_token(tampered)
|
|
|
|
|
|
def test_decode_access_token_wrong_typ_raises():
|
|
"""An access token must have typ='access'; other typ values should raise ValueError."""
|
|
from services.auth import create_password_reset_token, decode_access_token
|
|
reset_token = create_password_reset_token("test-uid")
|
|
with pytest.raises(ValueError):
|
|
decode_access_token(reset_token)
|
|
|
|
|
|
def test_generate_backup_codes():
|
|
from services.auth import generate_backup_codes
|
|
codes = generate_backup_codes(10)
|
|
assert len(codes) == 10
|
|
for code in codes:
|
|
assert isinstance(code, str)
|
|
assert len(code) == 8
|
|
|
|
|
|
def test_generate_backup_codes_alphanumeric():
|
|
from services.auth import generate_backup_codes
|
|
codes = generate_backup_codes(10)
|
|
for code in codes:
|
|
assert code.isalnum(), f"Code '{code}' contains non-alphanumeric chars"
|
|
|
|
|
|
def test_create_password_reset_token():
|
|
from services.auth import create_password_reset_token, decode_password_reset_token
|
|
token = create_password_reset_token("test-uid")
|
|
assert token.count(".") == 2
|
|
uid = decode_password_reset_token(token)
|
|
assert uid == "test-uid"
|
|
|
|
|
|
def test_decode_password_reset_token_wrong_typ_raises():
|
|
"""An access token must not be accepted as a password reset token."""
|
|
from services.auth import create_access_token, decode_password_reset_token
|
|
access_token = create_access_token("test-uid", "user")
|
|
with pytest.raises(ValueError):
|
|
decode_password_reset_token(access_token)
|
|
|
|
|
|
def test_no_fastapi_imports_in_auth_service():
|
|
"""services/auth.py must not import FastAPI or raise HTTPException."""
|
|
import os
|
|
import re
|
|
path = os.path.join(os.path.dirname(__file__), "..", "services", "auth.py")
|
|
with open(path) as f:
|
|
source = f.read()
|
|
# Check for actual imports (not docstring mentions)
|
|
assert not re.search(r"^(?:from|import)\s+fastapi", source, re.MULTILINE), \
|
|
"services/auth.py must not import from fastapi"
|
|
# Check no raise HTTPException in actual code (not docstrings)
|
|
assert not re.search(r"raise\s+HTTPException", source), \
|
|
"services/auth.py must not raise HTTPException"
|
|
|
|
|
|
def test_security_alert_email_referenced_in_auth():
|
|
"""rotate_refresh_token must reference send_security_alert_email (AUTH-07)."""
|
|
import os
|
|
path = os.path.join(os.path.dirname(__file__), "..", "services", "auth.py")
|
|
with open(path) as f:
|
|
source = f.read()
|
|
assert "send_security_alert_email" in source
|
|
|
|
|
|
def test_email_tasks_importable():
|
|
"""email_tasks.py must define both tasks (may skip if celery not installed locally)."""
|
|
try:
|
|
from tasks.email_tasks import send_reset_email, send_security_alert_email
|
|
assert callable(send_reset_email)
|
|
assert callable(send_security_alert_email)
|
|
except ModuleNotFoundError as e:
|
|
if "celery" in str(e):
|
|
pytest.skip("celery not installed in local test environment (runs in Docker)")
|
|
raise
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_and_rotate_refresh_token(db_session):
|
|
"""create_refresh_token creates a DB row; rotate_refresh_token returns new token."""
|
|
import uuid as _uuid
|
|
from db.models import User, Quota
|
|
from services.auth import (
|
|
create_refresh_token, rotate_refresh_token, hash_password
|
|
)
|
|
|
|
# Create a user
|
|
user = User(
|
|
id=_uuid.uuid4(),
|
|
handle="testuser",
|
|
email="test@example.com",
|
|
password_hash=hash_password("pass"),
|
|
role="user",
|
|
)
|
|
quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0)
|
|
db_session.add(user)
|
|
db_session.add(quota)
|
|
await db_session.commit()
|
|
|
|
# Create a refresh token
|
|
raw = await create_refresh_token(db_session, user.id)
|
|
assert isinstance(raw, str)
|
|
assert len(raw) > 10
|
|
|
|
# Rotate the token
|
|
new_raw, user_id_str = await rotate_refresh_token(db_session, raw)
|
|
assert isinstance(new_raw, str)
|
|
assert user_id_str == str(user.id)
|
|
assert new_raw != raw
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rotate_revoked_token_raises(db_session):
|
|
"""Rotating a revoked token should raise ValueError('token_family_revoked')."""
|
|
import uuid as _uuid
|
|
from unittest.mock import patch, MagicMock
|
|
from db.models import User, Quota
|
|
from services.auth import (
|
|
create_refresh_token, rotate_refresh_token, hash_password
|
|
)
|
|
|
|
user = User(
|
|
id=_uuid.uuid4(),
|
|
handle="testuser2",
|
|
email="test2@example.com",
|
|
password_hash=hash_password("pass"),
|
|
role="user",
|
|
)
|
|
quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0)
|
|
db_session.add(user)
|
|
db_session.add(quota)
|
|
await db_session.commit()
|
|
|
|
raw = await create_refresh_token(db_session, user.id)
|
|
# First rotate (valid)
|
|
new_raw, _ = await rotate_refresh_token(db_session, raw)
|
|
|
|
# Now try to re-use the original (now revoked) token.
|
|
# Mock the celery task at the point it is imported inside rotate_refresh_token.
|
|
mock_task = MagicMock()
|
|
mock_task.delay = MagicMock()
|
|
with patch.dict("sys.modules", {"tasks.email_tasks": MagicMock(send_security_alert_email=mock_task)}):
|
|
with pytest.raises(ValueError, match="token_family_revoked"):
|
|
await rotate_refresh_token(db_session, raw)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_store_and_verify_backup_codes(db_session):
|
|
"""store_backup_codes inserts rows; verify_backup_code matches correct code."""
|
|
import uuid as _uuid
|
|
from db.models import User, Quota
|
|
from services.auth import (
|
|
generate_backup_codes, store_backup_codes, verify_backup_code, hash_password
|
|
)
|
|
|
|
user = User(
|
|
id=_uuid.uuid4(),
|
|
handle="backupuser",
|
|
email="backup@example.com",
|
|
password_hash=hash_password("pass"),
|
|
role="user",
|
|
)
|
|
quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0)
|
|
db_session.add(user)
|
|
db_session.add(quota)
|
|
await db_session.commit()
|
|
|
|
codes = generate_backup_codes(10)
|
|
await store_backup_codes(db_session, user.id, codes)
|
|
|
|
# Verify with correct code
|
|
assert await verify_backup_code(db_session, user.id, codes[0]) is True
|
|
# Verify with already-used code should return False
|
|
assert await verify_backup_code(db_session, user.id, codes[0]) is False
|
|
# Verify with wrong code
|
|
assert await verify_backup_code(db_session, user.id, "XXXXXXXX") is False
|