a5994d9ff4
Includes planning artifacts (03-CONTEXT, 03-DISCUSSION-LOG, 03-02-SUMMARY), integration test script, MinIO/auth/docker fixes, and local dev account reference. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
430 lines
15 KiB
Python
430 lines
15 KiB
Python
"""
|
|
Auth service — pure Python, no FastAPI coupling.
|
|
|
|
Handles:
|
|
- Password hashing (Argon2 via pwdlib) and constant-time verification (SEC-06)
|
|
- JWT access token creation/decode (PyJWT)
|
|
- Refresh token lifecycle with family revocation on reuse (AUTH-07, RFC 9700)
|
|
- TOTP provisioning and verification with replay prevention (AUTH-08)
|
|
- Backup code generation, storage, and constant-time verification (AUTH-02)
|
|
- HaveIBeenPwned k-anonymity check (SEC-03)
|
|
- Admin account bootstrap (D-04, D-05, D-06)
|
|
|
|
Security invariants:
|
|
- All token/code comparisons use hmac.compare_digest (constant-time, SEC-06)
|
|
- No function raises HTTPException — callers (api/) map ValueError to HTTP errors
|
|
- refresh token family revocation enqueues send_security_alert_email.delay (AUTH-07)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import hmac
|
|
import logging
|
|
import secrets
|
|
import uuid
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
import jwt
|
|
import pyotp
|
|
from pwdlib import PasswordHash
|
|
from pwdlib.hashers.argon2 import Argon2Hasher
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from config import settings
|
|
from db.models import BackupCode, Quota, RefreshToken, User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ── Password hashing ────────────────────────────────────────────────────────────
|
|
# Single shared PasswordHash instance; Argon2 is the only enabled hasher.
|
|
_pwd = PasswordHash([Argon2Hasher()])
|
|
|
|
|
|
def hash_password(plain: str) -> str:
|
|
"""Return the Argon2 hash of *plain*."""
|
|
return _pwd.hash(plain)
|
|
|
|
|
|
def verify_password(plain: str, hashed: str) -> bool:
|
|
"""Return True if *plain* matches *hashed* (constant-time, SEC-06).
|
|
|
|
pwdlib.verify returns True/False and internally uses constant-time comparison.
|
|
"""
|
|
try:
|
|
return _pwd.verify(plain, hashed)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
# ── JWT helpers ─────────────────────────────────────────────────────────────────
|
|
|
|
def create_access_token(user_id: str, role: str) -> str:
|
|
"""Return a signed JWT access token.
|
|
|
|
Claims: sub=user_id, role=role, typ='access', exp=now+access_token_expire_minutes.
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
payload = {
|
|
"sub": str(user_id),
|
|
"role": role,
|
|
"typ": "access",
|
|
"iat": now,
|
|
"exp": now + timedelta(minutes=settings.access_token_expire_minutes),
|
|
}
|
|
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
|
|
|
|
|
|
def decode_access_token(token: str) -> dict:
|
|
"""Decode and validate an access token; raises ValueError on any failure.
|
|
|
|
Verifies: signature, expiry, and typ='access' (T-02-01 — prevents password-reset
|
|
tokens from being used as access tokens).
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
|
|
except jwt.ExpiredSignatureError as exc:
|
|
raise ValueError("Token has expired") from exc
|
|
except jwt.PyJWTError as exc:
|
|
raise ValueError(f"Invalid token: {exc}") from exc
|
|
|
|
if payload.get("typ") != "access":
|
|
raise ValueError("Token type mismatch: expected 'access'")
|
|
return payload
|
|
|
|
|
|
def create_password_reset_token(user_id: str) -> str:
|
|
"""Return a short-lived signed JWT for password reset.
|
|
|
|
Claims: sub=user_id, typ='password-reset', exp=now+3600s.
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
payload = {
|
|
"sub": str(user_id),
|
|
"typ": "password-reset",
|
|
"iat": now,
|
|
"exp": now + timedelta(seconds=3600),
|
|
}
|
|
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
|
|
|
|
|
|
def decode_password_reset_token(token: str) -> str:
|
|
"""Decode a password-reset token; raises ValueError on failure.
|
|
|
|
Returns the user_id string.
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
|
|
except jwt.ExpiredSignatureError as exc:
|
|
raise ValueError("Reset token has expired") from exc
|
|
except jwt.PyJWTError as exc:
|
|
raise ValueError(f"Invalid reset token: {exc}") from exc
|
|
|
|
if payload.get("typ") != "password-reset":
|
|
raise ValueError("Token type mismatch: expected 'password-reset'")
|
|
return payload["sub"]
|
|
|
|
|
|
# ── Refresh token lifecycle ─────────────────────────────────────────────────────
|
|
|
|
async def create_refresh_token(session: AsyncSession, user_id: uuid.UUID) -> str:
|
|
"""Insert a new RefreshToken row and return the raw (unhashed) token string.
|
|
|
|
The raw token is returned to the caller and set as an httpOnly cookie.
|
|
Only the SHA-256 hash is stored in the database.
|
|
"""
|
|
raw = secrets.token_urlsafe(32)
|
|
token_hash = hashlib.sha256(raw.encode()).hexdigest()
|
|
now = datetime.now(timezone.utc)
|
|
row = RefreshToken(
|
|
id=uuid.uuid4(),
|
|
user_id=user_id,
|
|
token_hash=token_hash,
|
|
expires_at=now + timedelta(days=settings.refresh_token_expire_days),
|
|
revoked=False,
|
|
)
|
|
session.add(row)
|
|
await session.commit()
|
|
return raw
|
|
|
|
|
|
async def rotate_refresh_token(
|
|
session: AsyncSession, raw_token: str
|
|
) -> tuple[str, str]:
|
|
"""Rotate a refresh token: revoke the old one and issue a new one.
|
|
|
|
Returns (new_raw_token, user_id_str).
|
|
|
|
Family revocation (AUTH-07 / RFC 9700):
|
|
If the presented token has revoked=True, ALL tokens for that user are
|
|
revoked and send_security_alert_email.delay is enqueued. Then raises
|
|
ValueError("token_family_revoked").
|
|
|
|
Raises ValueError for any invalid or expired token.
|
|
"""
|
|
token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
|
|
result = await session.execute(
|
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
|
)
|
|
row: Optional[RefreshToken] = result.scalar_one_or_none()
|
|
|
|
if row is None:
|
|
raise ValueError("Refresh token not found")
|
|
|
|
if row.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
|
raise ValueError("Refresh token has expired")
|
|
|
|
if row.revoked:
|
|
# T-02-02: reuse of revoked token — family revocation
|
|
await revoke_all_refresh_tokens(session, row.user_id)
|
|
# Enqueue security alert email (deferred import to avoid circular dependency)
|
|
from tasks.email_tasks import send_security_alert_email # noqa: PLC0415
|
|
send_security_alert_email.delay(str(row.user_id))
|
|
raise ValueError("token_family_revoked")
|
|
|
|
# Valid token: revoke the old one
|
|
row.revoked = True
|
|
await session.flush()
|
|
|
|
# Issue a new token in the same family
|
|
new_raw = await create_refresh_token(session, row.user_id)
|
|
return new_raw, str(row.user_id)
|
|
|
|
|
|
async def revoke_all_refresh_tokens(
|
|
session: AsyncSession, user_id: uuid.UUID
|
|
) -> int:
|
|
"""Mark all active refresh tokens for user_id as revoked.
|
|
|
|
Returns the count of revoked tokens (supports sign-out-all-devices).
|
|
"""
|
|
result = await session.execute(
|
|
select(RefreshToken).where(
|
|
RefreshToken.user_id == user_id,
|
|
RefreshToken.revoked.is_(False),
|
|
)
|
|
)
|
|
rows = result.scalars().all()
|
|
count = 0
|
|
for row in rows:
|
|
row.revoked = True
|
|
count += 1
|
|
await session.flush()
|
|
return count
|
|
|
|
|
|
# ── TOTP provisioning ───────────────────────────────────────────────────────────
|
|
|
|
async def provision_totp(
|
|
session: AsyncSession, user_id: uuid.UUID
|
|
) -> tuple[str, str]:
|
|
"""Generate a new TOTP secret for the user and store it (not yet enabled).
|
|
|
|
Returns (secret, provisioning_uri).
|
|
The secret is base32-encoded. provisioning_uri is suitable for QR code generation.
|
|
"""
|
|
user = await session.get(User, user_id)
|
|
if user is None:
|
|
raise ValueError("User not found")
|
|
|
|
secret = pyotp.random_base32()
|
|
user.totp_secret = secret
|
|
await session.commit()
|
|
|
|
totp = pyotp.totp.TOTP(secret)
|
|
uri = totp.provisioning_uri(user.email, issuer_name="DocuVault")
|
|
return secret, uri
|
|
|
|
|
|
async def verify_totp(
|
|
session: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
code: str,
|
|
redis_client,
|
|
) -> bool:
|
|
"""Verify a TOTP code with replay prevention (AUTH-08).
|
|
|
|
valid_window=1 allows ±30s clock drift (per STATE.md recommendation).
|
|
Replay prevention: stores used codes in Redis with key 'totp_used:{user_id}:{code}'
|
|
and TTL=90s (covers the full ±30s validity window).
|
|
|
|
Returns False if the code is invalid, already used, or user has no TOTP secret.
|
|
"""
|
|
user = await session.get(User, user_id)
|
|
if user is None or not user.totp_secret:
|
|
return False
|
|
|
|
totp = pyotp.TOTP(user.totp_secret)
|
|
|
|
# Check replay prevention before verifying
|
|
replay_key = f"totp_used:{user_id}:{code}"
|
|
if await redis_client.get(replay_key):
|
|
return False # Code already used within the validity window
|
|
|
|
if not totp.verify(code, valid_window=1):
|
|
return False
|
|
|
|
# Mark as used for 90s (covers valid_window=1: ±30s = 90s total)
|
|
await redis_client.set(replay_key, "1", ex=90)
|
|
return True
|
|
|
|
|
|
# ── Backup codes ────────────────────────────────────────────────────────────────
|
|
|
|
def generate_backup_codes(n: int = 10) -> list[str]:
|
|
"""Return *n* random 8-character uppercase alphanumeric backup codes."""
|
|
codes = []
|
|
for _ in range(n):
|
|
# secrets.token_hex(4) returns 8 hex chars; uppercase for readability
|
|
code = secrets.token_hex(4).upper()
|
|
codes.append(code)
|
|
return codes
|
|
|
|
|
|
async def store_backup_codes(
|
|
session: AsyncSession, user_id: uuid.UUID, codes: list[str]
|
|
) -> None:
|
|
"""Store backup codes for a user, replacing any existing unused codes.
|
|
|
|
Each code is stored as an Argon2 hash (never plaintext, T-02-03).
|
|
Existing unused codes are deleted first to prevent accumulation.
|
|
"""
|
|
# Delete existing unused codes
|
|
result = await session.execute(
|
|
select(BackupCode).where(
|
|
BackupCode.user_id == user_id,
|
|
BackupCode.used_at.is_(None),
|
|
)
|
|
)
|
|
for row in result.scalars().all():
|
|
await session.delete(row)
|
|
await session.flush()
|
|
|
|
# Insert new hashed codes
|
|
for code in codes:
|
|
row = BackupCode(
|
|
id=uuid.uuid4(),
|
|
user_id=user_id,
|
|
code_hash=hash_password(code),
|
|
used_at=None,
|
|
)
|
|
session.add(row)
|
|
|
|
await session.commit()
|
|
|
|
|
|
async def verify_backup_code(
|
|
session: AsyncSession, user_id: uuid.UUID, code: str
|
|
) -> bool:
|
|
"""Verify a backup code using constant-time comparison.
|
|
|
|
Always iterates ALL unused codes (prevents timing-based enumeration, SEC-06).
|
|
On match: sets used_at=now() and commits. Returns True on first match.
|
|
Returns False if no code matches.
|
|
"""
|
|
result = await session.execute(
|
|
select(BackupCode).where(
|
|
BackupCode.user_id == user_id,
|
|
BackupCode.used_at.is_(None),
|
|
)
|
|
)
|
|
rows = result.scalars().all()
|
|
|
|
matched_row: Optional[BackupCode] = None
|
|
for row in rows:
|
|
# Always call verify_password for ALL rows (constant-time: no early exit)
|
|
if verify_password(code, row.code_hash):
|
|
matched_row = row # record match but keep iterating
|
|
|
|
if matched_row is None:
|
|
return False
|
|
|
|
# Mark the matched code as used
|
|
matched_row.used_at = datetime.now(timezone.utc)
|
|
await session.commit()
|
|
return True
|
|
|
|
|
|
# ── HaveIBeenPwned check ────────────────────────────────────────────────────────
|
|
|
|
async def check_hibp(password: str) -> bool:
|
|
"""Check if password appears in HaveIBeenPwned using the k-anonymity model.
|
|
|
|
Sends only the first 5 chars of the SHA-1 hash to the HIBP API (T-02-05).
|
|
Returns True if the password has been breached, False otherwise.
|
|
On network error: logs a warning and returns False (fail-open, T-02-06).
|
|
"""
|
|
sha1 = hashlib.sha1(password.encode("utf-8")).hexdigest().upper()
|
|
prefix, suffix = sha1[:5], sha1[5:]
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
resp = await client.get(
|
|
f"https://api.pwnedpasswords.com/range/{prefix}",
|
|
headers={"Add-Padding": "true"},
|
|
)
|
|
resp.raise_for_status()
|
|
except Exception as exc:
|
|
logger.warning("HIBP check failed (fail-open): %s", exc)
|
|
return False
|
|
|
|
for line in resp.text.splitlines():
|
|
parts = line.split(":")
|
|
if len(parts) == 2:
|
|
candidate_suffix, count_str = parts
|
|
if hmac.compare_digest(candidate_suffix.upper(), suffix):
|
|
try:
|
|
count = int(count_str.strip())
|
|
except ValueError:
|
|
count = 1
|
|
if count > 0:
|
|
return True
|
|
return False
|
|
|
|
|
|
# ── Admin bootstrap ─────────────────────────────────────────────────────────────
|
|
|
|
async def bootstrap_admin(session: AsyncSession) -> None:
|
|
"""Idempotent admin account bootstrap (D-04, D-05, D-06).
|
|
|
|
If the users table is empty AND settings.admin_email and settings.admin_password
|
|
are both non-empty, creates an admin User row with a Quota row.
|
|
|
|
Logs a WARNING if env vars are missing (D-05) but never raises.
|
|
"""
|
|
if not settings.admin_email or not settings.admin_password:
|
|
logger.warning(
|
|
"Admin bootstrap skipped: ADMIN_EMAIL and/or ADMIN_PASSWORD not set (D-05). "
|
|
"Set both env vars to seed the first admin account on startup."
|
|
)
|
|
return
|
|
|
|
# Check if any users exist
|
|
result = await session.execute(select(User).limit(1))
|
|
if result.scalar_one_or_none() is not None:
|
|
# Users already exist — idempotent, skip (D-04)
|
|
return
|
|
|
|
admin_id = uuid.uuid4()
|
|
admin_user = User(
|
|
id=admin_id,
|
|
handle="admin",
|
|
email=settings.admin_email,
|
|
password_hash=hash_password(settings.admin_password),
|
|
role="admin",
|
|
is_active=True,
|
|
password_must_change=False,
|
|
)
|
|
quota = Quota(
|
|
user_id=admin_id,
|
|
limit_bytes=104857600, # 100 MB default (D-06)
|
|
used_bytes=0,
|
|
)
|
|
session.add(admin_user)
|
|
await session.flush() # persist User first so Quota FK is satisfied
|
|
session.add(quota)
|
|
await session.commit()
|
|
logger.info("Admin account bootstrapped for %s", settings.admin_email)
|