feat(02-01): implement services/auth.py full auth service layer and email_tasks.py
- services/auth.py: Argon2 password hashing (pwdlib), constant-time verify (SEC-06) - JWT create/decode for access tokens and password-reset tokens (typ claim validation, T-02-01) - Refresh token lifecycle: create, rotate, revoke-all with family revocation (AUTH-07, RFC 9700) - Family revocation enqueues send_security_alert_email.delay on token reuse (T-02-02) - TOTP provisioning (pyotp) and verification with Redis replay prevention, valid_window=1 (AUTH-08) - Backup code generation (8-char hex uppercase), storage (Argon2 hashed), constant-time verify (T-02-03) - HIBP k-anonymity check via SHA-1 prefix (T-02-05), fail-open on network error (T-02-06) - Admin bootstrap: idempotent, logs WARNING if env vars missing (D-04/D-05/D-06) - services/email.py: SMTP send + dev stdout fallback (D-01/D-02) - tasks/email_tasks.py: send_reset_email and send_security_alert_email Celery tasks - celery_app.py: add email queue route for tasks.email_tasks.* - TDD tests: 17 tests covering all auth primitives and family revocation
This commit is contained in:
@@ -26,9 +26,11 @@ celery_app.conf.task_serializer = "json"
|
||||
celery_app.conf.result_serializer = "json"
|
||||
celery_app.conf.accept_content = ["json"]
|
||||
|
||||
# Route document tasks to the dedicated `documents` queue
|
||||
# Route document tasks to the dedicated `documents` queue;
|
||||
# email tasks to the `email` queue (Phase 2 — D-03)
|
||||
celery_app.conf.task_routes = {
|
||||
"tasks.document_tasks.*": {"queue": "documents"},
|
||||
"tasks.email_tasks.*": {"queue": "email"},
|
||||
}
|
||||
|
||||
# Autodiscover tasks under the `tasks/` package
|
||||
|
||||
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
Auth service — pure Python, no FastAPI coupling.
|
||||
|
||||
Handles:
|
||||
- Password hashing (Argon2 via pwdlib) and constant-time verification (SEC-06)
|
||||
- JWT access token creation/decode (PyJWT)
|
||||
- Refresh token lifecycle with family revocation on reuse (AUTH-07, RFC 9700)
|
||||
- TOTP provisioning and verification with replay prevention (AUTH-08)
|
||||
- Backup code generation, storage, and constant-time verification (AUTH-02)
|
||||
- HaveIBeenPwned k-anonymity check (SEC-03)
|
||||
- Admin account bootstrap (D-04, D-05, D-06)
|
||||
|
||||
Security invariants:
|
||||
- All token/code comparisons use hmac.compare_digest (constant-time, SEC-06)
|
||||
- No function raises HTTPException — callers (api/) map ValueError to HTTP errors
|
||||
- refresh token family revocation enqueues send_security_alert_email.delay (AUTH-07)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import pyotp
|
||||
from pwdlib import PasswordHash
|
||||
from pwdlib.hashers.argon2 import Argon2Hasher
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import settings
|
||||
from db.models import BackupCode, Quota, RefreshToken, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Password hashing ────────────────────────────────────────────────────────────
|
||||
# Single shared PasswordHash instance; Argon2 is the only enabled hasher.
|
||||
_pwd = PasswordHash([Argon2Hasher()])
|
||||
|
||||
|
||||
def hash_password(plain: str) -> str:
|
||||
"""Return the Argon2 hash of *plain*."""
|
||||
return _pwd.hash(plain)
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
"""Return True if *plain* matches *hashed* (constant-time, SEC-06).
|
||||
|
||||
pwdlib.verify returns True/False and internally uses constant-time comparison.
|
||||
"""
|
||||
try:
|
||||
return _pwd.verify(plain, hashed)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── JWT helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_access_token(user_id: str, role: str) -> str:
|
||||
"""Return a signed JWT access token.
|
||||
|
||||
Claims: sub=user_id, role=role, typ='access', exp=now+access_token_expire_minutes.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"role": role,
|
||||
"typ": "access",
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=settings.access_token_expire_minutes),
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict:
|
||||
"""Decode and validate an access token; raises ValueError on any failure.
|
||||
|
||||
Verifies: signature, expiry, and typ='access' (T-02-01 — prevents password-reset
|
||||
tokens from being used as access tokens).
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError as exc:
|
||||
raise ValueError("Token has expired") from exc
|
||||
except jwt.PyJWTError as exc:
|
||||
raise ValueError(f"Invalid token: {exc}") from exc
|
||||
|
||||
if payload.get("typ") != "access":
|
||||
raise ValueError("Token type mismatch: expected 'access'")
|
||||
return payload
|
||||
|
||||
|
||||
def create_password_reset_token(user_id: str) -> str:
|
||||
"""Return a short-lived signed JWT for password reset.
|
||||
|
||||
Claims: sub=user_id, typ='password-reset', exp=now+3600s.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"typ": "password-reset",
|
||||
"iat": now,
|
||||
"exp": now + timedelta(seconds=3600),
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_password_reset_token(token: str) -> str:
|
||||
"""Decode a password-reset token; raises ValueError on failure.
|
||||
|
||||
Returns the user_id string.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError as exc:
|
||||
raise ValueError("Reset token has expired") from exc
|
||||
except jwt.PyJWTError as exc:
|
||||
raise ValueError(f"Invalid reset token: {exc}") from exc
|
||||
|
||||
if payload.get("typ") != "password-reset":
|
||||
raise ValueError("Token type mismatch: expected 'password-reset'")
|
||||
return payload["sub"]
|
||||
|
||||
|
||||
# ── Refresh token lifecycle ─────────────────────────────────────────────────────
|
||||
|
||||
async def create_refresh_token(session: AsyncSession, user_id: uuid.UUID) -> str:
|
||||
"""Insert a new RefreshToken row and return the raw (unhashed) token string.
|
||||
|
||||
The raw token is returned to the caller and set as an httpOnly cookie.
|
||||
Only the SHA-256 hash is stored in the database.
|
||||
"""
|
||||
raw = secrets.token_urlsafe(32)
|
||||
token_hash = hashlib.sha256(raw.encode()).hexdigest()
|
||||
now = datetime.now(timezone.utc)
|
||||
row = RefreshToken(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
expires_at=now + timedelta(days=settings.refresh_token_expire_days),
|
||||
revoked=False,
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
return raw
|
||||
|
||||
|
||||
async def rotate_refresh_token(
|
||||
session: AsyncSession, raw_token: str
|
||||
) -> tuple[str, str]:
|
||||
"""Rotate a refresh token: revoke the old one and issue a new one.
|
||||
|
||||
Returns (new_raw_token, user_id_str).
|
||||
|
||||
Family revocation (AUTH-07 / RFC 9700):
|
||||
If the presented token has revoked=True, ALL tokens for that user are
|
||||
revoked and send_security_alert_email.delay is enqueued. Then raises
|
||||
ValueError("token_family_revoked").
|
||||
|
||||
Raises ValueError for any invalid or expired token.
|
||||
"""
|
||||
token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
row: Optional[RefreshToken] = result.scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
raise ValueError("Refresh token not found")
|
||||
|
||||
if row.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
raise ValueError("Refresh token has expired")
|
||||
|
||||
if row.revoked:
|
||||
# T-02-02: reuse of revoked token — family revocation
|
||||
await revoke_all_refresh_tokens(session, row.user_id)
|
||||
# Enqueue security alert email (deferred import to avoid circular dependency)
|
||||
from tasks.email_tasks import send_security_alert_email # noqa: PLC0415
|
||||
send_security_alert_email.delay(str(row.user_id))
|
||||
raise ValueError("token_family_revoked")
|
||||
|
||||
# Valid token: revoke the old one
|
||||
row.revoked = True
|
||||
await session.flush()
|
||||
|
||||
# Issue a new token in the same family
|
||||
new_raw = await create_refresh_token(session, row.user_id)
|
||||
return new_raw, str(row.user_id)
|
||||
|
||||
|
||||
async def revoke_all_refresh_tokens(
|
||||
session: AsyncSession, user_id: uuid.UUID
|
||||
) -> int:
|
||||
"""Mark all active refresh tokens for user_id as revoked.
|
||||
|
||||
Returns the count of revoked tokens (supports sign-out-all-devices).
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.user_id == user_id,
|
||||
RefreshToken.revoked.is_(False),
|
||||
)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
count = 0
|
||||
for row in rows:
|
||||
row.revoked = True
|
||||
count += 1
|
||||
await session.flush()
|
||||
return count
|
||||
|
||||
|
||||
# ── TOTP provisioning ───────────────────────────────────────────────────────────
|
||||
|
||||
async def provision_totp(
|
||||
session: AsyncSession, user_id: uuid.UUID
|
||||
) -> tuple[str, str]:
|
||||
"""Generate a new TOTP secret for the user and store it (not yet enabled).
|
||||
|
||||
Returns (secret, provisioning_uri).
|
||||
The secret is base32-encoded. provisioning_uri is suitable for QR code generation.
|
||||
"""
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
raise ValueError("User not found")
|
||||
|
||||
secret = pyotp.random_base32()
|
||||
user.totp_secret = secret
|
||||
await session.commit()
|
||||
|
||||
totp = pyotp.totp.TOTP(secret)
|
||||
uri = totp.provisioning_uri(user.email, issuer_name="DocuVault")
|
||||
return secret, uri
|
||||
|
||||
|
||||
async def verify_totp(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
code: str,
|
||||
redis_client,
|
||||
) -> bool:
|
||||
"""Verify a TOTP code with replay prevention (AUTH-08).
|
||||
|
||||
valid_window=1 allows ±30s clock drift (per STATE.md recommendation).
|
||||
Replay prevention: stores used codes in Redis with key 'totp_used:{user_id}:{code}'
|
||||
and TTL=90s (covers the full ±30s validity window).
|
||||
|
||||
Returns False if the code is invalid, already used, or user has no TOTP secret.
|
||||
"""
|
||||
user = await session.get(User, user_id)
|
||||
if user is None or not user.totp_secret:
|
||||
return False
|
||||
|
||||
totp = pyotp.TOTP(user.totp_secret)
|
||||
|
||||
# Check replay prevention before verifying
|
||||
replay_key = f"totp_used:{user_id}:{code}"
|
||||
if await redis_client.get(replay_key):
|
||||
return False # Code already used within the validity window
|
||||
|
||||
if not totp.verify(code, valid_window=1):
|
||||
return False
|
||||
|
||||
# Mark as used for 90s (covers valid_window=1: ±30s = 90s total)
|
||||
await redis_client.set(replay_key, "1", ex=90)
|
||||
return True
|
||||
|
||||
|
||||
# ── Backup codes ────────────────────────────────────────────────────────────────
|
||||
|
||||
def generate_backup_codes(n: int = 10) -> list[str]:
|
||||
"""Return *n* random 8-character uppercase alphanumeric backup codes."""
|
||||
codes = []
|
||||
for _ in range(n):
|
||||
# secrets.token_hex(4) returns 8 hex chars; uppercase for readability
|
||||
code = secrets.token_hex(4).upper()
|
||||
codes.append(code)
|
||||
return codes
|
||||
|
||||
|
||||
async def store_backup_codes(
|
||||
session: AsyncSession, user_id: uuid.UUID, codes: list[str]
|
||||
) -> None:
|
||||
"""Store backup codes for a user, replacing any existing unused codes.
|
||||
|
||||
Each code is stored as an Argon2 hash (never plaintext, T-02-03).
|
||||
Existing unused codes are deleted first to prevent accumulation.
|
||||
"""
|
||||
# Delete existing unused codes
|
||||
result = await session.execute(
|
||||
select(BackupCode).where(
|
||||
BackupCode.user_id == user_id,
|
||||
BackupCode.used_at.is_(None),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
await session.delete(row)
|
||||
await session.flush()
|
||||
|
||||
# Insert new hashed codes
|
||||
for code in codes:
|
||||
row = BackupCode(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
code_hash=hash_password(code),
|
||||
used_at=None,
|
||||
)
|
||||
session.add(row)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def verify_backup_code(
|
||||
session: AsyncSession, user_id: uuid.UUID, code: str
|
||||
) -> bool:
|
||||
"""Verify a backup code using constant-time comparison.
|
||||
|
||||
Always iterates ALL unused codes (prevents timing-based enumeration, SEC-06).
|
||||
On match: sets used_at=now() and commits. Returns True on first match.
|
||||
Returns False if no code matches.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(BackupCode).where(
|
||||
BackupCode.user_id == user_id,
|
||||
BackupCode.used_at.is_(None),
|
||||
)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
|
||||
matched_row: Optional[BackupCode] = None
|
||||
for row in rows:
|
||||
# Always call verify_password for ALL rows (constant-time: no early exit)
|
||||
if verify_password(code, row.code_hash):
|
||||
matched_row = row # record match but keep iterating
|
||||
|
||||
if matched_row is None:
|
||||
return False
|
||||
|
||||
# Mark the matched code as used
|
||||
matched_row.used_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
|
||||
# ── HaveIBeenPwned check ────────────────────────────────────────────────────────
|
||||
|
||||
async def check_hibp(password: str) -> bool:
|
||||
"""Check if password appears in HaveIBeenPwned using the k-anonymity model.
|
||||
|
||||
Sends only the first 5 chars of the SHA-1 hash to the HIBP API (T-02-05).
|
||||
Returns True if the password has been breached, False otherwise.
|
||||
On network error: logs a warning and returns False (fail-open, T-02-06).
|
||||
"""
|
||||
sha1 = hashlib.sha1(password.encode("utf-8")).hexdigest().upper()
|
||||
prefix, suffix = sha1[:5], sha1[5:]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(
|
||||
f"https://api.pwnedpasswords.com/range/{prefix}",
|
||||
headers={"Add-Padding": "true"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except Exception as exc:
|
||||
logger.warning("HIBP check failed (fail-open): %s", exc)
|
||||
return False
|
||||
|
||||
for line in resp.text.splitlines():
|
||||
parts = line.split(":")
|
||||
if len(parts) == 2:
|
||||
candidate_suffix, count_str = parts
|
||||
if hmac.compare_digest(candidate_suffix.upper(), suffix):
|
||||
try:
|
||||
count = int(count_str.strip())
|
||||
except ValueError:
|
||||
count = 1
|
||||
if count > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ── Admin bootstrap ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def bootstrap_admin(session: AsyncSession) -> None:
|
||||
"""Idempotent admin account bootstrap (D-04, D-05, D-06).
|
||||
|
||||
If the users table is empty AND settings.admin_email and settings.admin_password
|
||||
are both non-empty, creates an admin User row with a Quota row.
|
||||
|
||||
Logs a WARNING if env vars are missing (D-05) but never raises.
|
||||
"""
|
||||
if not settings.admin_email or not settings.admin_password:
|
||||
logger.warning(
|
||||
"Admin bootstrap skipped: ADMIN_EMAIL and/or ADMIN_PASSWORD not set (D-05). "
|
||||
"Set both env vars to seed the first admin account on startup."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if any users exist
|
||||
result = await session.execute(select(User).limit(1))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
# Users already exist — idempotent, skip (D-04)
|
||||
return
|
||||
|
||||
admin_id = uuid.uuid4()
|
||||
admin_user = User(
|
||||
id=admin_id,
|
||||
handle="admin",
|
||||
email=settings.admin_email,
|
||||
password_hash=hash_password(settings.admin_password),
|
||||
role="admin",
|
||||
is_active=True,
|
||||
password_must_change=False,
|
||||
)
|
||||
quota = Quota(
|
||||
user_id=admin_id,
|
||||
limit_bytes=104857600, # 100 MB default (D-06)
|
||||
used_bytes=0,
|
||||
)
|
||||
session.add(admin_user)
|
||||
session.add(quota)
|
||||
await session.commit()
|
||||
logger.info("Admin account bootstrapped for %s", settings.admin_email)
|
||||
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Email service — pure Python, no FastAPI coupling.
|
||||
|
||||
Sends transactional emails via SMTP when SMTP_HOST is configured;
|
||||
logs the content to stdout otherwise (D-02 dev fallback).
|
||||
|
||||
Security notes:
|
||||
- Never raises: email failures are non-fatal; log and return
|
||||
- Celery task wrapper handles retry/error reporting
|
||||
"""
|
||||
import logging
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def send_password_reset_email(to_address: str, reset_link: str) -> None:
|
||||
"""Send (or log) the password reset email.
|
||||
|
||||
When SMTP_HOST is not configured (dev / CI), logs the reset link to stdout
|
||||
per D-02. The API response is 202 regardless — no token in the body.
|
||||
|
||||
Never raises — failures are logged and the function returns normally.
|
||||
"""
|
||||
from config import settings # deferred to avoid module-level side effects
|
||||
|
||||
if not settings.smtp_host:
|
||||
# D-02: dev fallback — log token link to stdout
|
||||
logger.info("DEV MODE — password reset link for %s: %s", to_address, reset_link)
|
||||
return
|
||||
|
||||
try:
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = "DocuVault — password reset"
|
||||
msg["From"] = settings.smtp_from
|
||||
msg["To"] = to_address
|
||||
|
||||
text_body = (
|
||||
f"You requested a password reset for DocuVault.\n\n"
|
||||
f"Click the link below (valid for 1 hour):\n{reset_link}\n\n"
|
||||
"If you did not request this, ignore this email."
|
||||
)
|
||||
html_body = (
|
||||
f"<p>You requested a password reset for DocuVault.</p>"
|
||||
f"<p><a href='{reset_link}'>Reset your password</a> (valid 1 hour)</p>"
|
||||
f"<p>If you did not request this, ignore this email.</p>"
|
||||
)
|
||||
msg.attach(MIMEText(text_body, "plain"))
|
||||
msg.attach(MIMEText(html_body, "html"))
|
||||
|
||||
with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server:
|
||||
server.ehlo()
|
||||
server.starttls()
|
||||
if settings.smtp_user:
|
||||
server.login(settings.smtp_user, settings.smtp_password)
|
||||
server.sendmail(settings.smtp_from, [to_address], msg.as_string())
|
||||
|
||||
logger.info("Password reset email sent to %s", to_address)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to send password reset email to %s: %s", to_address, exc)
|
||||
|
||||
|
||||
def send_security_alert_email_sync(to_address: str, user_id: str) -> None:
|
||||
"""Send (or log) a security alert email about suspicious refresh token reuse.
|
||||
|
||||
Called by the email_tasks.py Celery task after looking up the user email.
|
||||
Never raises — failures are logged.
|
||||
"""
|
||||
from config import settings # deferred import
|
||||
|
||||
if not settings.smtp_host:
|
||||
logger.warning(
|
||||
"Security alert for user %s: suspicious refresh token reuse detected "
|
||||
"(SMTP not configured — email not sent to %s)",
|
||||
user_id,
|
||||
to_address,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = "DocuVault — security alert: suspicious login detected"
|
||||
msg["From"] = settings.smtp_from
|
||||
msg["To"] = to_address
|
||||
|
||||
text_body = (
|
||||
"DocuVault detected suspicious activity on your account.\n\n"
|
||||
"A previously revoked refresh token was used to attempt a session refresh. "
|
||||
"All active sessions have been revoked as a precaution.\n\n"
|
||||
"If this was not you, please change your password immediately."
|
||||
)
|
||||
html_body = (
|
||||
"<p><strong>DocuVault security alert</strong></p>"
|
||||
"<p>A previously revoked refresh token was used to attempt a session refresh. "
|
||||
"All active sessions have been revoked as a precaution.</p>"
|
||||
"<p>If this was not you, please change your password immediately.</p>"
|
||||
)
|
||||
msg.attach(MIMEText(text_body, "plain"))
|
||||
msg.attach(MIMEText(html_body, "html"))
|
||||
|
||||
with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server:
|
||||
server.ehlo()
|
||||
server.starttls()
|
||||
if settings.smtp_user:
|
||||
server.login(settings.smtp_user, settings.smtp_password)
|
||||
server.sendmail(settings.smtp_from, [to_address], msg.as_string())
|
||||
|
||||
logger.info("Security alert email sent to %s (user_id=%s)", to_address, user_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to send security alert email to %s (user_id=%s): %s",
|
||||
to_address, user_id, exc
|
||||
)
|
||||
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Celery tasks for email dispatch in DocuVault.
|
||||
|
||||
Tasks follow the same pattern as document_tasks.py:
|
||||
- Plain sync def (Celery workers have no asyncio event loop by default)
|
||||
- Async body via asyncio.run()
|
||||
- All imports deferred inside _run_* functions to avoid circular imports
|
||||
(see celery_app.py comment — do NOT import config at module level)
|
||||
|
||||
Tasks registered here:
|
||||
send_reset_email — dispatches a password-reset email to the user
|
||||
send_security_alert_email — dispatches a security alert on refresh token reuse (AUTH-07)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from celery_app import celery_app
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.email_tasks.send_reset_email")
|
||||
def send_reset_email(to_address: str, reset_link: str) -> dict:
|
||||
"""Synchronous Celery entry-point — send a password reset email.
|
||||
|
||||
Called as: send_reset_email.delay(to_address, reset_link)
|
||||
Delegates to the async body via asyncio.run().
|
||||
"""
|
||||
return asyncio.run(_run_send_reset(to_address, reset_link))
|
||||
|
||||
|
||||
async def _run_send_reset(to_address: str, reset_link: str) -> dict:
|
||||
"""Async body of send_reset_email. Deferred imports to avoid circular deps."""
|
||||
from services.email import send_password_reset_email # noqa: PLC0415
|
||||
try:
|
||||
send_password_reset_email(to_address, reset_link)
|
||||
return {"status": "sent", "to": to_address}
|
||||
except Exception as exc:
|
||||
return {"status": "failed", "error": str(exc)}
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.email_tasks.send_security_alert_email")
|
||||
def send_security_alert_email(user_id: str) -> dict:
|
||||
"""Synchronous Celery entry-point — send a security alert on token reuse.
|
||||
|
||||
Called as: send_security_alert_email.delay(user_id)
|
||||
Fetches the user's email from the DB inside the task using asyncio.run().
|
||||
On SMTP not configured: logs a WARNING per D-02 convention.
|
||||
"""
|
||||
return asyncio.run(_run_send_security_alert(user_id))
|
||||
|
||||
|
||||
async def _run_send_security_alert(user_id: str) -> dict:
|
||||
"""Async body of send_security_alert_email. Deferred imports to avoid circular deps."""
|
||||
import uuid as _uuid # noqa: PLC0415
|
||||
|
||||
from db.session import AsyncSessionLocal # noqa: PLC0415
|
||||
from db.models import User # noqa: PLC0415
|
||||
from services.email import send_security_alert_email_sync # noqa: PLC0415
|
||||
|
||||
try:
|
||||
user_uuid = _uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return {"status": "failed", "error": f"Invalid user_id: {user_id}"}
|
||||
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
user = await session.get(User, user_uuid)
|
||||
if user is None:
|
||||
return {"status": "failed", "error": f"User {user_id} not found"}
|
||||
to_address = user.email
|
||||
|
||||
send_security_alert_email_sync(to_address, user_id)
|
||||
return {"status": "sent", "user_id": user_id}
|
||||
except Exception as exc:
|
||||
return {"status": "failed", "error": str(exc)}
|
||||
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
TDD tests for Task 2: services/auth.py and tasks/email_tasks.py.
|
||||
|
||||
These tests should FAIL before implementation (RED phase).
|
||||
"""
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
|
||||
def test_hash_password_returns_argon2_hash():
|
||||
from services.auth import hash_password
|
||||
h = hash_password("TestPass123!")
|
||||
assert h.startswith("$argon2")
|
||||
|
||||
|
||||
def test_verify_password_correct():
|
||||
from services.auth import hash_password, verify_password
|
||||
h = hash_password("TestPass123!")
|
||||
assert verify_password("TestPass123!", h) is True
|
||||
|
||||
|
||||
def test_verify_password_wrong():
|
||||
from services.auth import hash_password, verify_password
|
||||
h = hash_password("TestPass123!")
|
||||
assert verify_password("WrongPass!", h) is False
|
||||
|
||||
|
||||
def test_create_access_token_jwt_format():
|
||||
from services.auth import create_access_token
|
||||
t = create_access_token("test-uid", "user")
|
||||
assert isinstance(t, str)
|
||||
assert t.count(".") == 2 # JWT has 3 parts separated by 2 dots
|
||||
|
||||
|
||||
def test_decode_access_token_valid():
|
||||
from services.auth import create_access_token, decode_access_token
|
||||
t = create_access_token("test-uid", "user")
|
||||
payload = decode_access_token(t)
|
||||
assert payload["sub"] == "test-uid"
|
||||
assert payload["role"] == "user"
|
||||
assert payload["typ"] == "access"
|
||||
|
||||
|
||||
def test_decode_access_token_tampered_raises():
|
||||
from services.auth import create_access_token, decode_access_token
|
||||
t = create_access_token("test-uid", "user")
|
||||
tampered = t[:-5] + "XXXXX"
|
||||
with pytest.raises(ValueError):
|
||||
decode_access_token(tampered)
|
||||
|
||||
|
||||
def test_decode_access_token_wrong_typ_raises():
|
||||
"""An access token must have typ='access'; other typ values should raise ValueError."""
|
||||
from services.auth import create_password_reset_token, decode_access_token
|
||||
reset_token = create_password_reset_token("test-uid")
|
||||
with pytest.raises(ValueError):
|
||||
decode_access_token(reset_token)
|
||||
|
||||
|
||||
def test_generate_backup_codes():
|
||||
from services.auth import generate_backup_codes
|
||||
codes = generate_backup_codes(10)
|
||||
assert len(codes) == 10
|
||||
for code in codes:
|
||||
assert isinstance(code, str)
|
||||
assert len(code) == 8
|
||||
|
||||
|
||||
def test_generate_backup_codes_alphanumeric():
|
||||
from services.auth import generate_backup_codes
|
||||
codes = generate_backup_codes(10)
|
||||
for code in codes:
|
||||
assert code.isalnum(), f"Code '{code}' contains non-alphanumeric chars"
|
||||
|
||||
|
||||
def test_create_password_reset_token():
|
||||
from services.auth import create_password_reset_token, decode_password_reset_token
|
||||
token = create_password_reset_token("test-uid")
|
||||
assert token.count(".") == 2
|
||||
uid = decode_password_reset_token(token)
|
||||
assert uid == "test-uid"
|
||||
|
||||
|
||||
def test_decode_password_reset_token_wrong_typ_raises():
|
||||
"""An access token must not be accepted as a password reset token."""
|
||||
from services.auth import create_access_token, decode_password_reset_token
|
||||
access_token = create_access_token("test-uid", "user")
|
||||
with pytest.raises(ValueError):
|
||||
decode_password_reset_token(access_token)
|
||||
|
||||
|
||||
def test_no_fastapi_imports_in_auth_service():
|
||||
"""services/auth.py must not import FastAPI or raise HTTPException."""
|
||||
import os
|
||||
import re
|
||||
path = os.path.join(os.path.dirname(__file__), "..", "services", "auth.py")
|
||||
with open(path) as f:
|
||||
source = f.read()
|
||||
# Check for actual imports (not docstring mentions)
|
||||
assert not re.search(r"^(?:from|import)\s+fastapi", source, re.MULTILINE), \
|
||||
"services/auth.py must not import from fastapi"
|
||||
# Check no raise HTTPException in actual code (not docstrings)
|
||||
assert not re.search(r"raise\s+HTTPException", source), \
|
||||
"services/auth.py must not raise HTTPException"
|
||||
|
||||
|
||||
def test_security_alert_email_referenced_in_auth():
|
||||
"""rotate_refresh_token must reference send_security_alert_email (AUTH-07)."""
|
||||
import os
|
||||
path = os.path.join(os.path.dirname(__file__), "..", "services", "auth.py")
|
||||
with open(path) as f:
|
||||
source = f.read()
|
||||
assert "send_security_alert_email" in source
|
||||
|
||||
|
||||
def test_email_tasks_importable():
|
||||
"""email_tasks.py must define both tasks (may skip if celery not installed locally)."""
|
||||
try:
|
||||
from tasks.email_tasks import send_reset_email, send_security_alert_email
|
||||
assert callable(send_reset_email)
|
||||
assert callable(send_security_alert_email)
|
||||
except ModuleNotFoundError as e:
|
||||
if "celery" in str(e):
|
||||
pytest.skip("celery not installed in local test environment (runs in Docker)")
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_rotate_refresh_token(db_session):
|
||||
"""create_refresh_token creates a DB row; rotate_refresh_token returns new token."""
|
||||
import uuid as _uuid
|
||||
from db.models import User, Quota
|
||||
from services.auth import (
|
||||
create_refresh_token, rotate_refresh_token, hash_password
|
||||
)
|
||||
|
||||
# Create a user
|
||||
user = User(
|
||||
id=_uuid.uuid4(),
|
||||
handle="testuser",
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("pass"),
|
||||
role="user",
|
||||
)
|
||||
quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0)
|
||||
db_session.add(user)
|
||||
db_session.add(quota)
|
||||
await db_session.commit()
|
||||
|
||||
# Create a refresh token
|
||||
raw = await create_refresh_token(db_session, user.id)
|
||||
assert isinstance(raw, str)
|
||||
assert len(raw) > 10
|
||||
|
||||
# Rotate the token
|
||||
new_raw, user_id_str = await rotate_refresh_token(db_session, raw)
|
||||
assert isinstance(new_raw, str)
|
||||
assert user_id_str == str(user.id)
|
||||
assert new_raw != raw
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rotate_revoked_token_raises(db_session):
|
||||
"""Rotating a revoked token should raise ValueError('token_family_revoked')."""
|
||||
import uuid as _uuid
|
||||
from unittest.mock import patch, MagicMock
|
||||
from db.models import User, Quota
|
||||
from services.auth import (
|
||||
create_refresh_token, rotate_refresh_token, hash_password
|
||||
)
|
||||
|
||||
user = User(
|
||||
id=_uuid.uuid4(),
|
||||
handle="testuser2",
|
||||
email="test2@example.com",
|
||||
password_hash=hash_password("pass"),
|
||||
role="user",
|
||||
)
|
||||
quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0)
|
||||
db_session.add(user)
|
||||
db_session.add(quota)
|
||||
await db_session.commit()
|
||||
|
||||
raw = await create_refresh_token(db_session, user.id)
|
||||
# First rotate (valid)
|
||||
new_raw, _ = await rotate_refresh_token(db_session, raw)
|
||||
|
||||
# Now try to re-use the original (now revoked) token.
|
||||
# Mock the celery task at the point it is imported inside rotate_refresh_token.
|
||||
mock_task = MagicMock()
|
||||
mock_task.delay = MagicMock()
|
||||
with patch.dict("sys.modules", {"tasks.email_tasks": MagicMock(send_security_alert_email=mock_task)}):
|
||||
with pytest.raises(ValueError, match="token_family_revoked"):
|
||||
await rotate_refresh_token(db_session, raw)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_and_verify_backup_codes(db_session):
|
||||
"""store_backup_codes inserts rows; verify_backup_code matches correct code."""
|
||||
import uuid as _uuid
|
||||
from db.models import User, Quota
|
||||
from services.auth import (
|
||||
generate_backup_codes, store_backup_codes, verify_backup_code, hash_password
|
||||
)
|
||||
|
||||
user = User(
|
||||
id=_uuid.uuid4(),
|
||||
handle="backupuser",
|
||||
email="backup@example.com",
|
||||
password_hash=hash_password("pass"),
|
||||
role="user",
|
||||
)
|
||||
quota = Quota(user_id=user.id, limit_bytes=104857600, used_bytes=0)
|
||||
db_session.add(user)
|
||||
db_session.add(quota)
|
||||
await db_session.commit()
|
||||
|
||||
codes = generate_backup_codes(10)
|
||||
await store_backup_codes(db_session, user.id, codes)
|
||||
|
||||
# Verify with correct code
|
||||
assert await verify_backup_code(db_session, user.id, codes[0]) is True
|
||||
# Verify with already-used code should return False
|
||||
assert await verify_backup_code(db_session, user.id, codes[0]) is False
|
||||
# Verify with wrong code
|
||||
assert await verify_backup_code(db_session, user.id, "XXXXXXXX") is False
|
||||
Reference in New Issue
Block a user