refactor(backend): extract shared helper modules per architecture rules
- Add backend/ai/utils.py — parse_classification, parse_suggestions, strip_code_fences shared by all AI providers; removes duplicated private functions from anthropic_provider.py and openai_provider.py - Add backend/deps/utils.py — get_client_ip, parse_uuid request-parsing helpers; removes local _ip() variants from admin.py, auth.py, shares.py, folders.py - Add backend/storage/exceptions.py — canonical CloudConnectionError definition; all routers and backends import from here instead of redefining - Move validate_password_strength to backend/services/auth.py; removes duplicated _validate_password_strength from admin.py and auth.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from ai.base import AIProvider, ClassificationResult
|
from ai.base import AIProvider, ClassificationResult
|
||||||
|
from ai.utils import parse_classification, parse_suggestions
|
||||||
|
|
||||||
MAX_AI_CHARS = 8_000
|
MAX_AI_CHARS = 8_000
|
||||||
|
|
||||||
@@ -33,7 +32,7 @@ class AnthropicProvider(AIProvider):
|
|||||||
messages=[{"role": "user", "content": user_msg}],
|
messages=[{"role": "user", "content": user_msg}],
|
||||||
)
|
)
|
||||||
raw = response.content[0].text
|
raw = response.content[0].text
|
||||||
return _parse_classification(raw)
|
return parse_classification(raw)
|
||||||
|
|
||||||
async def suggest_topics(
|
async def suggest_topics(
|
||||||
self,
|
self,
|
||||||
@@ -53,7 +52,7 @@ class AnthropicProvider(AIProvider):
|
|||||||
messages=[{"role": "user", "content": user_msg}],
|
messages=[{"role": "user", "content": user_msg}],
|
||||||
)
|
)
|
||||||
raw = response.content[0].text
|
raw = response.content[0].text
|
||||||
return _parse_suggestions(raw)
|
return parse_suggestions(raw)
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
async def health_check(self) -> bool:
|
||||||
try:
|
try:
|
||||||
@@ -68,36 +67,3 @@ class AnthropicProvider(AIProvider):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _strip_code_fences(text: str) -> str:
|
|
||||||
text = re.sub(r"```(?:json)?\s*", "", text)
|
|
||||||
text = re.sub(r"```", "", text)
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_classification(raw: str) -> ClassificationResult:
|
|
||||||
raw = _strip_code_fences(raw)
|
|
||||||
# Try to find JSON object
|
|
||||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
data = json.loads(match.group())
|
|
||||||
return ClassificationResult(
|
|
||||||
topics=data.get("assigned_topics", []),
|
|
||||||
suggested_new_topics=data.get("new_topic_suggestions", []),
|
|
||||||
reasoning=data.get("reasoning", ""),
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return ClassificationResult()
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_suggestions(raw: str) -> list[str]:
|
|
||||||
raw = _strip_code_fences(raw)
|
|
||||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
data = json.loads(match.group())
|
|
||||||
return data.get("suggested_topics", [])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return []
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from ai.base import AIProvider, ClassificationResult
|
from ai.base import AIProvider, ClassificationResult
|
||||||
|
from ai.utils import parse_classification, parse_suggestions
|
||||||
|
|
||||||
MAX_AI_CHARS = 8_000
|
MAX_AI_CHARS = 8_000
|
||||||
|
|
||||||
@@ -35,7 +34,7 @@ class OpenAIProvider(AIProvider):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
raw = response.choices[0].message.content or ""
|
raw = response.choices[0].message.content or ""
|
||||||
return _parse_classification(raw)
|
return parse_classification(raw)
|
||||||
|
|
||||||
async def suggest_topics(
|
async def suggest_topics(
|
||||||
self,
|
self,
|
||||||
@@ -56,7 +55,7 @@ class OpenAIProvider(AIProvider):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
raw = response.choices[0].message.content or ""
|
raw = response.choices[0].message.content or ""
|
||||||
return _parse_suggestions(raw)
|
return parse_suggestions(raw)
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
async def health_check(self) -> bool:
|
||||||
try:
|
try:
|
||||||
@@ -70,35 +69,3 @@ class OpenAIProvider(AIProvider):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _strip_code_fences(text: str) -> str:
|
|
||||||
text = re.sub(r"```(?:json)?\s*", "", text)
|
|
||||||
text = re.sub(r"```", "", text)
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_classification(raw: str) -> ClassificationResult:
|
|
||||||
raw = _strip_code_fences(raw)
|
|
||||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
data = json.loads(match.group())
|
|
||||||
return ClassificationResult(
|
|
||||||
topics=data.get("assigned_topics", []),
|
|
||||||
suggested_new_topics=data.get("new_topic_suggestions", []),
|
|
||||||
reasoning=data.get("reasoning", ""),
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return ClassificationResult()
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_suggestions(raw: str) -> list[str]:
|
|
||||||
raw = _strip_code_fences(raw)
|
|
||||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
data = json.loads(match.group())
|
|
||||||
return data.get("suggested_topics", [])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return []
|
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
"""Shared AI response parsing utilities — used by all provider implementations."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from ai.base import ClassificationResult
|
||||||
|
|
||||||
|
|
||||||
|
def strip_code_fences(text: str) -> str:
|
||||||
|
"""Remove markdown code fences (```json ... ```) from *text*."""
|
||||||
|
text = re.sub(r"```(?:json)?\s*", "", text)
|
||||||
|
text = re.sub(r"```", "", text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_classification(raw: str) -> ClassificationResult:
|
||||||
|
"""Parse a classification JSON response into a ClassificationResult.
|
||||||
|
|
||||||
|
Tolerates markdown code fences and extracts the first JSON object found.
|
||||||
|
Returns an empty ClassificationResult on any parse failure.
|
||||||
|
"""
|
||||||
|
raw = strip_code_fences(raw)
|
||||||
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
data = json.loads(match.group())
|
||||||
|
return ClassificationResult(
|
||||||
|
topics=data.get("assigned_topics", []),
|
||||||
|
suggested_new_topics=data.get("new_topic_suggestions", []),
|
||||||
|
reasoning=data.get("reasoning", ""),
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return ClassificationResult()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_suggestions(raw: str) -> list[str]:
|
||||||
|
"""Parse a topic-suggestion JSON response into a list of topic name strings.
|
||||||
|
|
||||||
|
Tolerates markdown code fences. Returns an empty list on parse failure.
|
||||||
|
"""
|
||||||
|
raw = strip_code_fences(raw)
|
||||||
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
data = json.loads(match.group())
|
||||||
|
return data.get("suggested_topics", [])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
+8
-52
@@ -23,7 +23,6 @@ Security invariants:
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -36,8 +35,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from db.models import CloudConnection, Document, Quota, RefreshToken, Topic, User
|
from db.models import CloudConnection, Document, Quota, RefreshToken, Topic, User
|
||||||
from deps.auth import get_current_admin
|
from deps.auth import get_current_admin
|
||||||
from deps.db import get_db
|
from deps.db import get_db
|
||||||
|
from deps.utils import get_client_ip
|
||||||
from services.audit import write_audit_log
|
from services.audit import write_audit_log
|
||||||
from services.auth import hash_password, revoke_all_refresh_tokens, verify_password
|
from services.auth import hash_password, revoke_all_refresh_tokens, validate_password_strength, verify_password
|
||||||
from storage import get_storage_backend, get_storage_backend_for_document
|
from storage import get_storage_backend, get_storage_backend_for_document
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/admin", tags=["admin"])
|
router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||||
@@ -46,28 +46,6 @@ router = APIRouter(prefix="/api/admin", tags=["admin"])
|
|||||||
|
|
||||||
_DEFAULT_QUOTA_BYTES = 104857600 # 100 MB free-tier default (D-06)
|
_DEFAULT_QUOTA_BYTES = 104857600 # 100 MB free-tier default (D-06)
|
||||||
|
|
||||||
_PASSWORD_DETAIL = (
|
|
||||||
"Password must be at least 12 characters and include uppercase, "
|
|
||||||
"lowercase, a number, and a special character."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── IP extraction helper ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _ip(request: Request) -> Optional[str]:
|
|
||||||
"""Extract best-effort client IP from request for audit logging.
|
|
||||||
|
|
||||||
TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be
|
|
||||||
forged by any caller. This value is used for forensic audit logging only —
|
|
||||||
not for authentication or access control decisions. In production, deploy
|
|
||||||
behind a trusted reverse proxy (e.g. nginx with
|
|
||||||
`proxy_set_header X-Forwarded-For $remote_addr;`) which overwrites this
|
|
||||||
header with the real remote IP before it reaches FastAPI, or use a
|
|
||||||
trusted-proxy middleware that validates the source CIDR.
|
|
||||||
"""
|
|
||||||
return request.headers.get("X-Forwarded-For") or (
|
|
||||||
request.client.host if request.client else None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Safe response helper ──────────────────────────────────────────────────────
|
# ── Safe response helper ──────────────────────────────────────────────────────
|
||||||
@@ -90,25 +68,6 @@ def _user_to_dict(user: User) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ── Password strength helper ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _validate_password_strength(password: str) -> None:
|
|
||||||
"""Raise ValueError with the spec message if password fails any strength rule.
|
|
||||||
|
|
||||||
Rules (AUTH-01): min 12 chars, has uppercase, has lowercase, has digit,
|
|
||||||
has special char (non-alphanumeric).
|
|
||||||
"""
|
|
||||||
if len(password) < 12:
|
|
||||||
raise ValueError(_PASSWORD_DETAIL)
|
|
||||||
if not re.search(r"[A-Z]", password):
|
|
||||||
raise ValueError(_PASSWORD_DETAIL)
|
|
||||||
if not re.search(r"[a-z]", password):
|
|
||||||
raise ValueError(_PASSWORD_DETAIL)
|
|
||||||
if not re.search(r"[0-9]", password):
|
|
||||||
raise ValueError(_PASSWORD_DETAIL)
|
|
||||||
if not re.search(r"[^A-Za-z0-9]", password):
|
|
||||||
raise ValueError(_PASSWORD_DETAIL)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request models ────────────────────────────────────────────────────────────
|
# ── Request models ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -121,10 +80,7 @@ class UserCreate(BaseModel):
|
|||||||
@field_validator("password")
|
@field_validator("password")
|
||||||
@classmethod
|
@classmethod
|
||||||
def password_strength(cls, v: str) -> str:
|
def password_strength(cls, v: str) -> str:
|
||||||
try:
|
validate_password_strength(v)
|
||||||
_validate_password_strength(v)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise ValueError(str(exc)) from exc
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
@@ -264,7 +220,7 @@ async def create_user(
|
|||||||
session.add(quota)
|
session.add(quota)
|
||||||
await session.flush() # persist User + Quota before audit_log FK references them
|
await session.flush() # persist User + Quota before audit_log FK references them
|
||||||
# D-13: admin user created event
|
# D-13: admin user created event
|
||||||
_ip_addr = _ip(request)
|
_ip_addr = get_client_ip(request)
|
||||||
await write_audit_log(
|
await write_audit_log(
|
||||||
session,
|
session,
|
||||||
event_type="admin.user_created",
|
event_type="admin.user_created",
|
||||||
@@ -316,7 +272,7 @@ async def update_user_status(
|
|||||||
detail="Cannot deactivate the only admin",
|
detail="Cannot deactivate the only admin",
|
||||||
)
|
)
|
||||||
|
|
||||||
_ip_addr = _ip(request)
|
_ip_addr = get_client_ip(request)
|
||||||
user.is_active = body.is_active
|
user.is_active = body.is_active
|
||||||
|
|
||||||
if not body.is_active:
|
if not body.is_active:
|
||||||
@@ -426,7 +382,7 @@ async def update_user_quota(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
_ip_addr = _ip(request)
|
_ip_addr = get_client_ip(request)
|
||||||
old_limit = quota.limit_bytes
|
old_limit = quota.limit_bytes
|
||||||
quota.limit_bytes = body.limit_bytes
|
quota.limit_bytes = body.limit_bytes
|
||||||
session.add(quota)
|
session.add(quota)
|
||||||
@@ -471,7 +427,7 @@ async def update_ai_config(
|
|||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||||
|
|
||||||
_ip_addr = _ip(request)
|
_ip_addr = get_client_ip(request)
|
||||||
user.ai_provider = body.ai_provider
|
user.ai_provider = body.ai_provider
|
||||||
user.ai_model = body.ai_model
|
user.ai_model = body.ai_model
|
||||||
session.add(user)
|
session.add(user)
|
||||||
@@ -532,7 +488,7 @@ async def delete_user(
|
|||||||
detail="Cannot delete admin accounts",
|
detail="Cannot delete admin accounts",
|
||||||
)
|
)
|
||||||
|
|
||||||
_ip_addr = _ip(request)
|
_ip_addr = get_client_ip(request)
|
||||||
|
|
||||||
# SEC-09 (cloud): purge cloud-stored documents and credentials BEFORE DB delete.
|
# SEC-09 (cloud): purge cloud-stored documents and credentials BEFORE DB delete.
|
||||||
# Must run before MinIO cleanup so that credentials are still available to build
|
# Must run before MinIO cleanup so that credentials are still available to build
|
||||||
|
|||||||
+19
-46
@@ -19,7 +19,6 @@ Security invariants:
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
@@ -32,6 +31,7 @@ from config import settings
|
|||||||
from db.models import BackupCode, Quota, RefreshToken, User
|
from db.models import BackupCode, Quota, RefreshToken, User
|
||||||
from deps.auth import get_current_user
|
from deps.auth import get_current_user
|
||||||
from deps.db import get_db
|
from deps.db import get_db
|
||||||
|
from deps.utils import get_client_ip
|
||||||
from services import auth as auth_service
|
from services import auth as auth_service
|
||||||
from services.audit import write_audit_log
|
from services.audit import write_audit_log
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
@@ -43,30 +43,6 @@ router = APIRouter(prefix="/api/auth", tags=["auth"])
|
|||||||
# IP-level rate limiter (SEC-02 — 10 req/min on register/login/refresh)
|
# IP-level rate limiter (SEC-02 — 10 req/min on register/login/refresh)
|
||||||
limiter = Limiter(key_func=get_remote_address)
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
# ── Password strength validation ─────────────────────────────────────────────
|
|
||||||
_PASSWORD_DETAIL = (
|
|
||||||
"Password must be at least 12 characters and include uppercase, "
|
|
||||||
"lowercase, a number, and a special character."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_password_strength(password: str) -> bool:
|
|
||||||
"""Return True if password passes all strength rules (AUTH-01).
|
|
||||||
|
|
||||||
Rules: min 12 chars, has uppercase, has lowercase, has digit, has special char.
|
|
||||||
"""
|
|
||||||
if len(password) < 12:
|
|
||||||
return False
|
|
||||||
if not re.search(r"[A-Z]", password):
|
|
||||||
return False
|
|
||||||
if not re.search(r"[a-z]", password):
|
|
||||||
return False
|
|
||||||
if not re.search(r"[0-9]", password):
|
|
||||||
return False
|
|
||||||
if not re.search(r"[^A-Za-z0-9]", password):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request models ────────────────────────────────────────────────────────────
|
# ── Request models ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -132,11 +108,10 @@ async def register(
|
|||||||
- Inserts User + Quota rows in a single transaction
|
- Inserts User + Quota rows in a single transaction
|
||||||
"""
|
"""
|
||||||
# Password strength check
|
# Password strength check
|
||||||
if not _validate_password_strength(body.password):
|
try:
|
||||||
raise HTTPException(
|
auth_service.validate_password_strength(body.password)
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
except ValueError as exc:
|
||||||
detail=_PASSWORD_DETAIL,
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
|
||||||
)
|
|
||||||
|
|
||||||
# HIBP breach check
|
# HIBP breach check
|
||||||
if await auth_service.check_hibp(body.password):
|
if await auth_service.check_hibp(body.password):
|
||||||
@@ -228,7 +203,7 @@ async def login(
|
|||||||
user: Optional[User] = result.scalar_one_or_none()
|
user: Optional[User] = result.scalar_one_or_none()
|
||||||
|
|
||||||
# IP extraction for audit log (used in both success and failure paths)
|
# IP extraction for audit log (used in both success and failure paths)
|
||||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
_ip = get_client_ip(request)
|
||||||
|
|
||||||
# Verify password (anti-enumeration: same error regardless of whether user exists)
|
# Verify password (anti-enumeration: same error regardless of whether user exists)
|
||||||
if user is None or not auth_service.verify_password(body.password, user.password_hash):
|
if user is None or not auth_service.verify_password(body.password, user.password_hash):
|
||||||
@@ -385,7 +360,7 @@ async def logout(request: Request, response: Response, session: AsyncSession = D
|
|||||||
"""Revoke current refresh token and clear the cookie."""
|
"""Revoke current refresh token and clear the cookie."""
|
||||||
import hashlib as _hashlib
|
import hashlib as _hashlib
|
||||||
|
|
||||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
_ip = get_client_ip(request)
|
||||||
|
|
||||||
raw_token = request.cookies.get("refresh_token")
|
raw_token = request.cookies.get("refresh_token")
|
||||||
_logout_user_id = None
|
_logout_user_id = None
|
||||||
@@ -423,7 +398,7 @@ async def logout_all(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Sign out of all devices: revoke all refresh tokens for current user."""
|
"""Sign out of all devices: revoke all refresh tokens for current user."""
|
||||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
_ip = get_client_ip(request)
|
||||||
count = await auth_service.revoke_all_refresh_tokens(session, current_user.id)
|
count = await auth_service.revoke_all_refresh_tokens(session, current_user.id)
|
||||||
# D-13: sign-out-all event
|
# D-13: sign-out-all event
|
||||||
await write_audit_log(
|
await write_audit_log(
|
||||||
@@ -497,14 +472,13 @@ async def change_password(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Password strength check
|
# Password strength check
|
||||||
if not _validate_password_strength(body.new_password):
|
try:
|
||||||
raise HTTPException(
|
auth_service.validate_password_strength(body.new_password)
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
except ValueError as exc:
|
||||||
detail=_PASSWORD_DETAIL,
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
|
||||||
)
|
|
||||||
|
|
||||||
# Update password
|
# Update password
|
||||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
_ip = get_client_ip(request)
|
||||||
user = await session.get(User, current_user.id)
|
user = await session.get(User, current_user.id)
|
||||||
user.password_hash = auth_service.hash_password(body.new_password)
|
user.password_hash = auth_service.hash_password(body.new_password)
|
||||||
# D-13: password changed event (flush within same transaction before commit)
|
# D-13: password changed event (flush within same transaction before commit)
|
||||||
@@ -594,7 +568,7 @@ async def enable_totp(
|
|||||||
await auth_service.store_backup_codes(session, current_user.id, plain_codes)
|
await auth_service.store_backup_codes(session, current_user.id, plain_codes)
|
||||||
|
|
||||||
# D-13: TOTP enrolled event
|
# D-13: TOTP enrolled event
|
||||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
_ip = get_client_ip(request)
|
||||||
await write_audit_log(
|
await write_audit_log(
|
||||||
session,
|
session,
|
||||||
event_type="auth.totp_enrolled",
|
event_type="auth.totp_enrolled",
|
||||||
@@ -620,7 +594,7 @@ async def disable_totp(
|
|||||||
|
|
||||||
Clears totp_secret, sets totp_enabled=False, and deletes all backup codes.
|
Clears totp_secret, sets totp_enabled=False, and deletes all backup codes.
|
||||||
"""
|
"""
|
||||||
_ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None)
|
_ip = get_client_ip(request)
|
||||||
user = await session.get(User, current_user.id)
|
user = await session.get(User, current_user.id)
|
||||||
user.totp_enabled = False
|
user.totp_enabled = False
|
||||||
user.totp_secret = None
|
user.totp_secret = None
|
||||||
@@ -699,11 +673,10 @@ async def password_reset_confirm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Password strength validation
|
# Password strength validation
|
||||||
if not _validate_password_strength(body.new_password):
|
try:
|
||||||
raise HTTPException(
|
auth_service.validate_password_strength(body.new_password)
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
except ValueError as exc:
|
||||||
detail=_PASSWORD_DETAIL,
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc))
|
||||||
)
|
|
||||||
|
|
||||||
# HIBP breach check (SEC-03)
|
# HIBP breach check (SEC-03)
|
||||||
if await auth_service.check_hibp(body.new_password):
|
if await auth_service.check_hibp(body.new_password):
|
||||||
|
|||||||
@@ -48,14 +48,7 @@ except ImportError:
|
|||||||
# Fallback for test environments where minio is not installed
|
# Fallback for test environments where minio is not installed
|
||||||
S3Error = Exception # type: ignore[assignment,misc]
|
S3Error = Exception # type: ignore[assignment,misc]
|
||||||
|
|
||||||
try:
|
from storage.exceptions import CloudConnectionError
|
||||||
from storage.google_drive_backend import CloudConnectionError
|
|
||||||
except ImportError:
|
|
||||||
# Fallback: define a stub so the except clause compiles even if google deps absent
|
|
||||||
class CloudConnectionError(Exception): # type: ignore[no-redef]
|
|
||||||
def __init__(self, msg: str = "", *, reason: str = "") -> None:
|
|
||||||
super().__init__(msg)
|
|
||||||
self.reason = reason
|
|
||||||
|
|
||||||
# Valid cloud backend slugs (T-05-06-01: validated against allowlist, not user-supplied string)
|
# Valid cloud backend slugs (T-05-06-01: validated against allowlist, not user-supplied string)
|
||||||
_CLOUD_PROVIDERS = frozenset({"google_drive", "onedrive", "nextcloud", "webdav"})
|
_CLOUD_PROVIDERS = frozenset({"google_drive", "onedrive", "nextcloud", "webdav"})
|
||||||
|
|||||||
+4
-11
@@ -30,6 +30,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from db.models import Document, Folder, Quota, Share, User
|
from db.models import Document, Folder, Quota, Share, User
|
||||||
from deps.auth import get_regular_user
|
from deps.auth import get_regular_user
|
||||||
from deps.db import get_db
|
from deps.db import get_db
|
||||||
|
from deps.utils import get_client_ip
|
||||||
from services.audit import write_audit_log
|
from services.audit import write_audit_log
|
||||||
from storage import get_storage_backend
|
from storage import get_storage_backend
|
||||||
|
|
||||||
@@ -51,14 +52,6 @@ class DocumentMove(BaseModel):
|
|||||||
folder_id: Optional[str] = None
|
folder_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# ── Helper: extract IP address ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _get_ip(request: Request) -> Optional[str]:
|
|
||||||
"""Extract client IP, honouring X-Forwarded-For for reverse proxy setups (Pitfall 5)."""
|
|
||||||
return request.headers.get("X-Forwarded-For") or (
|
|
||||||
request.client.host if request.client else None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helper: folder serialization ──────────────────────────────────────────────
|
# ── Helper: folder serialization ──────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -148,7 +141,7 @@ async def create_folder(
|
|||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
actor_id=current_user.id,
|
actor_id=current_user.id,
|
||||||
resource_id=folder.id,
|
resource_id=folder.id,
|
||||||
ip_address=_get_ip(request),
|
ip_address=get_client_ip(request),
|
||||||
metadata_={"name": folder.name, "parent_id": str(parent_uuid) if parent_uuid else None},
|
metadata_={"name": folder.name, "parent_id": str(parent_uuid) if parent_uuid else None},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -316,7 +309,7 @@ async def rename_folder(
|
|||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
actor_id=current_user.id,
|
actor_id=current_user.id,
|
||||||
resource_id=folder.id,
|
resource_id=folder.id,
|
||||||
ip_address=_get_ip(request),
|
ip_address=get_client_ip(request),
|
||||||
metadata_={"old_name": old_name, "new_name": folder.name},
|
metadata_={"old_name": old_name, "new_name": folder.name},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -436,7 +429,7 @@ async def delete_folder(
|
|||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
actor_id=current_user.id,
|
actor_id=current_user.id,
|
||||||
resource_id=uid,
|
resource_id=uid,
|
||||||
ip_address=_get_ip(request),
|
ip_address=get_client_ip(request),
|
||||||
metadata_={"name": folder_name, "doc_count": len(docs)},
|
metadata_={"name": folder_name, "doc_count": len(docs)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+4
-18
@@ -27,6 +27,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from db.models import Document, Share, User
|
from db.models import Document, Share, User
|
||||||
from deps.auth import get_regular_user
|
from deps.auth import get_regular_user
|
||||||
from deps.db import get_db
|
from deps.db import get_db
|
||||||
|
from deps.utils import get_client_ip
|
||||||
from services.audit import write_audit_log
|
from services.audit import write_audit_log
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/shares", tags=["shares"])
|
router = APIRouter(prefix="/api/shares", tags=["shares"])
|
||||||
@@ -62,21 +63,6 @@ class SharePermissionPatch(BaseModel):
|
|||||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _ip(request: Request) -> Optional[str]:
|
|
||||||
"""Extract best-effort client IP from request (behind proxy or direct).
|
|
||||||
|
|
||||||
TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be
|
|
||||||
forged by any caller. This value is used for forensic audit logging only —
|
|
||||||
not for authentication or access control decisions. In production, deploy
|
|
||||||
behind a trusted reverse proxy (e.g. nginx with
|
|
||||||
`proxy_set_header X-Forwarded-For $remote_addr;`) which overwrites this
|
|
||||||
header with the real remote IP before it reaches FastAPI, or use a
|
|
||||||
trusted-proxy middleware that validates the source CIDR.
|
|
||||||
"""
|
|
||||||
return request.headers.get("X-Forwarded-For") or (
|
|
||||||
request.client.host if request.client else None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── POST /api/shares ──────────────────────────────────────────────────────────
|
# ── POST /api/shares ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -141,7 +127,7 @@ async def grant_share(
|
|||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
actor_id=current_user.id,
|
actor_id=current_user.id,
|
||||||
resource_id=uid,
|
resource_id=uid,
|
||||||
ip_address=_ip(request),
|
ip_address=get_client_ip(request),
|
||||||
metadata_={"recipient_id": str(recipient.id)},
|
metadata_={"recipient_id": str(recipient.id)},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -283,7 +269,7 @@ async def update_share_permission(
|
|||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
actor_id=current_user.id,
|
actor_id=current_user.id,
|
||||||
resource_id=share.document_id,
|
resource_id=share.document_id,
|
||||||
ip_address=_ip(request),
|
ip_address=get_client_ip(request),
|
||||||
metadata_={"share_id": str(share.id), "new_permission": body.permission},
|
metadata_={"share_id": str(share.id), "new_permission": body.permission},
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -328,7 +314,7 @@ async def revoke_share(
|
|||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
actor_id=current_user.id,
|
actor_id=current_user.id,
|
||||||
resource_id=document_id,
|
resource_id=document_id,
|
||||||
ip_address=_ip(request),
|
ip_address=get_client_ip(request),
|
||||||
metadata_={"recipient_id": str(recipient_id)},
|
metadata_={"recipient_id": str(recipient_id)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,34 @@
|
|||||||
|
"""Shared dependency utilities — request parsing helpers used across all API routers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_ip(request: Request) -> Optional[str]:
|
||||||
|
"""Extract best-effort client IP from request for audit logging.
|
||||||
|
|
||||||
|
TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be
|
||||||
|
forged by any caller. This value is used for forensic audit logging only —
|
||||||
|
not for authentication or access control decisions. In production, deploy
|
||||||
|
behind a trusted reverse proxy (e.g. nginx with
|
||||||
|
``proxy_set_header X-Forwarded-For $remote_addr;``) which overwrites this
|
||||||
|
header with the real remote IP before it reaches FastAPI.
|
||||||
|
"""
|
||||||
|
return request.headers.get("X-Forwarded-For") or (
|
||||||
|
request.client.host if request.client else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_uuid(value: str, detail: str = "Not found") -> uuid.UUID:
|
||||||
|
"""Parse *value* as a UUID, raising HTTP 404 with *detail* on failure.
|
||||||
|
|
||||||
|
Use at API boundaries to convert path/body string IDs to UUID objects.
|
||||||
|
Returns the parsed UUID so callers can use it directly without a try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return uuid.UUID(value)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail=detail)
|
||||||
@@ -20,6 +20,7 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
@@ -36,6 +37,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from config import settings
|
from config import settings
|
||||||
from db.models import BackupCode, Quota, RefreshToken, User
|
from db.models import BackupCode, Quota, RefreshToken, User
|
||||||
|
|
||||||
|
_PASSWORD_DETAIL = (
|
||||||
|
"Password must be at least 12 characters and include uppercase, "
|
||||||
|
"lowercase, a number, and a special character."
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ── Password hashing ────────────────────────────────────────────────────────────
|
# ── Password hashing ────────────────────────────────────────────────────────────
|
||||||
@@ -59,6 +65,22 @@ def verify_password(plain: str, hashed: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def validate_password_strength(password: str) -> None:
|
||||||
|
"""Raise ValueError with a descriptive message if *password* fails any strength rule.
|
||||||
|
|
||||||
|
Rules (AUTH-01): min 12 chars, uppercase, lowercase, digit, special char.
|
||||||
|
Callers at the API boundary should catch ValueError and map it to HTTP 422.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
len(password) < 12
|
||||||
|
or not re.search(r"[A-Z]", password)
|
||||||
|
or not re.search(r"[a-z]", password)
|
||||||
|
or not re.search(r"[0-9]", password)
|
||||||
|
or not re.search(r"[^A-Za-z0-9]", password)
|
||||||
|
):
|
||||||
|
raise ValueError(_PASSWORD_DETAIL)
|
||||||
|
|
||||||
|
|
||||||
# ── JWT helpers ─────────────────────────────────────────────────────────────────
|
# ── JWT helpers ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def create_access_token(user_id: str, role: str) -> str:
|
def create_access_token(user_id: str, role: str) -> str:
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
"""Storage exception types — import from here, never redefine elsewhere."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class CloudConnectionError(Exception):
|
||||||
|
"""Raised when a cloud provider signals a non-retryable connection problem.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
reason: "token_expired" — access token expired; API layer can refresh and retry.
|
||||||
|
"invalid_grant" — refresh token revoked; user must reconnect.
|
||||||
|
|
||||||
|
The backend never updates the DB. The API layer (_call_cloud_op in cloud.py)
|
||||||
|
catches this exception, performs the DB state transition, and decides whether
|
||||||
|
to retry or surface a 503 to the client (B2 design, D-05/D-06).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, msg: str = "", *, reason: str = "") -> None:
|
||||||
|
super().__init__(msg)
|
||||||
|
self.reason = reason # "token_expired" | "invalid_grant"
|
||||||
@@ -37,23 +37,7 @@ from googleapiclient.errors import HttpError
|
|||||||
from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
|
from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
|
||||||
|
|
||||||
from storage.base import StorageBackend
|
from storage.base import StorageBackend
|
||||||
|
from storage.exceptions import CloudConnectionError # noqa: F401 re-exported for import compatibility
|
||||||
|
|
||||||
class CloudConnectionError(Exception):
|
|
||||||
"""Raised when a cloud provider signals a non-retryable connection problem.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
reason: "token_expired" — access token expired; API layer can refresh and retry.
|
|
||||||
"invalid_grant" — refresh token revoked; user must reconnect.
|
|
||||||
|
|
||||||
The backend never updates the DB. The API layer (_call_cloud_op in cloud.py)
|
|
||||||
catches this exception, performs the DB state transition, and decides whether
|
|
||||||
to retry or surface a 503 to the client (B2 design, D-05/D-06).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, msg: str = "", *, reason: str = "") -> None:
|
|
||||||
super().__init__(msg)
|
|
||||||
self.reason = reason # "token_expired" | "invalid_grant"
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleDriveBackend(StorageBackend):
|
class GoogleDriveBackend(StorageBackend):
|
||||||
|
|||||||
@@ -76,12 +76,7 @@ async def _run(document_id: str) -> dict:
|
|||||||
if user is None:
|
if user is None:
|
||||||
return {"document_id": document_id, "status": "missing_user"}
|
return {"document_id": document_id, "status": "missing_user"}
|
||||||
|
|
||||||
try:
|
from storage.exceptions import CloudConnectionError
|
||||||
from storage.google_drive_backend import CloudConnectionError
|
|
||||||
except ImportError:
|
|
||||||
class CloudConnectionError(Exception): # type: ignore[no-redef]
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
backend = await get_storage_backend_for_document(doc, user, session)
|
backend = await get_storage_backend_for_document(doc, user, session)
|
||||||
file_bytes = await backend.get_object(doc.object_key)
|
file_bytes = await backend.get_object(doc.object_key)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Uses a mock provider — no real AI calls made.
|
|||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
from ai.openai_provider import _parse_classification, _parse_suggestions, _strip_code_fences
|
from ai.utils import parse_classification as _parse_classification, parse_suggestions as _parse_suggestions, strip_code_fences as _strip_code_fences
|
||||||
from ai.base import ClassificationResult
|
from ai.base import ClassificationResult
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user