refactor(backend): extract shared helper modules per architecture rules
- Add backend/ai/utils.py — parse_classification, parse_suggestions, strip_code_fences shared by all AI providers; removes duplicated private functions from anthropic_provider.py and openai_provider.py - Add backend/deps/utils.py — get_client_ip, parse_uuid request-parsing helpers; removes local _ip() variants from admin.py, auth.py, shares.py, folders.py - Add backend/storage/exceptions.py — canonical CloudConnectionError definition; all routers and backends import from here instead of redefining - Move validate_password_strength to backend/services/auth.py; removes duplicated _validate_password_strength from admin.py and auth.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
import anthropic
|
||||
from ai.base import AIProvider, ClassificationResult
|
||||
from ai.utils import parse_classification, parse_suggestions
|
||||
|
||||
MAX_AI_CHARS = 8_000
|
||||
|
||||
@@ -33,7 +32,7 @@ class AnthropicProvider(AIProvider):
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
)
|
||||
raw = response.content[0].text
|
||||
return _parse_classification(raw)
|
||||
return parse_classification(raw)
|
||||
|
||||
async def suggest_topics(
|
||||
self,
|
||||
@@ -53,7 +52,7 @@ class AnthropicProvider(AIProvider):
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
)
|
||||
raw = response.content[0].text
|
||||
return _parse_suggestions(raw)
|
||||
return parse_suggestions(raw)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
@@ -68,36 +67,3 @@ class AnthropicProvider(AIProvider):
|
||||
return False
|
||||
|
||||
|
||||
def _strip_code_fences(text: str) -> str:
|
||||
text = re.sub(r"```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"```", "", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _parse_classification(raw: str) -> ClassificationResult:
|
||||
raw = _strip_code_fences(raw)
|
||||
# Try to find JSON object
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return ClassificationResult(
|
||||
topics=data.get("assigned_topics", []),
|
||||
suggested_new_topics=data.get("new_topic_suggestions", []),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return ClassificationResult()
|
||||
|
||||
|
||||
def _parse_suggestions(raw: str) -> list[str]:
|
||||
raw = _strip_code_fences(raw)
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return data.get("suggested_topics", [])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return []
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
from openai import AsyncOpenAI
|
||||
from ai.base import AIProvider, ClassificationResult
|
||||
from ai.utils import parse_classification, parse_suggestions
|
||||
|
||||
MAX_AI_CHARS = 8_000
|
||||
|
||||
@@ -35,7 +34,7 @@ class OpenAIProvider(AIProvider):
|
||||
],
|
||||
)
|
||||
raw = response.choices[0].message.content or ""
|
||||
return _parse_classification(raw)
|
||||
return parse_classification(raw)
|
||||
|
||||
async def suggest_topics(
|
||||
self,
|
||||
@@ -56,7 +55,7 @@ class OpenAIProvider(AIProvider):
|
||||
],
|
||||
)
|
||||
raw = response.choices[0].message.content or ""
|
||||
return _parse_suggestions(raw)
|
||||
return parse_suggestions(raw)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
@@ -70,35 +69,3 @@ class OpenAIProvider(AIProvider):
|
||||
return False
|
||||
|
||||
|
||||
def _strip_code_fences(text: str) -> str:
|
||||
text = re.sub(r"```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"```", "", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _parse_classification(raw: str) -> ClassificationResult:
|
||||
raw = _strip_code_fences(raw)
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return ClassificationResult(
|
||||
topics=data.get("assigned_topics", []),
|
||||
suggested_new_topics=data.get("new_topic_suggestions", []),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return ClassificationResult()
|
||||
|
||||
|
||||
def _parse_suggestions(raw: str) -> list[str]:
|
||||
raw = _strip_code_fences(raw)
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return data.get("suggested_topics", [])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return []
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Shared AI response parsing utilities — used by all provider implementations."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ai.base import ClassificationResult
|
||||
|
||||
|
||||
def strip_code_fences(text: str) -> str:
|
||||
"""Remove markdown code fences (```json ... ```) from *text*."""
|
||||
text = re.sub(r"```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"```", "", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def parse_classification(raw: str) -> ClassificationResult:
|
||||
"""Parse a classification JSON response into a ClassificationResult.
|
||||
|
||||
Tolerates markdown code fences and extracts the first JSON object found.
|
||||
Returns an empty ClassificationResult on any parse failure.
|
||||
"""
|
||||
raw = strip_code_fences(raw)
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return ClassificationResult(
|
||||
topics=data.get("assigned_topics", []),
|
||||
suggested_new_topics=data.get("new_topic_suggestions", []),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return ClassificationResult()
|
||||
|
||||
|
||||
def parse_suggestions(raw: str) -> list[str]:
|
||||
"""Parse a topic-suggestion JSON response into a list of topic name strings.
|
||||
|
||||
Tolerates markdown code fences. Returns an empty list on parse failure.
|
||||
"""
|
||||
raw = strip_code_fences(raw)
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return data.get("suggested_topics", [])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return []
|
||||
+8
-52
@@ -23,7 +23,6 @@ Security invariants:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
@@ -36,8 +35,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from db.models import CloudConnection, Document, Quota, RefreshToken, Topic, User
|
||||
from deps.auth import get_current_admin
|
||||
from deps.db import get_db
|
||||
from deps.utils import get_client_ip
|
||||
from services.audit import write_audit_log
|
||||
from services.auth import hash_password, revoke_all_refresh_tokens, verify_password
|
||||
from services.auth import hash_password, revoke_all_refresh_tokens, validate_password_strength, verify_password
|
||||
from storage import get_storage_backend, get_storage_backend_for_document
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||
@@ -46,28 +46,6 @@ router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||
|
||||
_DEFAULT_QUOTA_BYTES = 104857600 # 100 MB free-tier default (D-06)
|
||||
|
||||
_PASSWORD_DETAIL = (
|
||||
"Password must be at least 12 characters and include uppercase, "
|
||||
"lowercase, a number, and a special character."
|
||||
)
|
||||
|
||||
|
||||
# ── IP extraction helper ──────────────────────────────────────────────────────
|
||||
|
||||
def _ip(request: Request) -> Optional[str]:
|
||||
"""Extract best-effort client IP from request for audit logging.
|
||||
|
||||
TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be
|
||||
forged by any caller. This value is used for forensic audit logging only —
|
||||
not for authentication or access control decisions. In production, deploy
|
||||
behind a trusted reverse proxy (e.g. nginx with
|
||||
`proxy_set_header X-Forwarded-For $remote_addr;`) which overwrites this
|
||||
header with the real remote IP before it reaches FastAPI, or use a
|
||||
trusted-proxy middleware that validates the source CIDR.
|
||||
"""
|
||||
return request.headers.get("X-Forwarded-For") or (
|
||||
request.client.host if request.client else None
|
||||
)
|
||||
|
||||
|
||||
# ── Safe response helper ──────────────────────────────────────────────────────
|
||||
@@ -90,25 +68,6 @@ def _user_to_dict(user: User) -> dict:
|
||||
}
|
||||
|
||||
|
||||
# ── Password strength helper ──────────────────────────────────────────────────
|
||||
|
||||
def _validate_password_strength(password: str) -> None:
|
||||
"""Raise ValueError with the spec message if password fails any strength rule.
|
||||
|
||||
Rules (AUTH-01): min 12 chars, has uppercase, has lowercase, has digit,
|
||||
has special char (non-alphanumeric).
|
||||
"""
|
||||
if len(password) < 12:
|
||||
raise ValueError(_PASSWORD_DETAIL)
|
||||
if not re.search(r"[A-Z]", password):
|
||||
raise ValueError(_PASSWORD_DETAIL)
|
||||
if not re.search(r"[a-z]", password):
|
||||
raise ValueError(_PASSWORD_DETAIL)
|
||||
if not re.search(r"[0-9]", password):
|
||||
raise ValueError(_PASSWORD_DETAIL)
|
||||
if not re.search(r"[^A-Za-z0-9]", password):
|
||||
raise ValueError(_PASSWORD_DETAIL)
|
||||
|
||||
|
||||
# ── Request models ────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -121,10 +80,7 @@ class UserCreate(BaseModel):
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
try:
|
||||
_validate_password_strength(v)
|
||||
except ValueError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
validate_password_strength(v)
|
||||
return v
|
||||
|
||||
|
||||
@@ -264,7 +220,7 @@ async def create_user(
|
||||
session.add(quota)
|
||||
await session.flush() # persist User + Quota before audit_log FK references them
|
||||
# D-13: admin user created event
|
||||
_ip_addr = _ip(request)
|
||||
_ip_addr = get_client_ip(request)
|
||||
await write_audit_log(
|
||||
session,
|
||||
event_type="admin.user_created",
|
||||
@@ -316,7 +272,7 @@ async def update_user_status(
|
||||
detail="Cannot deactivate the only admin",
|
||||
)
|
||||
|
||||
_ip_addr = _ip(request)
|
||||
_ip_addr = get_client_ip(request)
|
||||
user.is_active = body.is_active
|
||||
|
||||
if not body.is_active:
|
||||
@@ -426,7 +382,7 @@ async def update_user_quota(
|
||||
else None
|
||||
)
|
||||
|
||||
_ip_addr = _ip(request)
|
||||
_ip_addr = get_client_ip(request)
|
||||
old_limit = quota.limit_bytes
|
||||
quota.limit_bytes = body.limit_bytes
|
||||
session.add(quota)
|
||||
@@ -471,7 +427,7 @@ async def update_ai_config(
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
_ip_addr = _ip(request)
|
||||
_ip_addr = get_client_ip(request)
|
||||
user.ai_provider = body.ai_provider
|
||||
user.ai_model = body.ai_model
|
||||
session.add(user)
|
||||
@@ -532,7 +488,7 @@ async def delete_user(
|
||||
detail="Cannot delete admin accounts",
|
||||
)
|
||||
|
||||
_ip_addr = _ip(request)
|
||||
_ip_addr = get_client_ip(request)
|
||||
|
||||
# SEC-09 (cloud): purge cloud-stored documents and credentials BEFORE DB delete.
|
||||
# Must run before MinIO cleanup so that credentials are still available to build
|
||||
|
||||
+19
-46
@@ -19,7 +19,6 @@ Security invariants:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal, Optional
|
||||
|
||||
@@ -32,6 +31,7 @@ from config import settings
|
||||
from db.models import BackupCode, Quota, RefreshToken, User
|
||||
from deps.auth import get_current_user
|
||||
from deps.db import get_db
|
||||
from deps.utils import get_client_ip
|
||||
from services import auth as auth_service
|
||||
from services.audit import write_audit_log
|
||||
from slowapi import Limiter
|
||||
@@ -43,30 +43,6 @@ router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
# IP-level rate limiter (SEC-02 — 10 req/min on register/login/refresh)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# ── Password strength validation ─────────────────────────────────────────────
|
||||
_PASSWORD_DETAIL = (
|
||||
"Password must be at least 12 characters and include uppercase, "
|
||||
"lowercase, a number, and a special character."
|
||||
)
|
||||
|
||||
|
||||
def _validate_password_strength(password: str) -> bool:
|
||||
"""Return True if password passes all strength rules (AUTH-01).
|
||||
|
||||
Rules: min 12 chars, has uppercase, has lowercase, has digit, has special char.
|
||||
"""
|
||||
if len(password) < 12:
|
||||
return False
|
||||
if not re.search(r"[A-Z]", password):
|
||||
return False
|
||||
if not re.search(r"[a-z]", password):
|
||||
return False
|
||||
if not re.search(r"[0-9]", password):
|
||||
return False
|
||||
if not re.search(r"[^A-Za-z0-9]", password):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ── Request models ────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -132,11 +108,10 @@ async def register(
|
||||
- Inserts User + Quota rows in a single transaction
|
||||
"""
|
||||
# Password strength check
|
||||
if not _validate_password_strength(body.password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=_PASSWORD_DETAIL,
|
||||
)
|
||||
try:
|
||||
auth_service.validate_password_strength(body.password)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
|
||||
|
||||
# HIBP breach check
|
||||
if await auth_service.check_hibp(body.password):
|
||||
@@ -228,7 +203,7 @@ async def login(
|
||||
user: Optional[User] = result.scalar_one_or_none()
|
||||
|
||||
# IP extraction for audit log (used in both success and failure paths)
|
||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
||||
_ip = get_client_ip(request)
|
||||
|
||||
# Verify password (anti-enumeration: same error regardless of whether user exists)
|
||||
if user is None or not auth_service.verify_password(body.password, user.password_hash):
|
||||
@@ -385,7 +360,7 @@ async def logout(request: Request, response: Response, session: AsyncSession = D
|
||||
"""Revoke current refresh token and clear the cookie."""
|
||||
import hashlib as _hashlib
|
||||
|
||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
||||
_ip = get_client_ip(request)
|
||||
|
||||
raw_token = request.cookies.get("refresh_token")
|
||||
_logout_user_id = None
|
||||
@@ -423,7 +398,7 @@ async def logout_all(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Sign out of all devices: revoke all refresh tokens for current user."""
|
||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
||||
_ip = get_client_ip(request)
|
||||
count = await auth_service.revoke_all_refresh_tokens(session, current_user.id)
|
||||
# D-13: sign-out-all event
|
||||
await write_audit_log(
|
||||
@@ -497,14 +472,13 @@ async def change_password(
|
||||
)
|
||||
|
||||
# Password strength check
|
||||
if not _validate_password_strength(body.new_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=_PASSWORD_DETAIL,
|
||||
)
|
||||
try:
|
||||
auth_service.validate_password_strength(body.new_password)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
|
||||
|
||||
# Update password
|
||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
||||
_ip = get_client_ip(request)
|
||||
user = await session.get(User, current_user.id)
|
||||
user.password_hash = auth_service.hash_password(body.new_password)
|
||||
# D-13: password changed event (flush within same transaction before commit)
|
||||
@@ -594,7 +568,7 @@ async def enable_totp(
|
||||
await auth_service.store_backup_codes(session, current_user.id, plain_codes)
|
||||
|
||||
# D-13: TOTP enrolled event
|
||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
||||
_ip = get_client_ip(request)
|
||||
await write_audit_log(
|
||||
session,
|
||||
event_type="auth.totp_enrolled",
|
||||
@@ -620,7 +594,7 @@ async def disable_totp(
|
||||
|
||||
Clears totp_secret, sets totp_enabled=False, and deletes all backup codes.
|
||||
"""
|
||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
||||
_ip = get_client_ip(request)
|
||||
user = await session.get(User, current_user.id)
|
||||
user.totp_enabled = False
|
||||
user.totp_secret = None
|
||||
@@ -699,11 +673,10 @@ async def password_reset_confirm(
|
||||
)
|
||||
|
||||
# Password strength validation
|
||||
if not _validate_password_strength(body.new_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=_PASSWORD_DETAIL,
|
||||
)
|
||||
try:
|
||||
auth_service.validate_password_strength(body.new_password)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
|
||||
|
||||
# HIBP breach check (SEC-03)
|
||||
if await auth_service.check_hibp(body.new_password):
|
||||
|
||||
@@ -48,14 +48,7 @@ except ImportError:
|
||||
# Fallback for test environments where minio is not installed
|
||||
S3Error = Exception # type: ignore[assignment,misc]
|
||||
|
||||
try:
|
||||
from storage.google_drive_backend import CloudConnectionError
|
||||
except ImportError:
|
||||
# Fallback: define a stub so the except clause compiles even if google deps absent
|
||||
class CloudConnectionError(Exception): # type: ignore[no-redef]
|
||||
def __init__(self, msg: str = "", *, reason: str = "") -> None:
|
||||
super().__init__(msg)
|
||||
self.reason = reason
|
||||
from storage.exceptions import CloudConnectionError
|
||||
|
||||
# Valid cloud backend slugs (T-05-06-01: validated against allowlist, not user-supplied string)
|
||||
_CLOUD_PROVIDERS = frozenset({"google_drive", "onedrive", "nextcloud", "webdav"})
|
||||
|
||||
+4
-11
@@ -30,6 +30,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from db.models import Document, Folder, Quota, Share, User
|
||||
from deps.auth import get_regular_user
|
||||
from deps.db import get_db
|
||||
from deps.utils import get_client_ip
|
||||
from services.audit import write_audit_log
|
||||
from storage import get_storage_backend
|
||||
|
||||
@@ -51,14 +52,6 @@ class DocumentMove(BaseModel):
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
# ── Helper: extract IP address ────────────────────────────────────────────────
|
||||
|
||||
def _get_ip(request: Request) -> Optional[str]:
|
||||
"""Extract client IP, honouring X-Forwarded-For for reverse proxy setups (Pitfall 5)."""
|
||||
return request.headers.get("X-Forwarded-For") or (
|
||||
request.client.host if request.client else None
|
||||
)
|
||||
|
||||
|
||||
# ── Helper: folder serialization ──────────────────────────────────────────────
|
||||
|
||||
@@ -148,7 +141,7 @@ async def create_folder(
|
||||
user_id=current_user.id,
|
||||
actor_id=current_user.id,
|
||||
resource_id=folder.id,
|
||||
ip_address=_get_ip(request),
|
||||
ip_address=get_client_ip(request),
|
||||
metadata_={"name": folder.name, "parent_id": str(parent_uuid) if parent_uuid else None},
|
||||
)
|
||||
|
||||
@@ -316,7 +309,7 @@ async def rename_folder(
|
||||
user_id=current_user.id,
|
||||
actor_id=current_user.id,
|
||||
resource_id=folder.id,
|
||||
ip_address=_get_ip(request),
|
||||
ip_address=get_client_ip(request),
|
||||
metadata_={"old_name": old_name, "new_name": folder.name},
|
||||
)
|
||||
|
||||
@@ -436,7 +429,7 @@ async def delete_folder(
|
||||
user_id=current_user.id,
|
||||
actor_id=current_user.id,
|
||||
resource_id=uid,
|
||||
ip_address=_get_ip(request),
|
||||
ip_address=get_client_ip(request),
|
||||
metadata_={"name": folder_name, "doc_count": len(docs)},
|
||||
)
|
||||
|
||||
|
||||
+4
-18
@@ -27,6 +27,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from db.models import Document, Share, User
|
||||
from deps.auth import get_regular_user
|
||||
from deps.db import get_db
|
||||
from deps.utils import get_client_ip
|
||||
from services.audit import write_audit_log
|
||||
|
||||
router = APIRouter(prefix="/api/shares", tags=["shares"])
|
||||
@@ -62,21 +63,6 @@ class SharePermissionPatch(BaseModel):
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ip(request: Request) -> Optional[str]:
|
||||
"""Extract best-effort client IP from request (behind proxy or direct).
|
||||
|
||||
TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be
|
||||
forged by any caller. This value is used for forensic audit logging only —
|
||||
not for authentication or access control decisions. In production, deploy
|
||||
behind a trusted reverse proxy (e.g. nginx with
|
||||
`proxy_set_header X-Forwarded-For $remote_addr;`) which overwrites this
|
||||
header with the real remote IP before it reaches FastAPI, or use a
|
||||
trusted-proxy middleware that validates the source CIDR.
|
||||
"""
|
||||
return request.headers.get("X-Forwarded-For") or (
|
||||
request.client.host if request.client else None
|
||||
)
|
||||
|
||||
|
||||
# ── POST /api/shares ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -141,7 +127,7 @@ async def grant_share(
|
||||
user_id=current_user.id,
|
||||
actor_id=current_user.id,
|
||||
resource_id=uid,
|
||||
ip_address=_ip(request),
|
||||
ip_address=get_client_ip(request),
|
||||
metadata_={"recipient_id": str(recipient.id)},
|
||||
)
|
||||
|
||||
@@ -283,7 +269,7 @@ async def update_share_permission(
|
||||
user_id=current_user.id,
|
||||
actor_id=current_user.id,
|
||||
resource_id=share.document_id,
|
||||
ip_address=_ip(request),
|
||||
ip_address=get_client_ip(request),
|
||||
metadata_={"share_id": str(share.id), "new_permission": body.permission},
|
||||
)
|
||||
await session.commit()
|
||||
@@ -328,7 +314,7 @@ async def revoke_share(
|
||||
user_id=current_user.id,
|
||||
actor_id=current_user.id,
|
||||
resource_id=document_id,
|
||||
ip_address=_ip(request),
|
||||
ip_address=get_client_ip(request),
|
||||
metadata_={"recipient_id": str(recipient_id)},
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Shared dependency utilities — request parsing helpers used across all API routers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> Optional[str]:
|
||||
"""Extract best-effort client IP from request for audit logging.
|
||||
|
||||
TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be
|
||||
forged by any caller. This value is used for forensic audit logging only —
|
||||
not for authentication or access control decisions. In production, deploy
|
||||
behind a trusted reverse proxy (e.g. nginx with
|
||||
``proxy_set_header X-Forwarded-For $remote_addr;``) which overwrites this
|
||||
header with the real remote IP before it reaches FastAPI.
|
||||
"""
|
||||
return request.headers.get("X-Forwarded-For") or (
|
||||
request.client.host if request.client else None
|
||||
)
|
||||
|
||||
|
||||
def parse_uuid(value: str, detail: str = "Not found") -> uuid.UUID:
|
||||
"""Parse *value* as a UUID, raising HTTP 404 with *detail* on failure.
|
||||
|
||||
Use at API boundaries to convert path/body string IDs to UUID objects.
|
||||
Returns the parsed UUID so callers can use it directly without a try/except.
|
||||
"""
|
||||
try:
|
||||
return uuid.UUID(value)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail=detail)
|
||||
@@ -20,6 +20,7 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
@@ -36,6 +37,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from config import settings
|
||||
from db.models import BackupCode, Quota, RefreshToken, User
|
||||
|
||||
_PASSWORD_DETAIL = (
|
||||
"Password must be at least 12 characters and include uppercase, "
|
||||
"lowercase, a number, and a special character."
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Password hashing ────────────────────────────────────────────────────────────
|
||||
@@ -59,6 +65,22 @@ def verify_password(plain: str, hashed: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def validate_password_strength(password: str) -> None:
|
||||
"""Raise ValueError with a descriptive message if *password* fails any strength rule.
|
||||
|
||||
Rules (AUTH-01): min 12 chars, uppercase, lowercase, digit, special char.
|
||||
Callers at the API boundary should catch ValueError and map it to HTTP 422.
|
||||
"""
|
||||
if (
|
||||
len(password) < 12
|
||||
or not re.search(r"[A-Z]", password)
|
||||
or not re.search(r"[a-z]", password)
|
||||
or not re.search(r"[0-9]", password)
|
||||
or not re.search(r"[^A-Za-z0-9]", password)
|
||||
):
|
||||
raise ValueError(_PASSWORD_DETAIL)
|
||||
|
||||
|
||||
# ── JWT helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def create_access_token(user_id: str, role: str) -> str:
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Storage exception types — import from here, never redefine elsewhere."""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class CloudConnectionError(Exception):
|
||||
"""Raised when a cloud provider signals a non-retryable connection problem.
|
||||
|
||||
Attributes:
|
||||
reason: "token_expired" — access token expired; API layer can refresh and retry.
|
||||
"invalid_grant" — refresh token revoked; user must reconnect.
|
||||
|
||||
The backend never updates the DB. The API layer (_call_cloud_op in cloud.py)
|
||||
catches this exception, performs the DB state transition, and decides whether
|
||||
to retry or surface a 503 to the client (B2 design, D-05/D-06).
|
||||
"""
|
||||
|
||||
def __init__(self, msg: str = "", *, reason: str = "") -> None:
|
||||
super().__init__(msg)
|
||||
self.reason = reason # "token_expired" | "invalid_grant"
|
||||
@@ -37,23 +37,7 @@ from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
|
||||
|
||||
from storage.base import StorageBackend
|
||||
|
||||
|
||||
class CloudConnectionError(Exception):
|
||||
"""Raised when a cloud provider signals a non-retryable connection problem.
|
||||
|
||||
Attributes:
|
||||
reason: "token_expired" — access token expired; API layer can refresh and retry.
|
||||
"invalid_grant" — refresh token revoked; user must reconnect.
|
||||
|
||||
The backend never updates the DB. The API layer (_call_cloud_op in cloud.py)
|
||||
catches this exception, performs the DB state transition, and decides whether
|
||||
to retry or surface a 503 to the client (B2 design, D-05/D-06).
|
||||
"""
|
||||
|
||||
def __init__(self, msg: str = "", *, reason: str = "") -> None:
|
||||
super().__init__(msg)
|
||||
self.reason = reason # "token_expired" | "invalid_grant"
|
||||
from storage.exceptions import CloudConnectionError # noqa: F401 re-exported for import compatibility
|
||||
|
||||
|
||||
class GoogleDriveBackend(StorageBackend):
|
||||
|
||||
@@ -76,12 +76,7 @@ async def _run(document_id: str) -> dict:
|
||||
if user is None:
|
||||
return {"document_id": document_id, "status": "missing_user"}
|
||||
|
||||
try:
|
||||
from storage.google_drive_backend import CloudConnectionError
|
||||
except ImportError:
|
||||
class CloudConnectionError(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
from storage.exceptions import CloudConnectionError
|
||||
try:
|
||||
backend = await get_storage_backend_for_document(doc, user, session)
|
||||
file_bytes = await backend.get_object(doc.object_key)
|
||||
|
||||
@@ -4,7 +4,7 @@ Uses a mock provider — no real AI calls made.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
from ai.openai_provider import _parse_classification, _parse_suggestions, _strip_code_fences
|
||||
from ai.utils import parse_classification as _parse_classification, parse_suggestions as _parse_suggestions, strip_code_fences as _strip_code_fences
|
||||
from ai.base import ClassificationResult
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user