diff --git a/backend/ai/anthropic_provider.py b/backend/ai/anthropic_provider.py index cf5859d..ebb8123 100644 --- a/backend/ai/anthropic_provider.py +++ b/backend/ai/anthropic_provider.py @@ -1,7 +1,6 @@ -import json -import re import anthropic from ai.base import AIProvider, ClassificationResult +from ai.utils import parse_classification, parse_suggestions MAX_AI_CHARS = 8_000 @@ -33,7 +32,7 @@ class AnthropicProvider(AIProvider): messages=[{"role": "user", "content": user_msg}], ) raw = response.content[0].text - return _parse_classification(raw) + return parse_classification(raw) async def suggest_topics( self, @@ -53,7 +52,7 @@ class AnthropicProvider(AIProvider): messages=[{"role": "user", "content": user_msg}], ) raw = response.content[0].text - return _parse_suggestions(raw) + return parse_suggestions(raw) async def health_check(self) -> bool: try: @@ -68,36 +67,3 @@ class AnthropicProvider(AIProvider): return False -def _strip_code_fences(text: str) -> str: - text = re.sub(r"```(?:json)?\s*", "", text) - text = re.sub(r"```", "", text) - return text.strip() - - -def _parse_classification(raw: str) -> ClassificationResult: - raw = _strip_code_fences(raw) - # Try to find JSON object - match = re.search(r"\{.*\}", raw, re.DOTALL) - if match: - try: - data = json.loads(match.group()) - return ClassificationResult( - topics=data.get("assigned_topics", []), - suggested_new_topics=data.get("new_topic_suggestions", []), - reasoning=data.get("reasoning", ""), - ) - except json.JSONDecodeError: - pass - return ClassificationResult() - - -def _parse_suggestions(raw: str) -> list[str]: - raw = _strip_code_fences(raw) - match = re.search(r"\{.*\}", raw, re.DOTALL) - if match: - try: - data = json.loads(match.group()) - return data.get("suggested_topics", []) - except json.JSONDecodeError: - pass - return [] diff --git a/backend/ai/openai_provider.py b/backend/ai/openai_provider.py index fe1cdc5..e12af79 100644 --- a/backend/ai/openai_provider.py +++ b/backend/ai/openai_provider.py @@ -1,7 +1,6 @@ -import json -import re from openai import AsyncOpenAI from ai.base import AIProvider, ClassificationResult +from ai.utils import parse_classification, parse_suggestions MAX_AI_CHARS = 8_000 @@ -35,7 +34,7 @@ class OpenAIProvider(AIProvider): ], ) raw = response.choices[0].message.content or "" - return _parse_classification(raw) + return parse_classification(raw) async def suggest_topics( self, @@ -56,7 +55,7 @@ class OpenAIProvider(AIProvider): ], ) raw = response.choices[0].message.content or "" - return _parse_suggestions(raw) + return parse_suggestions(raw) async def health_check(self) -> bool: try: @@ -70,35 +69,3 @@ class OpenAIProvider(AIProvider): return False -def _strip_code_fences(text: str) -> str: - text = re.sub(r"```(?:json)?\s*", "", text) - text = re.sub(r"```", "", text) - return text.strip() - - -def _parse_classification(raw: str) -> ClassificationResult: - raw = _strip_code_fences(raw) - match = re.search(r"\{.*\}", raw, re.DOTALL) - if match: - try: - data = json.loads(match.group()) - return ClassificationResult( - topics=data.get("assigned_topics", []), - suggested_new_topics=data.get("new_topic_suggestions", []), - reasoning=data.get("reasoning", ""), - ) - except json.JSONDecodeError: - pass - return ClassificationResult() - - -def _parse_suggestions(raw: str) -> list[str]: - raw = _strip_code_fences(raw) - match = re.search(r"\{.*\}", raw, re.DOTALL) - if match: - try: - data = json.loads(match.group()) - return data.get("suggested_topics", []) - except json.JSONDecodeError: - pass - return [] diff --git a/backend/ai/utils.py b/backend/ai/utils.py new file mode 100644 index 0000000..7ef6ce0 --- /dev/null +++ b/backend/ai/utils.py @@ -0,0 +1,51 @@ +"""Shared AI response parsing utilities — used by all provider implementations.""" +from __future__ import annotations + +import json +import re + +from ai.base import ClassificationResult + + +def strip_code_fences(text: str) -> str: + """Remove markdown code fences (```json ... ```) from *text*.""" + text = re.sub(r"```(?:json)?\s*", "", text) + text = re.sub(r"```", "", text) + return text.strip() + + +def parse_classification(raw: str) -> ClassificationResult: + """Parse a classification JSON response into a ClassificationResult. + + Tolerates markdown code fences and extracts the first JSON object found. + Returns an empty ClassificationResult on any parse failure. + """ + raw = strip_code_fences(raw) + match = re.search(r"\{.*\}", raw, re.DOTALL) + if match: + try: + data = json.loads(match.group()) + return ClassificationResult( + topics=data.get("assigned_topics", []), + suggested_new_topics=data.get("new_topic_suggestions", []), + reasoning=data.get("reasoning", ""), + ) + except json.JSONDecodeError: + pass + return ClassificationResult() + + +def parse_suggestions(raw: str) -> list[str]: + """Parse a topic-suggestion JSON response into a list of topic name strings. + + Tolerates markdown code fences. Returns an empty list on parse failure. + """ + raw = strip_code_fences(raw) + match = re.search(r"\{.*\}", raw, re.DOTALL) + if match: + try: + data = json.loads(match.group()) + return data.get("suggested_topics", []) + except json.JSONDecodeError: + pass + return [] diff --git a/backend/api/admin.py b/backend/api/admin.py index 2e5a627..137e598 100644 --- a/backend/api/admin.py +++ b/backend/api/admin.py @@ -23,7 +23,6 @@ Security invariants: """ from __future__ import annotations -import re import uuid from datetime import datetime from typing import Optional @@ -36,8 +35,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from db.models import CloudConnection, Document, Quota, RefreshToken, Topic, User from deps.auth import get_current_admin from deps.db import get_db +from deps.utils import get_client_ip from services.audit import write_audit_log -from services.auth import hash_password, revoke_all_refresh_tokens, verify_password +from services.auth import hash_password, revoke_all_refresh_tokens, validate_password_strength, verify_password from storage import get_storage_backend, get_storage_backend_for_document router = APIRouter(prefix="/api/admin", tags=["admin"]) @@ -46,28 +46,6 @@ router = APIRouter(prefix="/api/admin", tags=["admin"]) _DEFAULT_QUOTA_BYTES = 104857600 # 100 MB free-tier default (D-06) -_PASSWORD_DETAIL = ( - "Password must be at least 12 characters and include uppercase, " - "lowercase, a number, and a special character." -) - - -# ── IP extraction helper ────────────────────────────────────────────────────── - -def _ip(request: Request) -> Optional[str]: - """Extract best-effort client IP from request for audit logging. - - TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be - forged by any caller. This value is used for forensic audit logging only — - not for authentication or access control decisions. In production, deploy - behind a trusted reverse proxy (e.g. nginx with - `proxy_set_header X-Forwarded-For $remote_addr;`) which overwrites this - header with the real remote IP before it reaches FastAPI, or use a - trusted-proxy middleware that validates the source CIDR. - """ - return request.headers.get("X-Forwarded-For") or ( - request.client.host if request.client else None - ) # ── Safe response helper ────────────────────────────────────────────────────── @@ -90,25 +68,6 @@ def _user_to_dict(user: User) -> dict: } -# ── Password strength helper ────────────────────────────────────────────────── - -def _validate_password_strength(password: str) -> None: - """Raise ValueError with the spec message if password fails any strength rule. - - Rules (AUTH-01): min 12 chars, has uppercase, has lowercase, has digit, - has special char (non-alphanumeric). - """ - if len(password) < 12: - raise ValueError(_PASSWORD_DETAIL) - if not re.search(r"[A-Z]", password): - raise ValueError(_PASSWORD_DETAIL) - if not re.search(r"[a-z]", password): - raise ValueError(_PASSWORD_DETAIL) - if not re.search(r"[0-9]", password): - raise ValueError(_PASSWORD_DETAIL) - if not re.search(r"[^A-Za-z0-9]", password): - raise ValueError(_PASSWORD_DETAIL) - # ── Request models ──────────────────────────────────────────────────────────── @@ -121,10 +80,7 @@ class UserCreate(BaseModel): @field_validator("password") @classmethod def password_strength(cls, v: str) -> str: - try: - _validate_password_strength(v) - except ValueError as exc: - raise ValueError(str(exc)) from exc + validate_password_strength(v) return v @@ -264,7 +220,7 @@ async def create_user( session.add(quota) await session.flush() # persist User + Quota before audit_log FK references them # D-13: admin user created event - _ip_addr = _ip(request) + _ip_addr = get_client_ip(request) await write_audit_log( session, event_type="admin.user_created", @@ -316,7 +272,7 @@ async def update_user_status( detail="Cannot deactivate the only admin", ) - _ip_addr = _ip(request) + _ip_addr = get_client_ip(request) user.is_active = body.is_active if not body.is_active: @@ -426,7 +382,7 @@ async def update_user_quota( else None ) - _ip_addr = _ip(request) + _ip_addr = get_client_ip(request) old_limit = quota.limit_bytes quota.limit_bytes = body.limit_bytes session.add(quota) @@ -471,7 +427,7 @@ async def update_ai_config( if user is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - _ip_addr = _ip(request) + _ip_addr = get_client_ip(request) user.ai_provider = body.ai_provider user.ai_model = body.ai_model session.add(user) @@ -532,7 +488,7 @@ async def delete_user( detail="Cannot delete admin accounts", ) - _ip_addr = _ip(request) + _ip_addr = get_client_ip(request) # SEC-09 (cloud): purge cloud-stored documents and credentials BEFORE DB delete. # Must run before MinIO cleanup so that credentials are still available to build diff --git a/backend/api/auth.py b/backend/api/auth.py index b6556ed..076b03e 100644 --- a/backend/api/auth.py +++ b/backend/api/auth.py @@ -19,7 +19,6 @@ Security invariants: """ from __future__ import annotations -import re import uuid from typing import Literal, Optional @@ -32,6 +31,7 @@ from config import settings from db.models import BackupCode, Quota, RefreshToken, User from deps.auth import get_current_user from deps.db import get_db +from deps.utils import get_client_ip from services import auth as auth_service from services.audit import write_audit_log from slowapi import Limiter @@ -43,30 +43,6 @@ router = APIRouter(prefix="/api/auth", tags=["auth"]) # IP-level rate limiter (SEC-02 — 10 req/min on register/login/refresh) limiter = Limiter(key_func=get_remote_address) -# ── Password strength validation ───────────────────────────────────────────── -_PASSWORD_DETAIL = ( - "Password must be at least 12 characters and include uppercase, " - "lowercase, a number, and a special character." -) - - -def _validate_password_strength(password: str) -> bool: - """Return True if password passes all strength rules (AUTH-01). - - Rules: min 12 chars, has uppercase, has lowercase, has digit, has special char. - """ - if len(password) < 12: - return False - if not re.search(r"[A-Z]", password): - return False - if not re.search(r"[a-z]", password): - return False - if not re.search(r"[0-9]", password): - return False - if not re.search(r"[^A-Za-z0-9]", password): - return False - return True - # ── Request models ──────────────────────────────────────────────────────────── @@ -132,11 +108,10 @@ async def register( - Inserts User + Quota rows in a single transaction """ # Password strength check - if not _validate_password_strength(body.password): - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=_PASSWORD_DETAIL, - ) + try: + auth_service.validate_password_strength(body.password) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc)) # HIBP breach check if await auth_service.check_hibp(body.password): @@ -228,7 +203,7 @@ async def login( user: Optional[User] = result.scalar_one_or_none() # IP extraction for audit log (used in both success and failure paths) - _ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) + _ip = get_client_ip(request) # Verify password (anti-enumeration: same error regardless of whether user exists) if user is None or not auth_service.verify_password(body.password, user.password_hash): @@ -385,7 +360,7 @@ async def logout(request: Request, response: Response, session: AsyncSession = D """Revoke current refresh token and clear the cookie.""" import hashlib as _hashlib - _ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) + _ip = get_client_ip(request) raw_token = request.cookies.get("refresh_token") _logout_user_id = None @@ -423,7 +398,7 @@ async def logout_all( current_user: User = Depends(get_current_user), ): """Sign out of all devices: revoke all refresh tokens for current user.""" - _ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) + _ip = get_client_ip(request) count = await auth_service.revoke_all_refresh_tokens(session, current_user.id) # D-13: sign-out-all event await write_audit_log( @@ -497,14 +472,13 @@ async def change_password( ) # Password strength check - if not _validate_password_strength(body.new_password): - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=_PASSWORD_DETAIL, - ) + try: + auth_service.validate_password_strength(body.new_password) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc)) # Update password - _ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) + _ip = get_client_ip(request) user = await session.get(User, current_user.id) user.password_hash = auth_service.hash_password(body.new_password) # D-13: password changed event (flush within same transaction before commit) @@ -594,7 +568,7 @@ async def enable_totp( await auth_service.store_backup_codes(session, current_user.id, plain_codes) # D-13: TOTP enrolled event - _ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) + _ip = get_client_ip(request) await write_audit_log( session, event_type="auth.totp_enrolled", @@ -620,7 +594,7 @@ async def disable_totp( Clears totp_secret, sets totp_enabled=False, and deletes all backup codes. """ - _ip = request.headers.get("X-Forwarded-For") or (request.client.host if request.client else None) + _ip = get_client_ip(request) user = await session.get(User, current_user.id) user.totp_enabled = False user.totp_secret = None @@ -699,11 +673,10 @@ async def password_reset_confirm( ) # Password strength validation - if not _validate_password_strength(body.new_password): - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=_PASSWORD_DETAIL, - ) + try: + auth_service.validate_password_strength(body.new_password) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(exc)) # HIBP breach check (SEC-03) if await auth_service.check_hibp(body.new_password): diff --git a/backend/api/documents.py b/backend/api/documents.py index 8f46651..ac4cd27 100644 --- a/backend/api/documents.py +++ b/backend/api/documents.py @@ -48,14 +48,7 @@ except ImportError: # Fallback for test environments where minio is not installed S3Error = Exception # type: ignore[assignment,misc] -try: - from storage.google_drive_backend import CloudConnectionError -except ImportError: - # Fallback: define a stub so the except clause compiles even if google deps absent - class CloudConnectionError(Exception): # type: ignore[no-redef] - def __init__(self, msg: str = "", *, reason: str = "") -> None: - super().__init__(msg) - self.reason = reason +from storage.exceptions import CloudConnectionError # Valid cloud backend slugs (T-05-06-01: validated against allowlist, not user-supplied string) _CLOUD_PROVIDERS = frozenset({"google_drive", "onedrive", "nextcloud", "webdav"}) diff --git a/backend/api/folders.py b/backend/api/folders.py index eae2d37..e270bfa 100644 --- a/backend/api/folders.py +++ b/backend/api/folders.py @@ -30,6 +30,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from db.models import Document, Folder, Quota, Share, User from deps.auth import get_regular_user from deps.db import get_db +from deps.utils import get_client_ip from services.audit import write_audit_log from storage import get_storage_backend @@ -51,14 +52,6 @@ class DocumentMove(BaseModel): folder_id: Optional[str] = None -# ── Helper: extract IP address ──────────────────────────────────────────────── - -def _get_ip(request: Request) -> Optional[str]: - """Extract client IP, honouring X-Forwarded-For for reverse proxy setups (Pitfall 5).""" - return request.headers.get("X-Forwarded-For") or ( - request.client.host if request.client else None - ) - # ── Helper: folder serialization ────────────────────────────────────────────── @@ -148,7 +141,7 @@ async def create_folder( user_id=current_user.id, actor_id=current_user.id, resource_id=folder.id, - ip_address=_get_ip(request), + ip_address=get_client_ip(request), metadata_={"name": folder.name, "parent_id": str(parent_uuid) if parent_uuid else None}, ) @@ -316,7 +309,7 @@ async def rename_folder( user_id=current_user.id, actor_id=current_user.id, resource_id=folder.id, - ip_address=_get_ip(request), + ip_address=get_client_ip(request), metadata_={"old_name": old_name, "new_name": folder.name}, ) @@ -436,7 +429,7 @@ async def delete_folder( user_id=current_user.id, actor_id=current_user.id, resource_id=uid, - ip_address=_get_ip(request), + ip_address=get_client_ip(request), metadata_={"name": folder_name, "doc_count": len(docs)}, ) diff --git a/backend/api/shares.py b/backend/api/shares.py index e33ee88..4bfe09f 100644 --- a/backend/api/shares.py +++ b/backend/api/shares.py @@ -27,6 +27,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from db.models import Document, Share, User from deps.auth import get_regular_user from deps.db import get_db +from deps.utils import get_client_ip from services.audit import write_audit_log router = APIRouter(prefix="/api/shares", tags=["shares"]) @@ -62,21 +63,6 @@ class SharePermissionPatch(BaseModel): # ── Helpers ─────────────────────────────────────────────────────────────────── -def _ip(request: Request) -> Optional[str]: - """Extract best-effort client IP from request (behind proxy or direct). - - TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be - forged by any caller. This value is used for forensic audit logging only — - not for authentication or access control decisions. In production, deploy - behind a trusted reverse proxy (e.g. nginx with - `proxy_set_header X-Forwarded-For $remote_addr;`) which overwrites this - header with the real remote IP before it reaches FastAPI, or use a - trusted-proxy middleware that validates the source CIDR. - """ - return request.headers.get("X-Forwarded-For") or ( - request.client.host if request.client else None - ) - # ── POST /api/shares ────────────────────────────────────────────────────────── @@ -141,7 +127,7 @@ async def grant_share( user_id=current_user.id, actor_id=current_user.id, resource_id=uid, - ip_address=_ip(request), + ip_address=get_client_ip(request), metadata_={"recipient_id": str(recipient.id)}, ) @@ -283,7 +269,7 @@ async def update_share_permission( user_id=current_user.id, actor_id=current_user.id, resource_id=share.document_id, - ip_address=_ip(request), + ip_address=get_client_ip(request), metadata_={"share_id": str(share.id), "new_permission": body.permission}, ) await session.commit() @@ -328,7 +314,7 @@ async def revoke_share( user_id=current_user.id, actor_id=current_user.id, resource_id=document_id, - ip_address=_ip(request), + ip_address=get_client_ip(request), metadata_={"recipient_id": str(recipient_id)}, ) diff --git a/backend/deps/utils.py b/backend/deps/utils.py new file mode 100644 index 0000000..9f46df4 --- /dev/null +++ b/backend/deps/utils.py @@ -0,0 +1,34 @@ +"""Shared dependency utilities — request parsing helpers used across all API routers.""" +from __future__ import annotations + +import uuid +from typing import Optional + +from fastapi import HTTPException, Request + + +def get_client_ip(request: Request) -> Optional[str]: + """Extract best-effort client IP from request for audit logging. + + TRUST BOUNDARY: X-Forwarded-For is a client-controlled header and can be + forged by any caller. This value is used for forensic audit logging only — + not for authentication or access control decisions. In production, deploy + behind a trusted reverse proxy (e.g. nginx with + ``proxy_set_header X-Forwarded-For $remote_addr;``) which overwrites this + header with the real remote IP before it reaches FastAPI. + """ + return request.headers.get("X-Forwarded-For") or ( + request.client.host if request.client else None + ) + + +def parse_uuid(value: str, detail: str = "Not found") -> uuid.UUID: + """Parse *value* as a UUID, raising HTTP 404 with *detail* on failure. + + Use at API boundaries to convert path/body string IDs to UUID objects. + Returns the parsed UUID so callers can use it directly without a try/except. + """ + try: + return uuid.UUID(value) + except ValueError: + raise HTTPException(status_code=404, detail=detail) diff --git a/backend/services/auth.py b/backend/services/auth.py index 508a096..02c084c 100644 --- a/backend/services/auth.py +++ b/backend/services/auth.py @@ -20,6 +20,7 @@ from __future__ import annotations import hashlib import hmac import logging +import re import secrets import uuid from datetime import datetime, timezone, timedelta @@ -36,6 +37,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from config import settings from db.models import BackupCode, Quota, RefreshToken, User +_PASSWORD_DETAIL = ( + "Password must be at least 12 characters and include uppercase, " + "lowercase, a number, and a special character." +) + logger = logging.getLogger(__name__) # ── Password hashing ──────────────────────────────────────────────────────────── @@ -59,6 +65,22 @@ def verify_password(plain: str, hashed: str) -> bool: return False +def validate_password_strength(password: str) -> None: + """Raise ValueError with a descriptive message if *password* fails any strength rule. + + Rules (AUTH-01): min 12 chars, uppercase, lowercase, digit, special char. + Callers at the API boundary should catch ValueError and map it to HTTP 422. + """ + if ( + len(password) < 12 + or not re.search(r"[A-Z]", password) + or not re.search(r"[a-z]", password) + or not re.search(r"[0-9]", password) + or not re.search(r"[^A-Za-z0-9]", password) + ): + raise ValueError(_PASSWORD_DETAIL) + + # ── JWT helpers ───────────────────────────────────────────────────────────────── def create_access_token(user_id: str, role: str) -> str: diff --git a/backend/storage/exceptions.py b/backend/storage/exceptions.py new file mode 100644 index 0000000..18f0bf5 --- /dev/null +++ b/backend/storage/exceptions.py @@ -0,0 +1,19 @@ +"""Storage exception types — import from here, never redefine elsewhere.""" +from __future__ import annotations + + +class CloudConnectionError(Exception): + """Raised when a cloud provider signals a non-retryable connection problem. + + Attributes: + reason: "token_expired" — access token expired; API layer can refresh and retry. + "invalid_grant" — refresh token revoked; user must reconnect. + + The backend never updates the DB. The API layer (_call_cloud_op in cloud.py) + catches this exception, performs the DB state transition, and decides whether + to retry or surface a 503 to the client (B2 design, D-05/D-06). + """ + + def __init__(self, msg: str = "", *, reason: str = "") -> None: + super().__init__(msg) + self.reason = reason # "token_expired" | "invalid_grant" diff --git a/backend/storage/google_drive_backend.py b/backend/storage/google_drive_backend.py index a6d3344..eb9f13e 100644 --- a/backend/storage/google_drive_backend.py +++ b/backend/storage/google_drive_backend.py @@ -37,23 +37,7 @@ from googleapiclient.errors import HttpError from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload from storage.base import StorageBackend - - -class CloudConnectionError(Exception): - """Raised when a cloud provider signals a non-retryable connection problem. - - Attributes: - reason: "token_expired" — access token expired; API layer can refresh and retry. - "invalid_grant" — refresh token revoked; user must reconnect. - - The backend never updates the DB. The API layer (_call_cloud_op in cloud.py) - catches this exception, performs the DB state transition, and decides whether - to retry or surface a 503 to the client (B2 design, D-05/D-06). - """ - - def __init__(self, msg: str = "", *, reason: str = "") -> None: - super().__init__(msg) - self.reason = reason # "token_expired" | "invalid_grant" +from storage.exceptions import CloudConnectionError # noqa: F401 re-exported for import compatibility class GoogleDriveBackend(StorageBackend): diff --git a/backend/tasks/document_tasks.py b/backend/tasks/document_tasks.py index 661d1ee..4d89c33 100644 --- a/backend/tasks/document_tasks.py +++ b/backend/tasks/document_tasks.py @@ -76,12 +76,7 @@ async def _run(document_id: str) -> dict: if user is None: return {"document_id": document_id, "status": "missing_user"} - try: - from storage.google_drive_backend import CloudConnectionError - except ImportError: - class CloudConnectionError(Exception): # type: ignore[no-redef] - pass - + from storage.exceptions import CloudConnectionError try: backend = await get_storage_backend_for_document(doc, user, session) file_bytes = await backend.get_object(doc.object_key) diff --git a/backend/tests/test_classifier.py b/backend/tests/test_classifier.py index 06c7a87..7578067 100644 --- a/backend/tests/test_classifier.py +++ b/backend/tests/test_classifier.py @@ -4,7 +4,7 @@ Uses a mock provider — no real AI calls made. """ import json import pytest -from ai.openai_provider import _parse_classification, _parse_suggestions, _strip_code_fences +from ai.utils import parse_classification as _parse_classification, parse_suggestions as _parse_suggestions, strip_code_fences as _strip_code_fences from ai.base import ClassificationResult