feat(02-01): implement services/auth.py full auth service layer and email_tasks.py
- services/auth.py: Argon2 password hashing (pwdlib), constant-time verify (SEC-06) - JWT create/decode for access tokens and password-reset tokens (typ claim validation, T-02-01) - Refresh token lifecycle: create, rotate, revoke-all with family revocation (AUTH-07, RFC 9700) - Family revocation enqueues send_security_alert_email.delay on token reuse (T-02-02) - TOTP provisioning (pyotp) and verification with Redis replay prevention, valid_window=1 (AUTH-08) - Backup code generation (8-char hex uppercase), storage (Argon2 hashed), constant-time verify (T-02-03) - HIBP k-anonymity check via SHA-1 prefix (T-02-05), fail-open on network error (T-02-06) - Admin bootstrap: idempotent, logs WARNING if env vars missing (D-04/D-05/D-06) - services/email.py: SMTP send + dev stdout fallback (D-01/D-02) - tasks/email_tasks.py: send_reset_email and send_security_alert_email Celery tasks - celery_app.py: add email queue route for tasks.email_tasks.* - TDD tests: 17 tests covering all auth primitives and family revocation
This commit is contained in:
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
Auth service — pure Python, no FastAPI coupling.
|
||||
|
||||
Handles:
|
||||
- Password hashing (Argon2 via pwdlib) and constant-time verification (SEC-06)
|
||||
- JWT access token creation/decode (PyJWT)
|
||||
- Refresh token lifecycle with family revocation on reuse (AUTH-07, RFC 9700)
|
||||
- TOTP provisioning and verification with replay prevention (AUTH-08)
|
||||
- Backup code generation, storage, and constant-time verification (AUTH-02)
|
||||
- HaveIBeenPwned k-anonymity check (SEC-03)
|
||||
- Admin account bootstrap (D-04, D-05, D-06)
|
||||
|
||||
Security invariants:
|
||||
- All token/code comparisons use hmac.compare_digest (constant-time, SEC-06)
|
||||
- No function raises HTTPException — callers (api/) map ValueError to HTTP errors
|
||||
- refresh token family revocation enqueues send_security_alert_email.delay (AUTH-07)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import pyotp
|
||||
from pwdlib import PasswordHash
|
||||
from pwdlib.hashers.argon2 import Argon2Hasher
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import settings
|
||||
from db.models import BackupCode, Quota, RefreshToken, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Password hashing ────────────────────────────────────────────────────────────
|
||||
# Single shared PasswordHash instance; Argon2 is the only enabled hasher.
|
||||
_pwd = PasswordHash([Argon2Hasher()])
|
||||
|
||||
|
||||
def hash_password(plain: str) -> str:
|
||||
"""Return the Argon2 hash of *plain*."""
|
||||
return _pwd.hash(plain)
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
"""Return True if *plain* matches *hashed* (constant-time, SEC-06).
|
||||
|
||||
pwdlib.verify returns True/False and internally uses constant-time comparison.
|
||||
"""
|
||||
try:
|
||||
return _pwd.verify(plain, hashed)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── JWT helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_access_token(user_id: str, role: str) -> str:
|
||||
"""Return a signed JWT access token.
|
||||
|
||||
Claims: sub=user_id, role=role, typ='access', exp=now+access_token_expire_minutes.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"role": role,
|
||||
"typ": "access",
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=settings.access_token_expire_minutes),
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict:
|
||||
"""Decode and validate an access token; raises ValueError on any failure.
|
||||
|
||||
Verifies: signature, expiry, and typ='access' (T-02-01 — prevents password-reset
|
||||
tokens from being used as access tokens).
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError as exc:
|
||||
raise ValueError("Token has expired") from exc
|
||||
except jwt.PyJWTError as exc:
|
||||
raise ValueError(f"Invalid token: {exc}") from exc
|
||||
|
||||
if payload.get("typ") != "access":
|
||||
raise ValueError("Token type mismatch: expected 'access'")
|
||||
return payload
|
||||
|
||||
|
||||
def create_password_reset_token(user_id: str) -> str:
|
||||
"""Return a short-lived signed JWT for password reset.
|
||||
|
||||
Claims: sub=user_id, typ='password-reset', exp=now+3600s.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"typ": "password-reset",
|
||||
"iat": now,
|
||||
"exp": now + timedelta(seconds=3600),
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_password_reset_token(token: str) -> str:
|
||||
"""Decode a password-reset token; raises ValueError on failure.
|
||||
|
||||
Returns the user_id string.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError as exc:
|
||||
raise ValueError("Reset token has expired") from exc
|
||||
except jwt.PyJWTError as exc:
|
||||
raise ValueError(f"Invalid reset token: {exc}") from exc
|
||||
|
||||
if payload.get("typ") != "password-reset":
|
||||
raise ValueError("Token type mismatch: expected 'password-reset'")
|
||||
return payload["sub"]
|
||||
|
||||
|
||||
# ── Refresh token lifecycle ─────────────────────────────────────────────────────
|
||||
|
||||
async def create_refresh_token(session: AsyncSession, user_id: uuid.UUID) -> str:
|
||||
"""Insert a new RefreshToken row and return the raw (unhashed) token string.
|
||||
|
||||
The raw token is returned to the caller and set as an httpOnly cookie.
|
||||
Only the SHA-256 hash is stored in the database.
|
||||
"""
|
||||
raw = secrets.token_urlsafe(32)
|
||||
token_hash = hashlib.sha256(raw.encode()).hexdigest()
|
||||
now = datetime.now(timezone.utc)
|
||||
row = RefreshToken(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
expires_at=now + timedelta(days=settings.refresh_token_expire_days),
|
||||
revoked=False,
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
return raw
|
||||
|
||||
|
||||
async def rotate_refresh_token(
|
||||
session: AsyncSession, raw_token: str
|
||||
) -> tuple[str, str]:
|
||||
"""Rotate a refresh token: revoke the old one and issue a new one.
|
||||
|
||||
Returns (new_raw_token, user_id_str).
|
||||
|
||||
Family revocation (AUTH-07 / RFC 9700):
|
||||
If the presented token has revoked=True, ALL tokens for that user are
|
||||
revoked and send_security_alert_email.delay is enqueued. Then raises
|
||||
ValueError("token_family_revoked").
|
||||
|
||||
Raises ValueError for any invalid or expired token.
|
||||
"""
|
||||
token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
row: Optional[RefreshToken] = result.scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
raise ValueError("Refresh token not found")
|
||||
|
||||
if row.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
raise ValueError("Refresh token has expired")
|
||||
|
||||
if row.revoked:
|
||||
# T-02-02: reuse of revoked token — family revocation
|
||||
await revoke_all_refresh_tokens(session, row.user_id)
|
||||
# Enqueue security alert email (deferred import to avoid circular dependency)
|
||||
from tasks.email_tasks import send_security_alert_email # noqa: PLC0415
|
||||
send_security_alert_email.delay(str(row.user_id))
|
||||
raise ValueError("token_family_revoked")
|
||||
|
||||
# Valid token: revoke the old one
|
||||
row.revoked = True
|
||||
await session.flush()
|
||||
|
||||
# Issue a new token in the same family
|
||||
new_raw = await create_refresh_token(session, row.user_id)
|
||||
return new_raw, str(row.user_id)
|
||||
|
||||
|
||||
async def revoke_all_refresh_tokens(
|
||||
session: AsyncSession, user_id: uuid.UUID
|
||||
) -> int:
|
||||
"""Mark all active refresh tokens for user_id as revoked.
|
||||
|
||||
Returns the count of revoked tokens (supports sign-out-all-devices).
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.user_id == user_id,
|
||||
RefreshToken.revoked.is_(False),
|
||||
)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
count = 0
|
||||
for row in rows:
|
||||
row.revoked = True
|
||||
count += 1
|
||||
await session.flush()
|
||||
return count
|
||||
|
||||
|
||||
# ── TOTP provisioning ───────────────────────────────────────────────────────────
|
||||
|
||||
async def provision_totp(
|
||||
session: AsyncSession, user_id: uuid.UUID
|
||||
) -> tuple[str, str]:
|
||||
"""Generate a new TOTP secret for the user and store it (not yet enabled).
|
||||
|
||||
Returns (secret, provisioning_uri).
|
||||
The secret is base32-encoded. provisioning_uri is suitable for QR code generation.
|
||||
"""
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
raise ValueError("User not found")
|
||||
|
||||
secret = pyotp.random_base32()
|
||||
user.totp_secret = secret
|
||||
await session.commit()
|
||||
|
||||
totp = pyotp.totp.TOTP(secret)
|
||||
uri = totp.provisioning_uri(user.email, issuer_name="DocuVault")
|
||||
return secret, uri
|
||||
|
||||
|
||||
async def verify_totp(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
code: str,
|
||||
redis_client,
|
||||
) -> bool:
|
||||
"""Verify a TOTP code with replay prevention (AUTH-08).
|
||||
|
||||
valid_window=1 allows ±30s clock drift (per STATE.md recommendation).
|
||||
Replay prevention: stores used codes in Redis with key 'totp_used:{user_id}:{code}'
|
||||
and TTL=90s (covers the full ±30s validity window).
|
||||
|
||||
Returns False if the code is invalid, already used, or user has no TOTP secret.
|
||||
"""
|
||||
user = await session.get(User, user_id)
|
||||
if user is None or not user.totp_secret:
|
||||
return False
|
||||
|
||||
totp = pyotp.TOTP(user.totp_secret)
|
||||
|
||||
# Check replay prevention before verifying
|
||||
replay_key = f"totp_used:{user_id}:{code}"
|
||||
if await redis_client.get(replay_key):
|
||||
return False # Code already used within the validity window
|
||||
|
||||
if not totp.verify(code, valid_window=1):
|
||||
return False
|
||||
|
||||
# Mark as used for 90s (covers valid_window=1: ±30s = 90s total)
|
||||
await redis_client.set(replay_key, "1", ex=90)
|
||||
return True
|
||||
|
||||
|
||||
# ── Backup codes ────────────────────────────────────────────────────────────────
|
||||
|
||||
def generate_backup_codes(n: int = 10) -> list[str]:
|
||||
"""Return *n* random 8-character uppercase alphanumeric backup codes."""
|
||||
codes = []
|
||||
for _ in range(n):
|
||||
# secrets.token_hex(4) returns 8 hex chars; uppercase for readability
|
||||
code = secrets.token_hex(4).upper()
|
||||
codes.append(code)
|
||||
return codes
|
||||
|
||||
|
||||
async def store_backup_codes(
|
||||
session: AsyncSession, user_id: uuid.UUID, codes: list[str]
|
||||
) -> None:
|
||||
"""Store backup codes for a user, replacing any existing unused codes.
|
||||
|
||||
Each code is stored as an Argon2 hash (never plaintext, T-02-03).
|
||||
Existing unused codes are deleted first to prevent accumulation.
|
||||
"""
|
||||
# Delete existing unused codes
|
||||
result = await session.execute(
|
||||
select(BackupCode).where(
|
||||
BackupCode.user_id == user_id,
|
||||
BackupCode.used_at.is_(None),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
await session.delete(row)
|
||||
await session.flush()
|
||||
|
||||
# Insert new hashed codes
|
||||
for code in codes:
|
||||
row = BackupCode(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
code_hash=hash_password(code),
|
||||
used_at=None,
|
||||
)
|
||||
session.add(row)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def verify_backup_code(
|
||||
session: AsyncSession, user_id: uuid.UUID, code: str
|
||||
) -> bool:
|
||||
"""Verify a backup code using constant-time comparison.
|
||||
|
||||
Always iterates ALL unused codes (prevents timing-based enumeration, SEC-06).
|
||||
On match: sets used_at=now() and commits. Returns True on first match.
|
||||
Returns False if no code matches.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(BackupCode).where(
|
||||
BackupCode.user_id == user_id,
|
||||
BackupCode.used_at.is_(None),
|
||||
)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
|
||||
matched_row: Optional[BackupCode] = None
|
||||
for row in rows:
|
||||
# Always call verify_password for ALL rows (constant-time: no early exit)
|
||||
if verify_password(code, row.code_hash):
|
||||
matched_row = row # record match but keep iterating
|
||||
|
||||
if matched_row is None:
|
||||
return False
|
||||
|
||||
# Mark the matched code as used
|
||||
matched_row.used_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
|
||||
# ── HaveIBeenPwned check ────────────────────────────────────────────────────────
|
||||
|
||||
async def check_hibp(password: str) -> bool:
|
||||
"""Check if password appears in HaveIBeenPwned using the k-anonymity model.
|
||||
|
||||
Sends only the first 5 chars of the SHA-1 hash to the HIBP API (T-02-05).
|
||||
Returns True if the password has been breached, False otherwise.
|
||||
On network error: logs a warning and returns False (fail-open, T-02-06).
|
||||
"""
|
||||
sha1 = hashlib.sha1(password.encode("utf-8")).hexdigest().upper()
|
||||
prefix, suffix = sha1[:5], sha1[5:]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(
|
||||
f"https://api.pwnedpasswords.com/range/{prefix}",
|
||||
headers={"Add-Padding": "true"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except Exception as exc:
|
||||
logger.warning("HIBP check failed (fail-open): %s", exc)
|
||||
return False
|
||||
|
||||
for line in resp.text.splitlines():
|
||||
parts = line.split(":")
|
||||
if len(parts) == 2:
|
||||
candidate_suffix, count_str = parts
|
||||
if hmac.compare_digest(candidate_suffix.upper(), suffix):
|
||||
try:
|
||||
count = int(count_str.strip())
|
||||
except ValueError:
|
||||
count = 1
|
||||
if count > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ── Admin bootstrap ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def bootstrap_admin(session: AsyncSession) -> None:
|
||||
"""Idempotent admin account bootstrap (D-04, D-05, D-06).
|
||||
|
||||
If the users table is empty AND settings.admin_email and settings.admin_password
|
||||
are both non-empty, creates an admin User row with a Quota row.
|
||||
|
||||
Logs a WARNING if env vars are missing (D-05) but never raises.
|
||||
"""
|
||||
if not settings.admin_email or not settings.admin_password:
|
||||
logger.warning(
|
||||
"Admin bootstrap skipped: ADMIN_EMAIL and/or ADMIN_PASSWORD not set (D-05). "
|
||||
"Set both env vars to seed the first admin account on startup."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if any users exist
|
||||
result = await session.execute(select(User).limit(1))
|
||||
if result.scalar_one_or_none() is not None:
|
||||
# Users already exist — idempotent, skip (D-04)
|
||||
return
|
||||
|
||||
admin_id = uuid.uuid4()
|
||||
admin_user = User(
|
||||
id=admin_id,
|
||||
handle="admin",
|
||||
email=settings.admin_email,
|
||||
password_hash=hash_password(settings.admin_password),
|
||||
role="admin",
|
||||
is_active=True,
|
||||
password_must_change=False,
|
||||
)
|
||||
quota = Quota(
|
||||
user_id=admin_id,
|
||||
limit_bytes=104857600, # 100 MB default (D-06)
|
||||
used_bytes=0,
|
||||
)
|
||||
session.add(admin_user)
|
||||
session.add(quota)
|
||||
await session.commit()
|
||||
logger.info("Admin account bootstrapped for %s", settings.admin_email)
|
||||
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Email service — pure Python, no FastAPI coupling.
|
||||
|
||||
Sends transactional emails via SMTP when SMTP_HOST is configured;
|
||||
logs the content to stdout otherwise (D-02 dev fallback).
|
||||
|
||||
Security notes:
|
||||
- Never raises: email failures are non-fatal; log and return
|
||||
- Celery task wrapper handles retry/error reporting
|
||||
"""
|
||||
import logging
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def send_password_reset_email(to_address: str, reset_link: str) -> None:
|
||||
"""Send (or log) the password reset email.
|
||||
|
||||
When SMTP_HOST is not configured (dev / CI), logs the reset link to stdout
|
||||
per D-02. The API response is 202 regardless — no token in the body.
|
||||
|
||||
Never raises — failures are logged and the function returns normally.
|
||||
"""
|
||||
from config import settings # deferred to avoid module-level side effects
|
||||
|
||||
if not settings.smtp_host:
|
||||
# D-02: dev fallback — log token link to stdout
|
||||
logger.info("DEV MODE — password reset link for %s: %s", to_address, reset_link)
|
||||
return
|
||||
|
||||
try:
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = "DocuVault — password reset"
|
||||
msg["From"] = settings.smtp_from
|
||||
msg["To"] = to_address
|
||||
|
||||
text_body = (
|
||||
f"You requested a password reset for DocuVault.\n\n"
|
||||
f"Click the link below (valid for 1 hour):\n{reset_link}\n\n"
|
||||
"If you did not request this, ignore this email."
|
||||
)
|
||||
html_body = (
|
||||
f"<p>You requested a password reset for DocuVault.</p>"
|
||||
f"<p><a href='{reset_link}'>Reset your password</a> (valid 1 hour)</p>"
|
||||
f"<p>If you did not request this, ignore this email.</p>"
|
||||
)
|
||||
msg.attach(MIMEText(text_body, "plain"))
|
||||
msg.attach(MIMEText(html_body, "html"))
|
||||
|
||||
with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server:
|
||||
server.ehlo()
|
||||
server.starttls()
|
||||
if settings.smtp_user:
|
||||
server.login(settings.smtp_user, settings.smtp_password)
|
||||
server.sendmail(settings.smtp_from, [to_address], msg.as_string())
|
||||
|
||||
logger.info("Password reset email sent to %s", to_address)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to send password reset email to %s: %s", to_address, exc)
|
||||
|
||||
|
||||
def send_security_alert_email_sync(to_address: str, user_id: str) -> None:
|
||||
"""Send (or log) a security alert email about suspicious refresh token reuse.
|
||||
|
||||
Called by the email_tasks.py Celery task after looking up the user email.
|
||||
Never raises — failures are logged.
|
||||
"""
|
||||
from config import settings # deferred import
|
||||
|
||||
if not settings.smtp_host:
|
||||
logger.warning(
|
||||
"Security alert for user %s: suspicious refresh token reuse detected "
|
||||
"(SMTP not configured — email not sent to %s)",
|
||||
user_id,
|
||||
to_address,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = "DocuVault — security alert: suspicious login detected"
|
||||
msg["From"] = settings.smtp_from
|
||||
msg["To"] = to_address
|
||||
|
||||
text_body = (
|
||||
"DocuVault detected suspicious activity on your account.\n\n"
|
||||
"A previously revoked refresh token was used to attempt a session refresh. "
|
||||
"All active sessions have been revoked as a precaution.\n\n"
|
||||
"If this was not you, please change your password immediately."
|
||||
)
|
||||
html_body = (
|
||||
"<p><strong>DocuVault security alert</strong></p>"
|
||||
"<p>A previously revoked refresh token was used to attempt a session refresh. "
|
||||
"All active sessions have been revoked as a precaution.</p>"
|
||||
"<p>If this was not you, please change your password immediately.</p>"
|
||||
)
|
||||
msg.attach(MIMEText(text_body, "plain"))
|
||||
msg.attach(MIMEText(html_body, "html"))
|
||||
|
||||
with smtplib.SMTP(settings.smtp_host, settings.smtp_port) as server:
|
||||
server.ehlo()
|
||||
server.starttls()
|
||||
if settings.smtp_user:
|
||||
server.login(settings.smtp_user, settings.smtp_password)
|
||||
server.sendmail(settings.smtp_from, [to_address], msg.as_string())
|
||||
|
||||
logger.info("Security alert email sent to %s (user_id=%s)", to_address, user_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to send security alert email to %s (user_id=%s): %s",
|
||||
to_address, user_id, exc
|
||||
)
|
||||
Reference in New Issue
Block a user