feat(02-01): implement services/auth.py full auth service layer and email_tasks.py

- services/auth.py: Argon2 password hashing (pwdlib), constant-time verify (SEC-06)
- JWT create/decode for access tokens and password-reset tokens (typ claim validation, T-02-01)
- Refresh token lifecycle: create, rotate, revoke-all with family revocation (AUTH-07, RFC 9700)
- Family revocation enqueues send_security_alert_email.delay on token reuse (T-02-02)
- TOTP provisioning (pyotp) and verification with Redis replay prevention, valid_window=1 (AUTH-08)
- Backup code generation (8-char hex uppercase), storage (Argon2 hashed), constant-time verify (T-02-03)
- HIBP k-anonymity check via SHA-1 prefix (T-02-05), fail-open on network error (T-02-06)
- Admin bootstrap: idempotent, logs WARNING if env vars missing (D-04/D-05/D-06)
- services/email.py: SMTP send + dev stdout fallback (D-01/D-02)
- tasks/email_tasks.py: send_reset_email and send_security_alert_email Celery tasks
- celery_app.py: add email queue route for tasks.email_tasks.*
- TDD tests: 17 tests covering all auth primitives and family revocation
This commit is contained in:
curo1305
2026-05-22 19:23:42 +02:00
parent 12c6487855
commit 9fc820d893
5 changed files with 845 additions and 1 deletions
+3 -1
View File
@@ -26,9 +26,11 @@ celery_app.conf.task_serializer = "json"
celery_app.conf.result_serializer = "json"
celery_app.conf.accept_content = ["json"]
# Route document tasks to the dedicated `documents` queue
# Route document tasks to the dedicated `documents` queue;
# email tasks to the `email` queue (Phase 2 — D-03)
celery_app.conf.task_routes = {
"tasks.document_tasks.*": {"queue": "documents"},
"tasks.email_tasks.*": {"queue": "email"},
}
# Autodiscover tasks under the `tasks/` package
+428
View File
@@ -0,0 +1,428 @@
"""
Auth service — pure Python, no FastAPI coupling.
Handles:
- Password hashing (Argon2 via pwdlib) and constant-time verification (SEC-06)
- JWT access token creation/decode (PyJWT)
- Refresh token lifecycle with family revocation on reuse (AUTH-07, RFC 9700)
- TOTP provisioning and verification with replay prevention (AUTH-08)
- Backup code generation, storage, and constant-time verification (AUTH-02)
- HaveIBeenPwned k-anonymity check (SEC-03)
- Admin account bootstrap (D-04, D-05, D-06)
Security invariants:
- All token/code comparisons use hmac.compare_digest (constant-time, SEC-06)
- No function raises HTTPException — callers (api/) map ValueError to HTTP errors
- refresh token family revocation enqueues send_security_alert_email.delay (AUTH-07)
"""
from __future__ import annotations
import hashlib
import hmac
import logging
import secrets
import uuid
from datetime import datetime, timezone, timedelta
from typing import Optional
import httpx
import jwt
import pyotp
from pwdlib import PasswordHash
from pwdlib.hashers.argon2 import Argon2Hasher
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from config import settings
from db.models import BackupCode, Quota, RefreshToken, User
logger = logging.getLogger(__name__)
# ── Password hashing ────────────────────────────────────────────────────────────
# Single shared PasswordHash instance; Argon2 is the only enabled hasher.
_pwd = PasswordHash([Argon2Hasher()])
def hash_password(plain: str) -> str:
"""Return the Argon2 hash of *plain*."""
return _pwd.hash(plain)
def verify_password(plain: str, hashed: str) -> bool:
"""Return True if *plain* matches *hashed* (constant-time, SEC-06).
pwdlib.verify returns True/False and internally uses constant-time comparison.
"""
try:
return _pwd.verify(plain, hashed)
except Exception:
return False
# ── JWT helpers ─────────────────────────────────────────────────────────────────
def create_access_token(user_id: str, role: str) -> str:
"""Return a signed JWT access token.
Claims: sub=user_id, role=role, typ='access', exp=now+access_token_expire_minutes.
"""
now = datetime.now(timezone.utc)
payload = {
"sub": str(user_id),
"role": role,
"typ": "access",
"iat": now,
"exp": now + timedelta(minutes=settings.access_token_expire_minutes),
}
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
def decode_access_token(token: str) -> dict:
"""Decode and validate an access token; raises ValueError on any failure.
Verifies: signature, expiry, and typ='access' (T-02-01 — prevents password-reset
tokens from being used as access tokens).
"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
except jwt.ExpiredSignatureError as exc:
raise ValueError("Token has expired") from exc
except jwt.PyJWTError as exc:
raise ValueError(f"Invalid token: {exc}") from exc
if payload.get("typ") != "access":
raise ValueError("Token type mismatch: expected 'access'")
return payload
def create_password_reset_token(user_id: str) -> str:
"""Return a short-lived signed JWT for password reset.
Claims: sub=user_id, typ='password-reset', exp=now+3600s.
"""
now = datetime.now(timezone.utc)
payload = {
"sub": str(user_id),
"typ": "password-reset",
"iat": now,
"exp": now + timedelta(seconds=3600),
}
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
def decode_password_reset_token(token: str) -> str:
"""Decode a password-reset token; raises ValueError on failure.
Returns the user_id string.
"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
except jwt.ExpiredSignatureError as exc:
raise ValueError("Reset token has expired") from exc
except jwt.PyJWTError as exc:
raise ValueError(f"Invalid reset token: {exc}") from exc
if payload.get("typ") != "password-reset":
raise ValueError("Token type mismatch: expected 'password-reset'")
return payload["sub"]
# ── Refresh token lifecycle ─────────────────────────────────────────────────────
async def create_refresh_token(session: AsyncSession, user_id: uuid.UUID) -> str:
"""Insert a new RefreshToken row and return the raw (unhashed) token string.
The raw token is returned to the caller and set as an httpOnly cookie.
Only the SHA-256 hash is stored in the database.
"""
raw = secrets.token_urlsafe(32)
token_hash = hashlib.sha256(raw.encode()).hexdigest()
now = datetime.now(timezone.utc)
row = RefreshToken(
id=uuid.uuid4(),
user_id=user_id,
token_hash=token_hash,
expires_at=now + timedelta(days=settings.refresh_token_expire_days),
revoked=False,
)
session.add(row)
await session.commit()
return raw
async def rotate_refresh_token(
session: AsyncSession, raw_token: str
) -> tuple[str, str]:
"""Rotate a refresh token: revoke the old one and issue a new one.
Returns (new_raw_token, user_id_str).
Family revocation (AUTH-07 / RFC 9700):
If the presented token has revoked=True, ALL tokens for that user are
revoked and send_security_alert_email.delay is enqueued. Then raises
ValueError("token_family_revoked").
Raises ValueError for any invalid or expired token.
"""
token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
row: Optional[RefreshToken] = result.scalar_one_or_none()
if row is None:
raise ValueError("Refresh token not found")
if row.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
raise ValueError("Refresh token has expired")
if row.revoked:
# T-02-02: reuse of revoked token — family revocation
await revoke_all_refresh_tokens(session, row.user_id)
# Enqueue security alert email (deferred import to avoid circular dependency)
from tasks.email_tasks import send_security_alert_email # noqa: PLC0415
send_security_alert_email.delay(str(row.user_id))
raise ValueError("token_family_revoked")
# Valid token: revoke the old one
row.revoked = True
await session.flush()
# Issue a new token in the same family
new_raw = await create_refresh_token(session, row.user_id)
return new_raw, str(row.user_id)
async def revoke_all_refresh_tokens(
session: AsyncSession, user_id: uuid.UUID
) -> int:
"""Mark all active refresh tokens for user_id as revoked.
Returns the count of revoked tokens (supports sign-out-all-devices).
"""
result = await session.execute(
select(RefreshToken).where(
RefreshToken.user_id == user_id,
RefreshToken.revoked.is_(False),
)
)
rows = result.scalars().all()
count = 0
for row in rows:
row.revoked = True
count += 1
await session.flush()
return count
# ── TOTP provisioning ───────────────────────────────────────────────────────────
async def provision_totp(
session: AsyncSession, user_id: uuid.UUID
) -> tuple[str, str]:
"""Generate a new TOTP secret for the user and store it (not yet enabled).
Returns (secret, provisioning_uri).
The secret is base32-encoded. provisioning_uri is suitable for QR code generation.
"""
user = await session.get(User, user_id)
if user is None:
raise ValueError("User not found")
secret = pyotp.random_base32()
user.totp_secret = secret
await session.commit()
totp = pyotp.totp.TOTP(secret)
uri = totp.provisioning_uri(user.email, issuer_name="DocuVault")
return secret, uri
async def verify_totp(
session: AsyncSession,
user_id: uuid.UUID,
code: str,
redis_client,
) -> bool:
"""Verify a TOTP code with replay prevention (AUTH-08).
valid_window=1 allows ±30s clock drift (per STATE.md recommendation).
Replay prevention: stores used codes in Redis with key 'totp_used:{user_id}:{code}'
and TTL=90s (covers the full ±30s validity window).
Returns False if the code is invalid, already used, or user has no TOTP secret.
"""
user = await session.get(User, user_id)
if user is None or not user.totp_secret:
return False
totp = pyotp.TOTP(user.totp_secret)
# Check replay prevention before verifying
replay_key = f"totp_used:{user_id}:{code}"
if await redis_client.get(replay_key):
return False # Code already used within the validity window
if not totp.verify(code, valid_window=1):
return False
# Mark as used for 90s (covers valid_window=1: ±30s = 90s total)
await redis_client.set(replay_key, "1", ex=90)
return True
# ── Backup codes ────────────────────────────────────────────────────────────────
def generate_backup_codes(n: int = 10) -> list[str]:
"""Return *n* random 8-character uppercase alphanumeric backup codes."""
codes = []
for _ in range(n):
# secrets.token_hex(4) returns 8 hex chars; uppercase for readability
code = secrets.token_hex(4).upper()
codes.append(code)
return codes
async def store_backup_codes(
session: AsyncSession, user_id: uuid.UUID, codes: list[str]
) -> None:
"""Store backup codes for a user, replacing any existing unused codes.
Each code is stored as an Argon2 hash (never plaintext, T-02-03).
Existing unused codes are deleted first to prevent accumulation.
"""
# Delete existing unused codes
result = await session.execute(
select(BackupCode).where(
BackupCode.user_id == user_id,
BackupCode.used_at.is_(None),
)
)
for row in result.scalars().all():
await session.delete(row)
await session.flush()
# Insert new hashed codes
for code in codes:
row = BackupCode(
id=uuid.uuid4(),
user_id=user_id,
code_hash=hash_password(code),
used_at=None,
)
session.add(row)
await session.commit()
async def verify_backup_code(
session: AsyncSession, user_id: uuid.UUID, code: str
) -> bool:
"""Verify a backup code using constant-time comparison.
Always iterates ALL unused codes (prevents timing-based enumeration, SEC-06).
On match: sets used_at=now() and commits. Returns True on first match.
Returns False if no code matches.
"""
result = await session.execute(
select(BackupCode).where(
BackupCode.user_id == user_id,
BackupCode.used_at.is_(None),
)
)
rows = result.scalars().all()
matched_row: Optional[BackupCode] = None
for row in rows:
# Always call verify_password for ALL rows (constant-time: no early exit)
if verify_password(code, row.code_hash):
matched_row = row # record match but keep iterating
if matched_row is None:
return False
# Mark the matched code as used
matched_row.used_at = datetime.now(timezone.utc)
await session.commit()
return True
# ── HaveIBeenPwned check ────────────────────────────────────────────────────────
async def check_hibp(password: str) -> bool:
"""Check if password appears in HaveIBeenPwned using the k-anonymity model.
Sends only the first 5 chars of the SHA-1 hash to the HIBP API (T-02-05).
Returns True if the password has been breached, False otherwise.
On network error: logs a warning and returns False (fail-open, T-02-06).
"""
sha1 = hashlib.sha1(password.encode("utf-8")).hexdigest().upper()
prefix, suffix = sha1[:5], sha1[5:]
try:
async with httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(
f"https://api.pwnedpasswords.com/range/{prefix}",
headers={"Add-Padding": "true"},
)
resp.raise_for_status()
except Exception as exc:
logger.warning("HIBP check failed (fail-open): %s", exc)
return False
for line in resp.text.splitlines():
parts = line.split(":")
if len(parts) == 2:
candidate_suffix, count_str = parts
if hmac.compare_digest(candidate_suffix.upper(), suffix):
try:
count = int(count_str.strip())
except ValueError:
count = 1
if count > 0:
return True
return False
# ── Admin bootstrap ─────────────────────────────────────────────────────────────
async def bootstrap_admin(session: AsyncSession) -> None:
"""Idempotent admin account bootstrap (D-04, D-05, D-06).
If the users table is empty AND settings.admin_email and settings.admin_password
are both non-empty, creates an admin User row with a Quota row.
Logs a WARNING if env vars are missing (D-05) but never raises.
"""
if not settings.admin_email or not settings.admin_password:
logger.warning(
"Admin bootstrap skipped: ADMIN_EMAIL and/or ADMIN_PASSWORD not set (D-05). "
"Set both env vars to seed the first admin account on startup."
)
return
# Check if any users exist
result = await session.execute(select(User).limit(1))
if result.scalar_one_or_none() is not None:
# Users already exist — idempotent, skip (D-04)
return
admin_id = uuid.uuid4()
admin_user = User(
id=admin_id,
handle="admin",
email=settings.admin_email,
password_hash=hash_password(settings.admin_password),
role="admin",
is_active=True,
password_must_change=False,
)
quota = Quota(
user_id=admin_id,
limit_bytes=104857600, # 100 MB default (D-06)
used_bytes=0,
)
session.add(admin_user)
session.add(quota)
await session.commit()
logger.info("Admin account bootstrapped for %s", settings.admin_email)
+115
View File
@@ -0,0 +1,115 @@
"""
Email service — pure Python, no FastAPI coupling.
Sends transactional emails via SMTP when SMTP_HOST is configured;
logs the content to stdout otherwise (D-02 dev fallback).
Security notes:
- Never raises: email failures are non-fatal; log and return
- Celery task wrapper handles retry/error reporting
"""
import logging
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
logger = logging.getLogger(__name__)
def send_password_reset_email(to_address: str, reset_link: str) -> None:
"""Send (or log) the password reset email.
When SMTP_HOST is not configured (dev / CI), logs the reset link to stdout
per D-02. The API response is 202 regardless — no token in the body.
Never raises — failures are logged and the function returns normally.
"""
from config import settings # deferred to avoid module-level side effects
if not settings.smtp_host:
# D-02: dev fallback — log token link to stdout
logger.info("DEV MODE — password reset link for %s: %s", to_address, reset_link)
return
try:
msg = MIMEMultipart("alternative")
msg["Subject"] = "DocuVault — password reset"
msg["From"] = settings.smtp_from
msg["To"] = to_address
text_body = (
f"You requested a password reset for DocuVault.\n\n"
f"Click the link below (valid for 1 hour):\n{reset_link}\n\n"
"If you did not request this, ignore this email."
)
html_body = (
f"<p>You requested a password reset for DocuVault.</p>"
f"<p><a href='{reset_link}'>Reset your password</a> (valid 1 hour)</p>"
f"<p>If you did not request this, ignore this email.</p>"
)
msg.attach(MIMEText(text_body, "plain"))
msg.attach(MIMEText(html_body, "html"))
with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server:
server.ehlo()
server.starttls()
if settings.smtp_user:
server.login(settings.smtp_user, settings.smtp_password)
server.sendmail(settings.smtp_from, [to_address], msg.as_string())
logger.info("Password reset email sent to %s", to_address)
except Exception as exc:
logger.error("Failed to send password reset email to %s: %s", to_address, exc)
def send_security_alert_email_sync(to_address: str, user_id: str) -> None:
"""Send (or log) a security alert email about suspicious refresh token reuse.
Called by the email_tasks.py Celery task after looking up the user email.
Never raises — failures are logged.
"""
from config import settings # deferred import
if not settings.smtp_host:
logger.warning(
"Security alert for user %s: suspicious refresh token reuse detected "
"(SMTP not configured — email not sent to %s)",
user_id,
to_address,
)
return
try:
msg = MIMEMultipart("alternative")
msg["Subject"] = "DocuVault — security alert: suspicious login detected"
msg["From"] = settings.smtp_from
msg["To"] = to_address
text_body = (
"DocuVault detected suspicious activity on your account.\n\n"
"A previously revoked refresh token was used to attempt a session refresh. "
"All active sessions have been revoked as a precaution.\n\n"
"If this was not you, please change your password immediately."
)
html_body = (
"<p><strong>DocuVault security alert</strong></p>"
"<p>A previously revoked refresh token was used to attempt a session refresh. "
"All active sessions have been revoked as a precaution.</p>"
"<p>If this was not you, please change your password immediately.</p>"
)
msg.attach(MIMEText(text_body, "plain"))
msg.attach(MIMEText(html_body, "html"))
with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server:
server.ehlo()
server.starttls()
if settings.smtp_user:
server.login(settings.smtp_user, settings.smtp_password)
server.sendmail(settings.smtp_from, [to_address], msg.as_string())
logger.info("Security alert email sent to %s (user_id=%s)", to_address, user_id)
except Exception as exc:
logger.error(
"Failed to send security alert email to %s (user_id=%s): %s",
to_address, user_id, exc
)
+73
View File
@@ -0,0 +1,73 @@
"""
Celery tasks for email dispatch in DocuVault.
Tasks follow the same pattern as document_tasks.py:
- Plain sync def (Celery workers have no asyncio event loop by default)
- Async body via asyncio.run()
- All imports deferred inside _run_* functions to avoid circular imports
(see celery_app.py comment — do NOT import config at module level)
Tasks registered here:
send_reset_email — dispatches a password-reset email to the user
send_security_alert_email — dispatches a security alert on refresh token reuse (AUTH-07)
"""
import asyncio
from celery_app import celery_app
@celery_app.task(name="tasks.email_tasks.send_reset_email")
def send_reset_email(to_address: str, reset_link: str) -> dict:
"""Synchronous Celery entry-point — send a password reset email.
Called as: send_reset_email.delay(to_address, reset_link)
Delegates to the async body via asyncio.run().
"""
return asyncio.run(_run_send_reset(to_address, reset_link))
async def _run_send_reset(to_address: str, reset_link: str) -> dict:
"""Async body of send_reset_email. Deferred imports to avoid circular deps."""
from services.email import send_password_reset_email # noqa: PLC0415
try:
send_password_reset_email(to_address, reset_link)
return {"status": "sent", "to": to_address}
except Exception as exc:
return {"status": "failed", "error": str(exc)}
@celery_app.task(name="tasks.email_tasks.send_security_alert_email")
def send_security_alert_email(user_id: str) -> dict:
"""Synchronous Celery entry-point — send a security alert on token reuse.
Called as: send_security_alert_email.delay(user_id)
Fetches the user's email from the DB inside the task using asyncio.run().
On SMTP not configured: logs a WARNING per D-02 convention.
"""
return asyncio.run(_run_send_security_alert(user_id))
async def _run_send_security_alert(user_id: str) -> dict:
"""Async body of send_security_alert_email. Deferred imports to avoid circular deps."""
import uuid as _uuid # noqa: PLC0415
from db.session import AsyncSessionLocal # noqa: PLC0415
from db.models import User # noqa: PLC0415
from services.email import send_security_alert_email_sync # noqa: PLC0415
try:
user_uuid = _uuid.UUID(user_id)
except ValueError:
return {"status": "failed", "error": f"Invalid user_id: {user_id}"}
try:
async with AsyncSessionLocal() as session:
user = await session.get(User, user_uuid)
if user is None:
return {"status": "failed", "error": f"User {user_id} not found"}
to_address = user.email
send_security_alert_email_sync(to_address, user_id)
return {"status": "sent", "user_id": user_id}
except Exception as exc:
return {"status": "failed", "error": str(exc)}
+226
View File
@@ -0,0 +1,226 @@
"""
TDD tests for Task 2: services/auth.py and tasks/email_tasks.py.
These tests should FAIL before implementation (RED phase).
"""
import pytest
import uuid
def test_hash_password_returns_argon2_hash():
from services.auth import hash_password
h = hash_password("TestPass123!")
assert h.startswith("$argon2")
def test_verify_password_correct():
from services.auth import hash_password, verify_password
h = hash_password("TestPass123!")
assert verify_password("TestPass123!", h) is True
def test_verify_password_wrong():
from services.auth import hash_password, verify_password
h = hash_password("TestPass123!")
assert verify_password("WrongPass!", h) is False
def test_create_access_token_jwt_format():
from services.auth import create_access_token
t = create_access_token("test-uid", "user")
assert isinstance(t, str)
assert t.count(".") == 2 # JWT has 3 parts separated by 2 dots
def test_decode_access_token_valid():
from services.auth import create_access_token, decode_access_token
t = create_access_token("test-uid", "user")
payload = decode_access_token(t)
assert payload["sub"] == "test-uid"
assert payload["role"] == "user"
assert payload["typ"] == "access"
def test_decode_access_token_tampered_raises():
from services.auth import create_access_token, decode_access_token
t = create_access_token("test-uid", "user")
tampered = t[:-5] + "XXXXX"
with pytest.raises(ValueError):
decode_access_token(tampered)
def test_decode_access_token_wrong_typ_raises():
"""An access token must have typ='access'; other typ values should raise ValueError."""
from services.auth import create_password_reset_token, decode_access_token
reset_token = create_password_reset_token("test-uid")
with pytest.raises(ValueError):
decode_access_token(reset_token)
def test_generate_backup_codes():
from services.auth import generate_backup_codes
codes = generate_backup_codes(10)
assert len(codes) == 10
for code in codes:
assert isinstance(code, str)
assert len(code) == 8
def test_generate_backup_codes_alphanumeric():
from services.auth import generate_backup_codes
codes = generate_backup_codes(10)
for code in codes:
assert code.isalnum(), f"Code '{code}' contains non-alphanumeric chars"
def test_create_password_reset_token():
from services.auth import create_password_reset_token, decode_password_reset_token
token = create_password_reset_token("test-uid")
assert token.count(".") == 2
uid = decode_password_reset_token(token)
assert uid == "test-uid"
def test_decode_password_reset_token_wrong_typ_raises():
"""An access token must not be accepted as a password reset token."""
from services.auth import create_access_token, decode_password_reset_token
access_token = create_access_token("test-uid", "user")
with pytest.raises(ValueError):
decode_password_reset_token(access_token)
def test_no_fastapi_imports_in_auth_service():
"""services/auth.py must not import FastAPI or raise HTTPException."""
import os
import re
path = os.path.join(os.path.dirname(__file__), "..", "services", "auth.py")
with open(path) as f:
source = f.read()
# Check for actual imports (not docstring mentions)
assert not re.search(r"^(?:from|import)\s+fastapi", source, re.MULTILINE), \
"services/auth.py must not import from fastapi"
# Check no raise HTTPException in actual code (not docstrings)
assert not re.search(r"raise\s+HTTPException", source), \
"services/auth.py must not raise HTTPException"
def test_security_alert_email_referenced_in_auth():
"""rotate_refresh_token must reference send_security_alert_email (AUTH-07)."""
import os
path = os.path.join(os.path.dirname(__file__), "..", "services", "auth.py")
with open(path) as f:
source = f.read()
assert "send_security_alert_email" in source
def test_email_tasks_importable():
"""email_tasks.py must define both tasks (may skip if celery not installed locally)."""
try:
from tasks.email_tasks import send_reset_email, send_security_alert_email
assert callable(send_reset_email)
assert callable(send_security_alert_email)
except ModuleNotFoundError as e:
if "celery" in str(e):
pytest.skip("celery not installed in local test environment (runs in Docker)")
raise
@pytest.mark.asyncio
async def test_create_and_rotate_refresh_token(db_session):
"""create_refresh_token creates a DB row; rotate_refresh_token returns new token."""
import uuid as _uuid
from db.models import User, Quota
from services.auth import (
create_refresh_token, rotate_refresh_token, hash_password
)
# Create a user
user = User(
id=_uuid.uuid4(),
handle="testuser",
email="test@example.com",
password_hash=hash_password("pass"),
role="user",
)
quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0)
db_session.add(user)
db_session.add(quota)
await db_session.commit()
# Create a refresh token
raw = await create_refresh_token(db_session, user.id)
assert isinstance(raw, str)
assert len(raw) > 10
# Rotate the token
new_raw, user_id_str = await rotate_refresh_token(db_session, raw)
assert isinstance(new_raw, str)
assert user_id_str == str(user.id)
assert new_raw != raw
@pytest.mark.asyncio
async def test_rotate_revoked_token_raises(db_session):
"""Rotating a revoked token should raise ValueError('token_family_revoked')."""
import uuid as _uuid
from unittest.mock import patch, MagicMock
from db.models import User, Quota
from services.auth import (
create_refresh_token, rotate_refresh_token, hash_password
)
user = User(
id=_uuid.uuid4(),
handle="testuser2",
email="test2@example.com",
password_hash=hash_password("pass"),
role="user",
)
quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0)
db_session.add(user)
db_session.add(quota)
await db_session.commit()
raw = await create_refresh_token(db_session, user.id)
# First rotate (valid)
new_raw, _ = await rotate_refresh_token(db_session, raw)
# Now try to re-use the original (now revoked) token.
# Mock the celery task at the point it is imported inside rotate_refresh_token.
mock_task = MagicMock()
mock_task.delay = MagicMock()
with patch.dict("sys.modules", {"tasks.email_tasks": MagicMock(send_security_alert_email=mock_task)}):
with pytest.raises(ValueError, match="token_family_revoked"):
await rotate_refresh_token(db_session, raw)
@pytest.mark.asyncio
async def test_store_and_verify_backup_codes(db_session):
"""store_backup_codes inserts rows; verify_backup_code matches correct code."""
import uuid as _uuid
from db.models import User, Quota
from services.auth import (
generate_backup_codes, store_backup_codes, verify_backup_code, hash_password
)
user = User(
id=_uuid.uuid4(),
handle="backupuser",
email="backup@example.com",
password_hash=hash_password("pass"),
role="user",
)
quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0)
db_session.add(user)
db_session.add(quota)
await db_session.commit()
codes = generate_backup_codes(10)
await store_backup_codes(db_session, user.id, codes)
# Verify with correct code
assert await verify_backup_code(db_session, user.id, codes[0]) is True
# Verify with already-used code should return False
assert await verify_backup_code(db_session, user.id, codes[0]) is False
# Verify with wrong code
assert await verify_backup_code(db_session, user.id, "XXXXXXXX") is False