refactor(backend): extract shared helper modules per architecture rules

- Add backend/ai/utils.py — parse_classification, parse_suggestions, strip_code_fences
  shared by all AI providers; removes duplicated private functions from
  anthropic_provider.py and openai_provider.py
- Add backend/deps/utils.py — get_client_ip, parse_uuid request-parsing helpers;
  removes local _ip() variants from admin.py, auth.py, shares.py, folders.py
- Add backend/storage/exceptions.py — canonical CloudConnectionError definition;
  all routers and backends import from here instead of redefining
- Move validate_password_strength to backend/services/auth.py; removes duplicated
  _validate_password_strength from admin.py and auth.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
curo1305
2026-06-02 16:10:35 +02:00
parent 89f8d5a654
commit a548266461
14 changed files with 171 additions and 232 deletions
+3 -37
View File
@@ -1,7 +1,6 @@
import json
import re
import anthropic import anthropic
from ai.base import AIProvider, ClassificationResult from ai.base import AIProvider, ClassificationResult
from ai.utils import parse_classification, parse_suggestions
MAX_AI_CHARS = 8_000 MAX_AI_CHARS = 8_000
@@ -33,7 +32,7 @@ class AnthropicProvider(AIProvider):
messages=[{"role": "user", "content": user_msg}], messages=[{"role": "user", "content": user_msg}],
) )
raw = response.content[0].text raw = response.content[0].text
return _parse_classification(raw) return parse_classification(raw)
async def suggest_topics( async def suggest_topics(
self, self,
@@ -53,7 +52,7 @@ class AnthropicProvider(AIProvider):
messages=[{"role": "user", "content": user_msg}], messages=[{"role": "user", "content": user_msg}],
) )
raw = response.content[0].text raw = response.content[0].text
return _parse_suggestions(raw) return parse_suggestions(raw)
async def health_check(self) -> bool: async def health_check(self) -> bool:
try: try:
@@ -68,36 +67,3 @@ class AnthropicProvider(AIProvider):
return False return False
def _strip_code_fences(text: str) -> str:
text = re.sub(r"```(?:json)?\s*", "", text)
text = re.sub(r"```", "", text)
return text.strip()
def _parse_classification(raw: str) -> ClassificationResult:
raw = _strip_code_fences(raw)
# Try to find JSON object
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
try:
data = json.loads(match.group())
return ClassificationResult(
topics=data.get("assigned_topics", []),
suggested_new_topics=data.get("new_topic_suggestions", []),
reasoning=data.get("reasoning", ""),
)
except json.JSONDecodeError:
pass
return ClassificationResult()
def _parse_suggestions(raw: str) -> list[str]:
raw = _strip_code_fences(raw)
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
try:
data = json.loads(match.group())
return data.get("suggested_topics", [])
except json.JSONDecodeError:
pass
return []
+3 -36
View File
@@ -1,7 +1,6 @@
import json
import re
from openai import AsyncOpenAI from openai import AsyncOpenAI
from ai.base import AIProvider, ClassificationResult from ai.base import AIProvider, ClassificationResult
from ai.utils import parse_classification, parse_suggestions
MAX_AI_CHARS = 8_000 MAX_AI_CHARS = 8_000
@@ -35,7 +34,7 @@ class OpenAIProvider(AIProvider):
], ],
) )
raw = response.choices[0].message.content or "" raw = response.choices[0].message.content or ""
return _parse_classification(raw) return parse_classification(raw)
async def suggest_topics( async def suggest_topics(
self, self,
@@ -56,7 +55,7 @@ class OpenAIProvider(AIProvider):
], ],
) )
raw = response.choices[0].message.content or "" raw = response.choices[0].message.content or ""
return _parse_suggestions(raw) return parse_suggestions(raw)
async def health_check(self) -> bool: async def health_check(self) -> bool:
try: try:
@@ -70,35 +69,3 @@ class OpenAIProvider(AIProvider):
return False return False
def _strip_code_fences(text: str) -> str:
text = re.sub(r"```(?:json)?\s*", "", text)
text = re.sub(r"```", "", text)
return text.strip()
def _parse_classification(raw: str) -> ClassificationResult:
raw = _strip_code_fences(raw)
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
try:
data = json.loads(match.group())
return ClassificationResult(
topics=data.get("assigned_topics", []),
suggested_new_topics=data.get("new_topic_suggestions", []),
reasoning=data.get("reasoning", ""),
)
except json.JSONDecodeError:
pass
return ClassificationResult()
def _parse_suggestions(raw: str) -> list[str]:
raw = _strip_code_fences(raw)
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
try:
data = json.loads(match.group())
return data.get("suggested_topics", [])
except json.JSONDecodeError:
pass
return []
+51
View File
@@ -0,0 +1,51 @@
"""Shared AI response parsing utilities — used by all provider implementations."""
from __future__ import annotations
import json
import re
from ai.base import ClassificationResult
def strip_code_fences(text: str) -> str:
"""Remove markdown code fences (```json ... ```) from *text*."""
text = re.sub(r"```(?:json)?\s*", "", text)
text = re.sub(r"```", "", text)
return text.strip()
def parse_classification(raw: str) -> ClassificationResult:
"""Parse a classification JSON response into a ClassificationResult.
Tolerates markdown code fences and extracts the first JSON object found.
Returns an empty ClassificationResult on any parse failure.
"""
raw = strip_code_fences(raw)
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
try:
data = json.loads(match.group())
return ClassificationResult(
topics=data.get("assigned_topics", []),
suggested_new_topics=data.get("new_topic_suggestions", []),
reasoning=data.get("reasoning", ""),
)
except json.JSONDecodeError:
pass
return ClassificationResult()
def parse_suggestions(raw: str) -> list[str]:
"""Parse a topic-suggestion JSON response into a list of topic name strings.
Tolerates markdown code fences. Returns an empty list on parse failure.
"""
raw = strip_code_fences(raw)
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
try:
data = json.loads(match.group())
return data.get("suggested_topics", [])
except json.JSONDecodeError:
pass
return []
+8 -52
View File
@@ -23,7 +23,6 @@ Security invariants:
""" """
from __future__ import annotations from __future__ import annotations
import re
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
@@ -36,8 +35,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from db.models import CloudConnection, Document, Quota, RefreshToken, Topic, User from db.models import CloudConnection, Document, Quota, RefreshToken, Topic, User
from deps.auth import get_current_admin from deps.auth import get_current_admin
from deps.db import get_db from deps.db import get_db
from deps.utils import get_client_ip
from services.audit import write_audit_log from services.audit import write_audit_log
from services.auth import hash_password, revoke_all_refresh_tokens, verify_password from services.auth import hash_password, revoke_all_refresh_tokens, validate_password_strength, verify_password
from storage import get_storage_backend, get_storage_backend_for_document from storage import get_storage_backend, get_storage_backend_for_document
router = APIRouter(prefix="/api/admin", tags=["admin"]) router = APIRouter(prefix="/api/admin", tags=["admin"])
@@ -46,28 +46,6 @@ router = APIRouter(prefix="/api/admin", tags=["admin"])
_DEFAULT_QUOTA_BYTES = 104857600 # 100 MB free-tier default (D-06) _DEFAULT_QUOTA_BYTES = 104857600 # 100 MB free-tier default (D-06)
_PASSWORD_DETAIL = (
"Password must be at least 12 characters and include uppercase, "
"lowercase, a number, and a special character."
)
# ── IP extraction helper ──────────────────────────────────────────────────────
def _ip(request: Request) -> Optional[str]:
"""Extract best-effort client IP from request for audit logging.
TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be
forged by any caller. This value is used for forensic audit logging only —
not for authentication or access control decisions. In production, deploy
behind a trusted reverse proxy (e.g. nginx with
`proxy_set_header X-Forwarded-For $remote_addr;`) which overwrites this
header with the real remote IP before it reaches FastAPI, or use a
trusted-proxy middleware that validates the source CIDR.
"""
return request.headers.get("X-Forwarded-For") or (
request.client.host if request.client else None
)
# ── Safe response helper ────────────────────────────────────────────────────── # ── Safe response helper ──────────────────────────────────────────────────────
@@ -90,25 +68,6 @@ def _user_to_dict(user: User) -> dict:
} }
# ── Password strength helper ──────────────────────────────────────────────────
def _validate_password_strength(password: str) -> None:
"""Raise ValueError with the spec message if password fails any strength rule.
Rules (AUTH-01): min 12 chars, has uppercase, has lowercase, has digit,
has special char (non-alphanumeric).
"""
if len(password) < 12:
raise ValueError(_PASSWORD_DETAIL)
if not re.search(r"[A-Z]", password):
raise ValueError(_PASSWORD_DETAIL)
if not re.search(r"[a-z]", password):
raise ValueError(_PASSWORD_DETAIL)
if not re.search(r"[0-9]", password):
raise ValueError(_PASSWORD_DETAIL)
if not re.search(r"[^A-Za-z0-9]", password):
raise ValueError(_PASSWORD_DETAIL)
# ── Request models ──────────────────────────────────────────────────────────── # ── Request models ────────────────────────────────────────────────────────────
@@ -121,10 +80,7 @@ class UserCreate(BaseModel):
@field_validator("password") @field_validator("password")
@classmethod @classmethod
def password_strength(cls, v: str) -> str: def password_strength(cls, v: str) -> str:
try: validate_password_strength(v)
_validate_password_strength(v)
except ValueError as exc:
raise ValueError(str(exc)) from exc
return v return v
@@ -264,7 +220,7 @@ async def create_user(
session.add(quota) session.add(quota)
await session.flush() # persist User + Quota before audit_log FK references them await session.flush() # persist User + Quota before audit_log FK references them
# D-13: admin user created event # D-13: admin user created event
_ip_addr = _ip(request) _ip_addr = get_client_ip(request)
await write_audit_log( await write_audit_log(
session, session,
event_type="admin.user_created", event_type="admin.user_created",
@@ -316,7 +272,7 @@ async def update_user_status(
detail="Cannot deactivate the only admin", detail="Cannot deactivate the only admin",
) )
_ip_addr = _ip(request) _ip_addr = get_client_ip(request)
user.is_active = body.is_active user.is_active = body.is_active
if not body.is_active: if not body.is_active:
@@ -426,7 +382,7 @@ async def update_user_quota(
else None else None
) )
_ip_addr = _ip(request) _ip_addr = get_client_ip(request)
old_limit = quota.limit_bytes old_limit = quota.limit_bytes
quota.limit_bytes = body.limit_bytes quota.limit_bytes = body.limit_bytes
session.add(quota) session.add(quota)
@@ -471,7 +427,7 @@ async def update_ai_config(
if user is None: if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
_ip_addr = _ip(request) _ip_addr = get_client_ip(request)
user.ai_provider = body.ai_provider user.ai_provider = body.ai_provider
user.ai_model = body.ai_model user.ai_model = body.ai_model
session.add(user) session.add(user)
@@ -532,7 +488,7 @@ async def delete_user(
detail="Cannot delete admin accounts", detail="Cannot delete admin accounts",
) )
_ip_addr = _ip(request) _ip_addr = get_client_ip(request)
# SEC-09 (cloud): purge cloud-stored documents and credentials BEFORE DB delete. # SEC-09 (cloud): purge cloud-stored documents and credentials BEFORE DB delete.
# Must run before MinIO cleanup so that credentials are still available to build # Must run before MinIO cleanup so that credentials are still available to build
+19 -46
View File
@@ -19,7 +19,6 @@ Security invariants:
""" """
from __future__ import annotations from __future__ import annotations
import re
import uuid import uuid
from typing import Literal, Optional from typing import Literal, Optional
@@ -32,6 +31,7 @@ from config import settings
from db.models import BackupCode, Quota, RefreshToken, User from db.models import BackupCode, Quota, RefreshToken, User
from deps.auth import get_current_user from deps.auth import get_current_user
from deps.db import get_db from deps.db import get_db
from deps.utils import get_client_ip
from services import auth as auth_service from services import auth as auth_service
from services.audit import write_audit_log from services.audit import write_audit_log
from slowapi import Limiter from slowapi import Limiter
@@ -43,30 +43,6 @@ router = APIRouter(prefix="/api/auth", tags=["auth"])
# IP-level rate limiter (SEC-02 — 10 req/min on register/login/refresh) # IP-level rate limiter (SEC-02 — 10 req/min on register/login/refresh)
limiter = Limiter(key_func=get_remote_address) limiter = Limiter(key_func=get_remote_address)
# ── Password strength validation ─────────────────────────────────────────────
_PASSWORD_DETAIL = (
"Password must be at least 12 characters and include uppercase, "
"lowercase, a number, and a special character."
)
def _validate_password_strength(password: str) -> bool:
"""Return True if password passes all strength rules (AUTH-01).
Rules: min 12 chars, has uppercase, has lowercase, has digit, has special char.
"""
if len(password) < 12:
return False
if not re.search(r"[A-Z]", password):
return False
if not re.search(r"[a-z]", password):
return False
if not re.search(r"[0-9]", password):
return False
if not re.search(r"[^A-Za-z0-9]", password):
return False
return True
# ── Request models ──────────────────────────────────────────────────────────── # ── Request models ────────────────────────────────────────────────────────────
@@ -132,11 +108,10 @@ async def register(
- Inserts User + Quota rows in a single transaction - Inserts User + Quota rows in a single transaction
""" """
# Password strength check # Password strength check
if not _validate_password_strength(body.password): try:
raise HTTPException( auth_service.validate_password_strength(body.password)
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, except ValueError as exc:
detail=_PASSWORD_DETAIL, raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
)
# HIBP breach check # HIBP breach check
if await auth_service.check_hibp(body.password): if await auth_service.check_hibp(body.password):
@@ -228,7 +203,7 @@ async def login(
user: Optional[User] = result.scalar_one_or_none() user: Optional[User] = result.scalar_one_or_none()
# IP extraction for audit log (used in both success and failure paths) # IP extraction for audit log (used in both success and failure paths)
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) _ip = get_client_ip(request)
# Verify password (anti-enumeration: same error regardless of whether user exists) # Verify password (anti-enumeration: same error regardless of whether user exists)
if user is None or not auth_service.verify_password(body.password, user.password_hash): if user is None or not auth_service.verify_password(body.password, user.password_hash):
@@ -385,7 +360,7 @@ async def logout(request: Request, response: Response, session: AsyncSession = D
"""Revoke current refresh token and clear the cookie.""" """Revoke current refresh token and clear the cookie."""
import hashlib as _hashlib import hashlib as _hashlib
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) _ip = get_client_ip(request)
raw_token = request.cookies.get("refresh_token") raw_token = request.cookies.get("refresh_token")
_logout_user_id = None _logout_user_id = None
@@ -423,7 +398,7 @@ async def logout_all(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Sign out of all devices: revoke all refresh tokens for current user.""" """Sign out of all devices: revoke all refresh tokens for current user."""
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) _ip = get_client_ip(request)
count = await auth_service.revoke_all_refresh_tokens(session, current_user.id) count = await auth_service.revoke_all_refresh_tokens(session, current_user.id)
# D-13: sign-out-all event # D-13: sign-out-all event
await write_audit_log( await write_audit_log(
@@ -497,14 +472,13 @@ async def change_password(
) )
# Password strength check # Password strength check
if not _validate_password_strength(body.new_password): try:
raise HTTPException( auth_service.validate_password_strength(body.new_password)
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, except ValueError as exc:
detail=_PASSWORD_DETAIL, raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
)
# Update password # Update password
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) _ip = get_client_ip(request)
user = await session.get(User, current_user.id) user = await session.get(User, current_user.id)
user.password_hash = auth_service.hash_password(body.new_password) user.password_hash = auth_service.hash_password(body.new_password)
# D-13: password changed event (flush within same transaction before commit) # D-13: password changed event (flush within same transaction before commit)
@@ -594,7 +568,7 @@ async def enable_totp(
await auth_service.store_backup_codes(session, current_user.id, plain_codes) await auth_service.store_backup_codes(session, current_user.id, plain_codes)
# D-13: TOTP enrolled event # D-13: TOTP enrolled event
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) _ip = get_client_ip(request)
await write_audit_log( await write_audit_log(
session, session,
event_type="auth.totp_enrolled", event_type="auth.totp_enrolled",
@@ -620,7 +594,7 @@ async def disable_totp(
Clears totp_secret, sets totp_enabled=False, and deletes all backup codes. Clears totp_secret, sets totp_enabled=False, and deletes all backup codes.
""" """
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) _ip = get_client_ip(request)
user = await session.get(User, current_user.id) user = await session.get(User, current_user.id)
user.totp_enabled = False user.totp_enabled = False
user.totp_secret = None user.totp_secret = None
@@ -699,11 +673,10 @@ async def password_reset_confirm(
) )
# Password strength validation # Password strength validation
if not _validate_password_strength(body.new_password): try:
raise HTTPException( auth_service.validate_password_strength(body.new_password)
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, except ValueError as exc:
detail=_PASSWORD_DETAIL, raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
)
# HIBP breach check (SEC-03) # HIBP breach check (SEC-03)
if await auth_service.check_hibp(body.new_password): if await auth_service.check_hibp(body.new_password):
+1 -8
View File
@@ -48,14 +48,7 @@ except ImportError:
# Fallback for test environments where minio is not installed # Fallback for test environments where minio is not installed
S3Error = Exception # type: ignore[assignment,misc] S3Error = Exception # type: ignore[assignment,misc]
try: from storage.exceptions import CloudConnectionError
from storage.google_drive_backend import CloudConnectionError
except ImportError:
# Fallback: define a stub so the except clause compiles even if google deps absent
class CloudConnectionError(Exception): # type: ignore[no-redef]
def __init__(self, msg: str = "", *, reason: str = "") -> None:
super().__init__(msg)
self.reason = reason
# Valid cloud backend slugs (T-05-06-01: validated against allowlist, not user-supplied string) # Valid cloud backend slugs (T-05-06-01: validated against allowlist, not user-supplied string)
_CLOUD_PROVIDERS = frozenset({"google_drive", "onedrive", "nextcloud", "webdav"}) _CLOUD_PROVIDERS = frozenset({"google_drive", "onedrive", "nextcloud", "webdav"})
+4 -11
View File
@@ -30,6 +30,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from db.models import Document, Folder, Quota, Share, User from db.models import Document, Folder, Quota, Share, User
from deps.auth import get_regular_user from deps.auth import get_regular_user
from deps.db import get_db from deps.db import get_db
from deps.utils import get_client_ip
from services.audit import write_audit_log from services.audit import write_audit_log
from storage import get_storage_backend from storage import get_storage_backend
@@ -51,14 +52,6 @@ class DocumentMove(BaseModel):
folder_id: Optional[str] = None folder_id: Optional[str] = None
# ── Helper: extract IP address ────────────────────────────────────────────────
def _get_ip(request: Request) -> Optional[str]:
"""Extract client IP, honouring X-Forwarded-For for reverse proxy setups (Pitfall 5)."""
return request.headers.get("X-Forwarded-For") or (
request.client.host if request.client else None
)
# ── Helper: folder serialization ────────────────────────────────────────────── # ── Helper: folder serialization ──────────────────────────────────────────────
@@ -148,7 +141,7 @@ async def create_folder(
user_id=current_user.id, user_id=current_user.id,
actor_id=current_user.id, actor_id=current_user.id,
resource_id=folder.id, resource_id=folder.id,
ip_address=_get_ip(request), ip_address=get_client_ip(request),
metadata_={"name": folder.name, "parent_id": str(parent_uuid) if parent_uuid else None}, metadata_={"name": folder.name, "parent_id": str(parent_uuid) if parent_uuid else None},
) )
@@ -316,7 +309,7 @@ async def rename_folder(
user_id=current_user.id, user_id=current_user.id,
actor_id=current_user.id, actor_id=current_user.id,
resource_id=folder.id, resource_id=folder.id,
ip_address=_get_ip(request), ip_address=get_client_ip(request),
metadata_={"old_name": old_name, "new_name": folder.name}, metadata_={"old_name": old_name, "new_name": folder.name},
) )
@@ -436,7 +429,7 @@ async def delete_folder(
user_id=current_user.id, user_id=current_user.id,
actor_id=current_user.id, actor_id=current_user.id,
resource_id=uid, resource_id=uid,
ip_address=_get_ip(request), ip_address=get_client_ip(request),
metadata_={"name": folder_name, "doc_count": len(docs)}, metadata_={"name": folder_name, "doc_count": len(docs)},
) )
+4 -18
View File
@@ -27,6 +27,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from db.models import Document, Share, User from db.models import Document, Share, User
from deps.auth import get_regular_user from deps.auth import get_regular_user
from deps.db import get_db from deps.db import get_db
from deps.utils import get_client_ip
from services.audit import write_audit_log from services.audit import write_audit_log
router = APIRouter(prefix="/api/shares", tags=["shares"]) router = APIRouter(prefix="/api/shares", tags=["shares"])
@@ -62,21 +63,6 @@ class SharePermissionPatch(BaseModel):
# ── Helpers ─────────────────────────────────────────────────────────────────── # ── Helpers ───────────────────────────────────────────────────────────────────
def _ip(request: Request) -> Optional[str]:
"""Extract best-effort client IP from request (behind proxy or direct).
TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be
forged by any caller. This value is used for forensic audit logging only —
not for authentication or access control decisions. In production, deploy
behind a trusted reverse proxy (e.g. nginx with
`proxy_set_header X-Forwarded-For $remote_addr;`) which overwrites this
header with the real remote IP before it reaches FastAPI, or use a
trusted-proxy middleware that validates the source CIDR.
"""
return request.headers.get("X-Forwarded-For") or (
request.client.host if request.client else None
)
# ── POST /api/shares ────────────────────────────────────────────────────────── # ── POST /api/shares ──────────────────────────────────────────────────────────
@@ -141,7 +127,7 @@ async def grant_share(
user_id=current_user.id, user_id=current_user.id,
actor_id=current_user.id, actor_id=current_user.id,
resource_id=uid, resource_id=uid,
ip_address=_ip(request), ip_address=get_client_ip(request),
metadata_={"recipient_id": str(recipient.id)}, metadata_={"recipient_id": str(recipient.id)},
) )
@@ -283,7 +269,7 @@ async def update_share_permission(
user_id=current_user.id, user_id=current_user.id,
actor_id=current_user.id, actor_id=current_user.id,
resource_id=share.document_id, resource_id=share.document_id,
ip_address=_ip(request), ip_address=get_client_ip(request),
metadata_={"share_id": str(share.id), "new_permission": body.permission}, metadata_={"share_id": str(share.id), "new_permission": body.permission},
) )
await session.commit() await session.commit()
@@ -328,7 +314,7 @@ async def revoke_share(
user_id=current_user.id, user_id=current_user.id,
actor_id=current_user.id, actor_id=current_user.id,
resource_id=document_id, resource_id=document_id,
ip_address=_ip(request), ip_address=get_client_ip(request),
metadata_={"recipient_id": str(recipient_id)}, metadata_={"recipient_id": str(recipient_id)},
) )
+34
View File
@@ -0,0 +1,34 @@
"""Shared dependency utilities — request parsing helpers used across all API routers."""
from __future__ import annotations
import uuid
from typing import Optional
from fastapi import HTTPException, Request
def get_client_ip(request: Request) -> Optional[str]:
"""Extract best-effort client IP from request for audit logging.
TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be
forged by any caller. This value is used for forensic audit logging only —
not for authentication or access control decisions. In production, deploy
behind a trusted reverse proxy (e.g. nginx with
``proxy_set_header X-Forwarded-For $remote_addr;``) which overwrites this
header with the real remote IP before it reaches FastAPI.
"""
return request.headers.get("X-Forwarded-For") or (
request.client.host if request.client else None
)
def parse_uuid(value: str, detail: str = "Not found") -> uuid.UUID:
"""Parse *value* as a UUID, raising HTTP 404 with *detail* on failure.
Use at API boundaries to convert path/body string IDs to UUID objects.
Returns the parsed UUID so callers can use it directly without a try/except.
"""
try:
return uuid.UUID(value)
except ValueError:
raise HTTPException(status_code=404, detail=detail)
+22
View File
@@ -20,6 +20,7 @@ from __future__ import annotations
import hashlib import hashlib
import hmac import hmac
import logging import logging
import re
import secrets import secrets
import uuid import uuid
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
@@ -36,6 +37,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from config import settings from config import settings
from db.models import BackupCode, Quota, RefreshToken, User from db.models import BackupCode, Quota, RefreshToken, User
_PASSWORD_DETAIL = (
"Password must be at least 12 characters and include uppercase, "
"lowercase, a number, and a special character."
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ── Password hashing ──────────────────────────────────────────────────────────── # ── Password hashing ────────────────────────────────────────────────────────────
@@ -59,6 +65,22 @@ def verify_password(plain: str, hashed: str) -> bool:
return False return False
def validate_password_strength(password: str) -> None:
"""Raise ValueError with a descriptive message if *password* fails any strength rule.
Rules (AUTH-01): min 12 chars, uppercase, lowercase, digit, special char.
Callers at the API boundary should catch ValueError and map it to HTTP 422.
"""
if (
len(password) < 12
or not re.search(r"[A-Z]", password)
or not re.search(r"[a-z]", password)
or not re.search(r"[0-9]", password)
or not re.search(r"[^A-Za-z0-9]", password)
):
raise ValueError(_PASSWORD_DETAIL)
# ── JWT helpers ───────────────────────────────────────────────────────────────── # ── JWT helpers ─────────────────────────────────────────────────────────────────
def create_access_token(user_id: str, role: str) -> str: def create_access_token(user_id: str, role: str) -> str:
+19
View File
@@ -0,0 +1,19 @@
"""Storage exception types — import from here, never redefine elsewhere."""
from __future__ import annotations
class CloudConnectionError(Exception):
"""Raised when a cloud provider signals a non-retryable connection problem.
Attributes:
reason: "token_expired" — access token expired; API layer can refresh and retry.
"invalid_grant" — refresh token revoked; user must reconnect.
The backend never updates the DB. The API layer (_call_cloud_op in cloud.py)
catches this exception, performs the DB state transition, and decides whether
to retry or surface a 503 to the client (B2 design, D-05/D-06).
"""
def __init__(self, msg: str = "", *, reason: str = "") -> None:
super().__init__(msg)
self.reason = reason # "token_expired" | "invalid_grant"
+1 -17
View File
@@ -37,23 +37,7 @@ from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
from storage.base import StorageBackend from storage.base import StorageBackend
from storage.exceptions import CloudConnectionError # noqa: F401 re-exported for import compatibility
class CloudConnectionError(Exception):
"""Raised when a cloud provider signals a non-retryable connection problem.
Attributes:
reason: "token_expired" — access token expired; API layer can refresh and retry.
"invalid_grant" — refresh token revoked; user must reconnect.
The backend never updates the DB. The API layer (_call_cloud_op in cloud.py)
catches this exception, performs the DB state transition, and decides whether
to retry or surface a 503 to the client (B2 design, D-05/D-06).
"""
def __init__(self, msg: str = "", *, reason: str = "") -> None:
super().__init__(msg)
self.reason = reason # "token_expired" | "invalid_grant"
class GoogleDriveBackend(StorageBackend): class GoogleDriveBackend(StorageBackend):
+1 -6
View File
@@ -76,12 +76,7 @@ async def _run(document_id: str) -> dict:
if user is None: if user is None:
return {"document_id": document_id, "status": "missing_user"} return {"document_id": document_id, "status": "missing_user"}
try: from storage.exceptions import CloudConnectionError
from storage.google_drive_backend import CloudConnectionError
except ImportError:
class CloudConnectionError(Exception): # type: ignore[no-redef]
pass
try: try:
backend = await get_storage_backend_for_document(doc, user, session) backend = await get_storage_backend_for_document(doc, user, session)
file_bytes = await backend.get_object(doc.object_key) file_bytes = await backend.get_object(doc.object_key)
+1 -1
View File
@@ -4,7 +4,7 @@ Uses a mock provider — no real AI calls made.
""" """
import json import json
import pytest import pytest
from ai.openai_provider import _parse_classification, _parse_suggestions, _strip_code_fences from ai.utils import parse_classification as _parse_classification, parse_suggestions as _parse_suggestions, strip_code_fences as _strip_code_fences
from ai.base import ClassificationResult from ai.base import ClassificationResult