Files
kite/backend/services/auth.py
T
curo1305 a5994d9ff4 chore: commit pending phase-3 work and add TEST_ACCOUNTS.md
Includes planning artifacts (03-CONTEXT, 03-DISCUSSION-LOG, 03-02-SUMMARY),
integration test script, MinIO/auth/docker fixes, and local dev account reference.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-24 11:30:56 +02:00

430 lines
15 KiB
Python

"""
Auth service — pure Python, no FastAPI coupling.
Handles:
- Password hashing (Argon2 via pwdlib) and constant-time verification (SEC-06)
- JWT access token creation/decode (PyJWT)
- Refresh token lifecycle with family revocation on reuse (AUTH-07, RFC 9700)
- TOTP provisioning and verification with replay prevention (AUTH-08)
- Backup code generation, storage, and constant-time verification (AUTH-02)
- HaveIBeenPwned k-anonymity check (SEC-03)
- Admin account bootstrap (D-04, D-05, D-06)
Security invariants:
- All token/code comparisons use hmac.compare_digest (constant-time, SEC-06)
- No function raises HTTPException — callers (api/) map ValueError to HTTP errors
- refresh token family revocation enqueues send_security_alert_email.delay (AUTH-07)
"""
from __future__ import annotations
import hashlib
import hmac
import logging
import secrets
import uuid
from datetime import datetime, timezone, timedelta
from typing import Optional
import httpx
import jwt
import pyotp
from pwdlib import PasswordHash
from pwdlib.hashers.argon2 import Argon2Hasher
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from config import settings
from db.models import BackupCode, Quota, RefreshToken, User
logger = logging.getLogger(__name__)
# ── Password hashing ────────────────────────────────────────────────────────────
# Single shared PasswordHash instance; Argon2 is the only enabled hasher.
_pwd = PasswordHash([Argon2Hasher()])
def hash_password(plain: str) -> str:
"""Return the Argon2 hash of *plain*."""
return _pwd.hash(plain)
def verify_password(plain: str, hashed: str) -> bool:
"""Return True if *plain* matches *hashed* (constant-time, SEC-06).
pwdlib.verify returns True/False and internally uses constant-time comparison.
"""
try:
return _pwd.verify(plain, hashed)
except Exception:
return False
# ── JWT helpers ─────────────────────────────────────────────────────────────────
def create_access_token(user_id: str, role: str) -> str:
"""Return a signed JWT access token.
Claims: sub=user_id, role=role, typ='access', exp=now+access_token_expire_minutes.
"""
now = datetime.now(timezone.utc)
payload = {
"sub": str(user_id),
"role": role,
"typ": "access",
"iat": now,
"exp": now + timedelta(minutes=settings.access_token_expire_minutes),
}
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
def decode_access_token(token: str) -> dict:
"""Decode and validate an access token; raises ValueError on any failure.
Verifies: signature, expiry, and typ='access' (T-02-01 — prevents password-reset
tokens from being used as access tokens).
"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
except jwt.ExpiredSignatureError as exc:
raise ValueError("Token has expired") from exc
except jwt.PyJWTError as exc:
raise ValueError(f"Invalid token: {exc}") from exc
if payload.get("typ") != "access":
raise ValueError("Token type mismatch: expected 'access'")
return payload
def create_password_reset_token(user_id: str) -> str:
"""Return a short-lived signed JWT for password reset.
Claims: sub=user_id, typ='password-reset', exp=now+3600s.
"""
now = datetime.now(timezone.utc)
payload = {
"sub": str(user_id),
"typ": "password-reset",
"iat": now,
"exp": now + timedelta(seconds=3600),
}
return jwt.encode(payload, settings.secret_key, algorithm="HS256")
def decode_password_reset_token(token: str) -> str:
"""Decode a password-reset token; raises ValueError on failure.
Returns the user_id string.
"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=["HS256"])
except jwt.ExpiredSignatureError as exc:
raise ValueError("Reset token has expired") from exc
except jwt.PyJWTError as exc:
raise ValueError(f"Invalid reset token: {exc}") from exc
if payload.get("typ") != "password-reset":
raise ValueError("Token type mismatch: expected 'password-reset'")
return payload["sub"]
# ── Refresh token lifecycle ─────────────────────────────────────────────────────
async def create_refresh_token(session: AsyncSession, user_id: uuid.UUID) -> str:
"""Insert a new RefreshToken row and return the raw (unhashed) token string.
The raw token is returned to the caller and set as an httpOnly cookie.
Only the SHA-256 hash is stored in the database.
"""
raw = secrets.token_urlsafe(32)
token_hash = hashlib.sha256(raw.encode()).hexdigest()
now = datetime.now(timezone.utc)
row = RefreshToken(
id=uuid.uuid4(),
user_id=user_id,
token_hash=token_hash,
expires_at=now + timedelta(days=settings.refresh_token_expire_days),
revoked=False,
)
session.add(row)
await session.commit()
return raw
async def rotate_refresh_token(
session: AsyncSession, raw_token: str
) -> tuple[str, str]:
"""Rotate a refresh token: revoke the old one and issue a new one.
Returns (new_raw_token, user_id_str).
Family revocation (AUTH-07 / RFC 9700):
If the presented token has revoked=True, ALL tokens for that user are
revoked and send_security_alert_email.delay is enqueued. Then raises
ValueError("token_family_revoked").
Raises ValueError for any invalid or expired token.
"""
token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
row: Optional[RefreshToken] = result.scalar_one_or_none()
if row is None:
raise ValueError("Refresh token not found")
if row.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
raise ValueError("Refresh token has expired")
if row.revoked:
# T-02-02: reuse of revoked token — family revocation
await revoke_all_refresh_tokens(session, row.user_id)
# Enqueue security alert email (deferred import to avoid circular dependency)
from tasks.email_tasks import send_security_alert_email # noqa: PLC0415
send_security_alert_email.delay(str(row.user_id))
raise ValueError("token_family_revoked")
# Valid token: revoke the old one
row.revoked = True
await session.flush()
# Issue a new token in the same family
new_raw = await create_refresh_token(session, row.user_id)
return new_raw, str(row.user_id)
async def revoke_all_refresh_tokens(
session: AsyncSession, user_id: uuid.UUID
) -> int:
"""Mark all active refresh tokens for user_id as revoked.
Returns the count of revoked tokens (supports sign-out-all-devices).
"""
result = await session.execute(
select(RefreshToken).where(
RefreshToken.user_id == user_id,
RefreshToken.revoked.is_(False),
)
)
rows = result.scalars().all()
count = 0
for row in rows:
row.revoked = True
count += 1
await session.flush()
return count
# ── TOTP provisioning ───────────────────────────────────────────────────────────
async def provision_totp(
session: AsyncSession, user_id: uuid.UUID
) -> tuple[str, str]:
"""Generate a new TOTP secret for the user and store it (not yet enabled).
Returns (secret, provisioning_uri).
The secret is base32-encoded. provisioning_uri is suitable for QR code generation.
"""
user = await session.get(User, user_id)
if user is None:
raise ValueError("User not found")
secret = pyotp.random_base32()
user.totp_secret = secret
await session.commit()
totp = pyotp.totp.TOTP(secret)
uri = totp.provisioning_uri(user.email, issuer_name="DocuVault")
return secret, uri
async def verify_totp(
session: AsyncSession,
user_id: uuid.UUID,
code: str,
redis_client,
) -> bool:
"""Verify a TOTP code with replay prevention (AUTH-08).
valid_window=1 allows ±30s clock drift (per STATE.md recommendation).
Replay prevention: stores used codes in Redis with key 'totp_used:{user_id}:{code}'
and TTL=90s (covers the full ±30s validity window).
Returns False if the code is invalid, already used, or user has no TOTP secret.
"""
user = await session.get(User, user_id)
if user is None or not user.totp_secret:
return False
totp = pyotp.TOTP(user.totp_secret)
# Check replay prevention before verifying
replay_key = f"totp_used:{user_id}:{code}"
if await redis_client.get(replay_key):
return False # Code already used within the validity window
if not totp.verify(code, valid_window=1):
return False
# Mark as used for 90s (covers valid_window=1: ±30s = 90s total)
await redis_client.set(replay_key, "1", ex=90)
return True
# ── Backup codes ────────────────────────────────────────────────────────────────
def generate_backup_codes(n: int = 10) -> list[str]:
"""Return *n* random 8-character uppercase alphanumeric backup codes."""
codes = []
for _ in range(n):
# secrets.token_hex(4) returns 8 hex chars; uppercase for readability
code = secrets.token_hex(4).upper()
codes.append(code)
return codes
async def store_backup_codes(
session: AsyncSession, user_id: uuid.UUID, codes: list[str]
) -> None:
"""Store backup codes for a user, replacing any existing unused codes.
Each code is stored as an Argon2 hash (never plaintext, T-02-03).
Existing unused codes are deleted first to prevent accumulation.
"""
# Delete existing unused codes
result = await session.execute(
select(BackupCode).where(
BackupCode.user_id == user_id,
BackupCode.used_at.is_(None),
)
)
for row in result.scalars().all():
await session.delete(row)
await session.flush()
# Insert new hashed codes
for code in codes:
row = BackupCode(
id=uuid.uuid4(),
user_id=user_id,
code_hash=hash_password(code),
used_at=None,
)
session.add(row)
await session.commit()
async def verify_backup_code(
session: AsyncSession, user_id: uuid.UUID, code: str
) -> bool:
"""Verify a backup code using constant-time comparison.
Always iterates ALL unused codes (prevents timing-based enumeration, SEC-06).
On match: sets used_at=now() and commits. Returns True on first match.
Returns False if no code matches.
"""
result = await session.execute(
select(BackupCode).where(
BackupCode.user_id == user_id,
BackupCode.used_at.is_(None),
)
)
rows = result.scalars().all()
matched_row: Optional[BackupCode] = None
for row in rows:
# Always call verify_password for ALL rows (constant-time: no early exit)
if verify_password(code, row.code_hash):
matched_row = row # record match but keep iterating
if matched_row is None:
return False
# Mark the matched code as used
matched_row.used_at = datetime.now(timezone.utc)
await session.commit()
return True
# ── HaveIBeenPwned check ────────────────────────────────────────────────────────
async def check_hibp(password: str) -> bool:
"""Check if password appears in HaveIBeenPwned using the k-anonymity model.
Sends only the first 5 chars of the SHA-1 hash to the HIBP API (T-02-05).
Returns True if the password has been breached, False otherwise.
On network error: logs a warning and returns False (fail-open, T-02-06).
"""
sha1 = hashlib.sha1(password.encode("utf-8")).hexdigest().upper()
prefix, suffix = sha1[:5], sha1[5:]
try:
async with httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(
f"https://api.pwnedpasswords.com/range/{prefix}",
headers={"Add-Padding": "true"},
)
resp.raise_for_status()
except Exception as exc:
logger.warning("HIBP check failed (fail-open): %s", exc)
return False
for line in resp.text.splitlines():
parts = line.split(":")
if len(parts) == 2:
candidate_suffix, count_str = parts
if hmac.compare_digest(candidate_suffix.upper(), suffix):
try:
count = int(count_str.strip())
except ValueError:
count = 1
if count > 0:
return True
return False
# ── Admin bootstrap ─────────────────────────────────────────────────────────────
async def bootstrap_admin(session: AsyncSession) -> None:
"""Idempotent admin account bootstrap (D-04, D-05, D-06).
If the users table is empty AND settings.admin_email and settings.admin_password
are both non-empty, creates an admin User row with a Quota row.
Logs a WARNING if env vars are missing (D-05) but never raises.
"""
if not settings.admin_email or not settings.admin_password:
logger.warning(
"Admin bootstrap skipped: ADMIN_EMAIL and/or ADMIN_PASSWORD not set (D-05). "
"Set both env vars to seed the first admin account on startup."
)
return
# Check if any users exist
result = await session.execute(select(User).limit(1))
if result.scalar_one_or_none() is not None:
# Users already exist — idempotent, skip (D-04)
return
admin_id = uuid.uuid4()
admin_user = User(
id=admin_id,
handle="admin",
email=settings.admin_email,
password_hash=hash_password(settings.admin_password),
role="admin",
is_active=True,
password_must_change=False,
)
quota = Quota(
user_id=admin_id,
limit_bytes=104857600, # 100 MB default (D-06)
used_bytes=0,
)
session.add(admin_user)
await session.flush() # persist User first so Quota FK is satisfied
session.add(quota)
await session.commit()
logger.info("Admin account bootstrapped for %s", settings.admin_email)