Commit 49f43a45 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

Fix mock API routing and retriever alias

parent 566ee233
import asyncio
import json
import logging
import time
from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel
from agent.tools.data_retrieval_tool import SearchItem, data_retrieval_tool
logger = logging.getLogger(__name__)
router = APIRouter()
# --- HELPERS ---
async def retry_with_backoff(coro_fn, max_retries=3, backoff_factor=2):
"""Retry async function with exponential backoff"""
for attempt in range(max_retries):
try:
return await coro_fn()
except Exception as e:
if attempt == max_retries - 1:
raise
wait_time = backoff_factor**attempt
logger.warning(f"⚠️ Attempt {attempt + 1} failed: {e!s}, retrying in {wait_time}s...")
await asyncio.sleep(wait_time)
# --- MODELS ---
class MockQueryRequest(BaseModel):
user_query: str
user_id: str | None = "test_user"
session_id: str | None = None
class MockDBRequest(BaseModel):
query: str | None = None
magento_ref_code: str | None = None
price_min: float | None = None
price_max: float | None = None
top_k: int = 10
class MockRetrieverRequest(BaseModel):
user_query: str
price_min: float | None = None
price_max: float | None = None
magento_ref_code: str | None = None
user_id: str | None = "test_user"
session_id: str | None = None
# --- MOCK LLM RESPONSES (không gọi OpenAI) ---
MOCK_AI_RESPONSES = [
"Dựa trên tìm kiếm của bạn, tôi tìm thấy các sản phẩm phù hợp với nhu cầu của bạn. Những mặt hàng này có chất lượng tốt và giá cả phải chăng.",
"Tôi gợi ý cho bạn những sản phẩm sau. Chúng đều là những lựa chọn phổ biến và nhận được đánh giá cao từ khách hàng.",
"Dựa trên tiêu chí tìm kiếm của bạn, đây là những sản phẩm tốt nhất mà tôi có thể giới thiệu.",
"Những sản phẩm này hoàn toàn phù hợp với yêu cầu của bạn. Hãy xem chi tiết để chọn sản phẩm yêu thích nhất.",
"Tôi đã tìm được các mặt hàng tuyệt vời cho bạn. Hãy kiểm tra chúng để tìm ra lựa chọn tốt nhất.",
]
# --- ENDPOINTS ---
from agent.mock_controller import mock_chat_controller
@router.post("/mock/agent/chat", summary="Mock Agent Chat (Real Tools + Fake LLM)")
async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks):
"""
Mock Agent Chat using mock_chat_controller:
- ✅ Real embedding + vector search (data_retrieval_tool THẬT)
- ✅ Real products from StarRocks
- ❌ Fake LLM response (no OpenAI cost)
- Perfect for stress testing + end-to-end testing
"""
try:
logger.info(f"🚀 [Mock Agent Chat] Starting with query: {req.user_query}")
result = await mock_chat_controller(
query=req.user_query,
user_id=req.user_id or "test_user",
background_tasks=background_tasks,
)
return {
"status": "success",
"user_query": req.user_query,
"user_id": req.user_id,
"session_id": req.session_id,
**result, # Include status, ai_response, product_ids, etc.
}
except Exception as e:
logger.error(f"❌ Error in mock agent chat: {e!s}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Mock Agent Chat Error: {e!s}")
@router.post("/mock/db/search", summary="Real Data Retrieval Tool (Agent Tool)")
async def mock_db_search(req: MockDBRequest):
"""
Dùng `data_retrieval_tool` THẬT từ Agent:
- Nếu có magento_ref_code → CODE SEARCH (không cần embedding)
- Nếu có query → HYDE SEMANTIC SEARCH (embedding + vector search)
- Lọc theo giá nếu có price_min/price_max
- Trả về sản phẩm thực từ StarRocks
Format input giống SearchItem của agent tool.
"""
try:
logger.info("📍 Data Retrieval Tool called")
start_time = time.time()
# Xây dựng SearchItem từ request
search_item = SearchItem(
query=req.query or "sản phẩm",
magento_ref_code=req.magento_ref_code,
price_min=req.price_min,
price_max=req.price_max,
action="search",
)
logger.info(f"🔧 Search params: {search_item.dict(exclude_none=True)}")
# Gọi data_retrieval_tool THẬT với retry
result_json = await retry_with_backoff(
lambda: data_retrieval_tool.ainvoke({"searches": [search_item]}), max_retries=3
)
result = json.loads(result_json)
elapsed_time = time.time() - start_time
logger.info(f"✅ Data Retrieval completed in {elapsed_time:.3f}s")
return {
"status": result.get("status", "success"),
"search_params": search_item.dict(exclude_none=True),
"total_results": len(result.get("results", [{}])[0].get("products", [])),
"products": result.get("results", [{}])[0].get("products", []),
"processing_time_ms": round(elapsed_time * 1000, 2),
"raw_result": result,
}
except Exception as e:
logger.error(f"❌ Error in DB search: {e!s}", exc_info=True)
raise HTTPException(status_code=500, detail=f"DB Search Error: {e!s}")
@router.post("/mock/retriverdb", summary="Real Embedding + Real DB Vector Search")
import asyncio
import json
import logging
import time
from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel
from agent.tools.data_retrieval_tool import SearchItem, data_retrieval_tool
logger = logging.getLogger(__name__)
router = APIRouter()
# --- HELPERS ---
async def retry_with_backoff(coro_fn, max_retries=3, backoff_factor=2):
"""Retry async function with exponential backoff"""
for attempt in range(max_retries):
try:
return await coro_fn()
except Exception as e:
if attempt == max_retries - 1:
raise
wait_time = backoff_factor**attempt
logger.warning(f"⚠️ Attempt {attempt + 1} failed: {e!s}, retrying in {wait_time}s...")
await asyncio.sleep(wait_time)
# --- MODELS ---
class MockQueryRequest(BaseModel):
user_query: str
user_id: str | None = "test_user"
session_id: str | None = None
class MockDBRequest(BaseModel):
query: str | None = None
magento_ref_code: str | None = None
price_min: float | None = None
price_max: float | None = None
top_k: int = 10
class MockRetrieverRequest(BaseModel):
user_query: str
price_min: float | None = None
price_max: float | None = None
magento_ref_code: str | None = None
user_id: str | None = "test_user"
session_id: str | None = None
# --- MOCK LLM RESPONSES (không gọi OpenAI) ---
MOCK_AI_RESPONSES = [
"Dựa trên tìm kiếm của bạn, tôi tìm thấy các sản phẩm phù hợp với nhu cầu của bạn. Những mặt hàng này có chất lượng tốt và giá cả phải chăng.",
"Tôi gợi ý cho bạn những sản phẩm sau. Chúng đều là những lựa chọn phổ biến và nhận được đánh giá cao từ khách hàng.",
"Dựa trên tiêu chí tìm kiếm của bạn, đây là những sản phẩm tốt nhất mà tôi có thể giới thiệu.",
"Những sản phẩm này hoàn toàn phù hợp với yêu cầu của bạn. Hãy xem chi tiết để chọn sản phẩm yêu thích nhất.",
"Tôi đã tìm được các mặt hàng tuyệt vời cho bạn. Hãy kiểm tra chúng để tìm ra lựa chọn tốt nhất.",
]
# --- ENDPOINTS ---
from agent.mock_controller import mock_chat_controller
@router.post("/api/mock/agent/chat", summary="Mock Agent Chat (Real Tools + Fake LLM)")
async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks):
"""
Mock Agent Chat using mock_chat_controller:
- ✅ Real embedding + vector search (data_retrieval_tool THẬT)
- ✅ Real products from StarRocks
- ❌ Fake LLM response (no OpenAI cost)
- Perfect for stress testing + end-to-end testing
"""
try:
logger.info(f"🚀 [Mock Agent Chat] Starting with query: {req.user_query}")
result = await mock_chat_controller(
query=req.user_query,
user_id=req.user_id or "test_user",
background_tasks=background_tasks,
)
return {
"status": "success",
"user_query": req.user_query,
"user_id": req.user_id,
"session_id": req.session_id,
**result, # Include status, ai_response, product_ids, etc.
}
except Exception as e:
logger.error(f"❌ Error in mock agent chat: {e!s}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Mock Agent Chat Error: {e!s}")
@router.post("/api/mock/db/search", summary="Real Data Retrieval Tool (Agent Tool)")
async def mock_db_search(req: MockDBRequest):
"""
Dùng `data_retrieval_tool` THẬT từ Agent:
- Nếu có magento_ref_code → CODE SEARCH (không cần embedding)
- Nếu có query → HYDE SEMANTIC SEARCH (embedding + vector search)
- Lọc theo giá nếu có price_min/price_max
- Trả về sản phẩm thực từ StarRocks
Format input giống SearchItem của agent tool.
"""
try:
logger.info("📍 Data Retrieval Tool called")
start_time = time.time()
# Xây dựng SearchItem từ request
search_item = SearchItem(
query=req.query or "sản phẩm",
magento_ref_code=req.magento_ref_code,
price_min=req.price_min,
price_max=req.price_max,
action="search",
)
logger.info(f"🔧 Search params: {search_item.dict(exclude_none=True)}")
# Gọi data_retrieval_tool THẬT với retry
result_json = await retry_with_backoff(
lambda: data_retrieval_tool.ainvoke({"searches": [search_item]}), max_retries=3
)
result = json.loads(result_json)
elapsed_time = time.time() - start_time
logger.info(f"✅ Data Retrieval completed in {elapsed_time:.3f}s")
return {
"status": result.get("status", "success"),
"search_params": search_item.dict(exclude_none=True),
"total_results": len(result.get("results", [{}])[0].get("products", [])),
"products": result.get("results", [{}])[0].get("products", []),
"processing_time_ms": round(elapsed_time * 1000, 2),
"raw_result": result,
}
except Exception as e:
logger.error(f"❌ Error in DB search: {e!s}", exc_info=True)
raise HTTPException(status_code=500, detail=f"DB Search Error: {e!s}")
@router.post("/api/mock/retrieverdb", summary="Real Embedding + Real DB Vector Search")
@router.post("/api/mock/retriverdb", summary="Real Embedding + Real DB Vector Search (Legacy)")
async def mock_retriever_db(req: MockRetrieverRequest):
"""
API thực tế để test Retriever + DB Search (dùng agent tool):
- Lấy query từ user
- Embedding THẬT (gọi OpenAI embedding trong tool)
- Vector search THẬT trong StarRocks
- Trả về kết quả sản phẩm thực (bỏ qua LLM)
Dùng để test performance của embedding + vector search riêng biệt.
"""
try:
logger.info(f"📍 Retriever DB started: {req.user_query}")
start_time = time.time()
# Xây dựng SearchItem từ request
search_item = SearchItem(
query=req.user_query,
magento_ref_code=req.magento_ref_code,
price_min=req.price_min,
price_max=req.price_max,
action="search",
)
logger.info(f"🔧 Retriever params: {search_item.dict(exclude_none=True)}")
# Gọi data_retrieval_tool THẬT (embedding + vector search) với retry
result_json = await retry_with_backoff(
lambda: data_retrieval_tool.ainvoke({"searches": [search_item]}), max_retries=3
)
result = json.loads(result_json)
elapsed_time = time.time() - start_time
logger.info(f"✅ Retriever completed in {elapsed_time:.3f}s")
# Parse kết quả
search_results = result.get("results", [{}])[0]
products = search_results.get("products", [])
return {
"status": result.get("status", "success"),
"user_query": req.user_query,
"user_id": req.user_id,
"session_id": req.session_id,
"search_params": search_item.dict(exclude_none=True),
"total_results": len(products),
"products": products,
"processing_time_ms": round(elapsed_time * 1000, 2),
}
except Exception as e:
logger.error(f"❌ Error in retriever DB: {e!s}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Retriever DB Error: {e!s}")
"""
API thực tế để test Retriever + DB Search (dùng agent tool):
- Lấy query từ user
- Embedding THẬT (gọi OpenAI embedding trong tool)
- Vector search THẬT trong StarRocks
- Trả về kết quả sản phẩm thực (bỏ qua LLM)
Dùng để test performance của embedding + vector search riêng biệt.
"""
try:
logger.info(f"📍 Retriever DB started: {req.user_query}")
start_time = time.time()
# Xây dựng SearchItem từ request
search_item = SearchItem(
query=req.user_query,
magento_ref_code=req.magento_ref_code,
price_min=req.price_min,
price_max=req.price_max,
action="search",
)
logger.info(f"🔧 Retriever params: {search_item.dict(exclude_none=True)}")
# Gọi data_retrieval_tool THẬT (embedding + vector search) với retry
result_json = await retry_with_backoff(
lambda: data_retrieval_tool.ainvoke({"searches": [search_item]}), max_retries=3
)
result = json.loads(result_json)
elapsed_time = time.time() - start_time
logger.info(f"✅ Retriever completed in {elapsed_time:.3f}s")
# Parse kết quả
search_results = result.get("results", [{}])[0]
products = search_results.get("products", [])
return {
"status": result.get("status", "success"),
"user_query": req.user_query,
"user_id": req.user_id,
"session_id": req.session_id,
"search_params": search_item.dict(exclude_none=True),
"total_results": len(products),
"products": products,
"processing_time_ms": round(elapsed_time * 1000, 2),
}
except Exception as e:
logger.error(f"❌ Error in retriever DB: {e!s}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Retriever DB Error: {e!s}")
"""
Middleware Module - Gom Auth + Rate Limit
Singleton Pattern cho cả 2 services
"""
from __future__ import annotations
import json
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING
from fastapi import HTTPException, Request, status
from starlette.middleware.base import BaseHTTPMiddleware
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION
# =============================================================================
# Public endpoints - không cần auth
PUBLIC_PATHS = {
"/",
"/health",
"/docs",
"/openapi.json",
"/redoc",
}
# Public path prefixes
PUBLIC_PATH_PREFIXES = [
"/static",
"/mock",
]
# =============================================================================
# AUTH + RATE LIMIT MIDDLEWARE CLASS
# =============================================================================
# Paths that need rate limit check
RATE_LIMITED_PATHS = [
"/api/agent/chat",
]
class CanifaAuthMiddleware(BaseHTTPMiddleware):
"""
Canifa Authentication + Rate Limit Middleware
Flow:
1. Frontend gửi request với Authorization: Bearer <canifa_token>
2. Middleware verify token với Canifa API → extract customer_id
3. Check message rate limit (Guest: 10, User: 100)
4. Attach user info vào request.state
5. Routes lấy trực tiếp từ request.state
"""
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path
method = request.method
# ✅ Allow OPTIONS requests (CORS preflight)
if method == "OPTIONS":
return await call_next(request)
# Skip public endpoints
if path in PUBLIC_PATHS:
return await call_next(request)
# Skip public path prefixes
if any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES):
return await call_next(request)
# =====================================================================
# STEP 1: AUTHENTICATION (Canifa API)
# =====================================================================
try:
auth_header = request.headers.get("Authorization")
# --- Device ID from Body ---
device_id = ""
if method in ["POST", "PUT", "PATCH"]:
try:
body_bytes = await request.body()
async def receive_wrapper():
return {"type": "http.request", "body": body_bytes}
request._receive = receive_wrapper
if body_bytes:
try:
body_json = json.loads(body_bytes)
device_id = body_json.get("device_id", "")
except json.JSONDecodeError:
pass
except Exception as e:
logger.warning(f"Error reading device_id from body: {e}")
# Fallback: Nếu không có trong body, tìm trong header -> IP
if not device_id:
device_id = request.headers.get("device_id", "")
if not device_id:
device_id = f"unknown_{request.client.host}" if request.client else "unknown"
# ========== DEV MODE: Bypass auth ==========
dev_user_id = request.headers.get("X-Dev-User-Id")
if dev_user_id:
logger.warning(f"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id}")
request.state.user = {"customer_id": dev_user_id}
request.state.user_id = dev_user_id
request.state.is_authenticated = True
request.state.device_id = device_id or dev_user_id
return await call_next(request)
# --- TRƯỜNG HỢP 1: KHÔNG CÓ TOKEN -> GUEST ---
if not auth_header or not auth_header.startswith("Bearer "):
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
request.state.device_id = device_id
else:
# --- TRƯỜNG HỢP 2: CÓ TOKEN -> GỌI CANIFA VERIFY ---
token = auth_header.replace("Bearer ", "")
from common.canifa_api import verify_canifa_token, extract_user_id_from_canifa_response
try:
user_data = await verify_canifa_token(token)
user_id = await extract_user_id_from_canifa_response(user_data)
if user_id:
request.state.user = user_data
request.state.user_id = user_id
request.state.token = token
request.state.is_authenticated = True
request.state.device_id = device_id
logger.debug(f"✅ Canifa Auth Success: User {user_id}")
else:
logger.warning(f"⚠️ Invalid Canifa Token -> Guest Mode")
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
request.state.device_id = device_id
except Exception as e:
logger.error(f"❌ Canifa Auth Error: {e} -> Guest Mode")
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
request.state.device_id = device_id
except Exception as e:
logger.error(f"❌ Middleware Auth Error: {e}")
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
request.state.device_id = ""
# =====================================================================
# STEP 2: RATE LIMIT CHECK (Chỉ cho các path cần limit)
# =====================================================================
if path in RATE_LIMITED_PATHS:
try:
from common.message_limit import message_limit_service
from fastapi.responses import JSONResponse
# Lấy identity_key làm rate limit key
# Guest: device_id → limit 10
# User: user_id → limit 100
is_authenticated = request.state.is_authenticated
if is_authenticated and request.state.user_id:
rate_limit_key = request.state.user_id
else:
rate_limit_key = request.state.device_id
if rate_limit_key:
can_send, limit_info = await message_limit_service.check_limit(
identity_key=rate_limit_key,
is_authenticated=is_authenticated,
)
# Lưu limit_info vào request.state để route có thể dùng
request.state.limit_info = limit_info
if not can_send:
logger.warning(
f"⚠️ Rate Limit Exceeded: {rate_limit_key} | "
f"used={limit_info['used']}/{limit_info['limit']}"
)
return JSONResponse(
status_code=429,
content={
"status": "error",
"error_code": limit_info.get("error_code") or "MESSAGE_LIMIT_EXCEEDED",
"message": limit_info["message"],
"require_login": limit_info["require_login"],
"limit_info": {
"limit": limit_info["limit"],
"used": limit_info["used"],
"remaining": limit_info["remaining"],
"reset_seconds": limit_info["reset_seconds"],
},
},
)
else:
logger.warning(f"⚠️ No identity_key for rate limiting")
except Exception as e:
logger.error(f"❌ Rate Limit Check Error: {e}")
# Cho phép request tiếp tục nếu lỗi rate limit
return await call_next(request)
# =============================================================================
# MIDDLEWARE MANAGER - Singleton to manage all middlewares
# =============================================================================
class MiddlewareManager:
"""
Middleware Manager - Singleton Pattern
Quản lý và setup tất cả middlewares cho FastAPI app
Usage:
from common.middleware import middleware_manager
app = FastAPI()
middleware_manager.setup(app, enable_auth=True, enable_rate_limit=True)
"""
_instance: MiddlewareManager | None = None
_initialized: bool = False
def __new__(cls) -> MiddlewareManager:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if MiddlewareManager._initialized:
return
self._auth_enabled = False
self._rate_limit_enabled = False
MiddlewareManager._initialized = True
logger.info("✅ MiddlewareManager initialized")
def setup(
self,
app: FastAPI,
*,
enable_auth: bool = True,
enable_rate_limit: bool = True,
enable_cors: bool = True,
cors_origins: list[str] | None = None,
) -> None:
"""
Setup tất cả middlewares cho FastAPI app.
Args:
app: FastAPI application
enable_auth: Bật Canifa authentication middleware
enable_rate_limit: Bật rate limiting
enable_cors: Bật CORS middleware
cors_origins: List origins cho CORS (default: ["*"])
Note:
Thứ tự middleware quan trọng! Middleware thêm sau sẽ chạy TRƯỚC.
Order: CORS → Auth → RateLimit → SlowAPI
"""
# 1. CORS Middleware (thêm cuối cùng để chạy đầu tiên)
if enable_cors:
self._setup_cors(app, cors_origins or ["*"])
# 2. Auth Middleware
if enable_auth:
self._setup_auth(app)
# 3. Rate Limit Middleware
if enable_rate_limit:
self._setup_rate_limit(app)
logger.info(
f"✅ Middlewares configured: "
f"CORS={enable_cors}, Auth={enable_auth}, RateLimit={enable_rate_limit}"
)
def _setup_cors(self, app: FastAPI, origins: list[str]) -> None:
"""Setup CORS middleware."""
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logger.info(f"✅ CORS middleware enabled (origins: {origins})")
def _setup_auth(self, app: FastAPI) -> None:
"""Setup Canifa auth middleware."""
app.add_middleware(CanifaAuthMiddleware)
self._auth_enabled = True
logger.info("✅ Canifa Auth middleware enabled")
def _setup_rate_limit(self, app: FastAPI) -> None:
"""Setup rate limiting."""
from common.rate_limit import rate_limit_service
rate_limit_service.setup(app)
self._rate_limit_enabled = True
logger.info("✅ Rate Limit middleware enabled")
@property
def is_auth_enabled(self) -> bool:
return self._auth_enabled
@property
def is_rate_limit_enabled(self) -> bool:
return self._rate_limit_enabled
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
middleware_manager = MiddlewareManager()
"""
Middleware Module - Gom Auth + Rate Limit
Singleton Pattern cho cả 2 services
"""
from __future__ import annotations
import json
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING
from fastapi import HTTPException, Request, status
from starlette.middleware.base import BaseHTTPMiddleware
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION
# =============================================================================
# Public endpoints - không cần auth
PUBLIC_PATHS = {
"/",
"/health",
"/docs",
"/openapi.json",
"/redoc",
}
# Public path prefixes
PUBLIC_PATH_PREFIXES = [
"/static",
"/mock",
"/api/mock",
]
# =============================================================================
# AUTH + RATE LIMIT MIDDLEWARE CLASS
# =============================================================================
# Paths that need rate limit check
RATE_LIMITED_PATHS = [
"/api/agent/chat",
]
class CanifaAuthMiddleware(BaseHTTPMiddleware):
"""
Canifa Authentication + Rate Limit Middleware
Flow:
1. Frontend gửi request với Authorization: Bearer <canifa_token>
2. Middleware verify token với Canifa API → extract customer_id
3. Check message rate limit (Guest: 10, User: 100)
4. Attach user info vào request.state
5. Routes lấy trực tiếp từ request.state
"""
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path
method = request.method
# ✅ Allow OPTIONS requests (CORS preflight)
if method == "OPTIONS":
return await call_next(request)
# Skip public endpoints
if path in PUBLIC_PATHS:
return await call_next(request)
# Skip public path prefixes
if any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES):
return await call_next(request)
# =====================================================================
# STEP 1: AUTHENTICATION (Canifa API)
# =====================================================================
try:
auth_header = request.headers.get("Authorization")
# --- Device ID from Body ---
device_id = ""
if method in ["POST", "PUT", "PATCH"]:
try:
body_bytes = await request.body()
async def receive_wrapper():
return {"type": "http.request", "body": body_bytes}
request._receive = receive_wrapper
if body_bytes:
try:
body_json = json.loads(body_bytes)
device_id = body_json.get("device_id", "")
except json.JSONDecodeError:
pass
except Exception as e:
logger.warning(f"Error reading device_id from body: {e}")
# Fallback: Nếu không có trong body, tìm trong header -> IP
if not device_id:
device_id = request.headers.get("device_id", "")
if not device_id:
device_id = f"unknown_{request.client.host}" if request.client else "unknown"
# ========== DEV MODE: Bypass auth ==========
dev_user_id = request.headers.get("X-Dev-User-Id")
if dev_user_id:
logger.warning(f"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id}")
request.state.user = {"customer_id": dev_user_id}
request.state.user_id = dev_user_id
request.state.is_authenticated = True
request.state.device_id = device_id or dev_user_id
return await call_next(request)
# --- TRƯỜNG HỢP 1: KHÔNG CÓ TOKEN -> GUEST ---
if not auth_header or not auth_header.startswith("Bearer "):
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
request.state.device_id = device_id
else:
# --- TRƯỜNG HỢP 2: CÓ TOKEN -> GỌI CANIFA VERIFY ---
token = auth_header.replace("Bearer ", "")
from common.canifa_api import verify_canifa_token, extract_user_id_from_canifa_response
try:
user_data = await verify_canifa_token(token)
user_id = await extract_user_id_from_canifa_response(user_data)
if user_id:
request.state.user = user_data
request.state.user_id = user_id
request.state.token = token
request.state.is_authenticated = True
request.state.device_id = device_id
logger.debug(f"✅ Canifa Auth Success: User {user_id}")
else:
logger.warning(f"⚠️ Invalid Canifa Token -> Guest Mode")
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
request.state.device_id = device_id
except Exception as e:
logger.error(f"❌ Canifa Auth Error: {e} -> Guest Mode")
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
request.state.device_id = device_id
except Exception as e:
logger.error(f"❌ Middleware Auth Error: {e}")
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
request.state.device_id = ""
# =====================================================================
# STEP 2: RATE LIMIT CHECK (Chỉ cho các path cần limit)
# =====================================================================
if path in RATE_LIMITED_PATHS:
try:
from common.message_limit import message_limit_service
from fastapi.responses import JSONResponse
# Lấy identity_key làm rate limit key
# Guest: device_id → limit 10
# User: user_id → limit 100
is_authenticated = request.state.is_authenticated
if is_authenticated and request.state.user_id:
rate_limit_key = request.state.user_id
else:
rate_limit_key = request.state.device_id
if rate_limit_key:
can_send, limit_info = await message_limit_service.check_limit(
identity_key=rate_limit_key,
is_authenticated=is_authenticated,
)
# Lưu limit_info vào request.state để route có thể dùng
request.state.limit_info = limit_info
if not can_send:
logger.warning(
f"⚠️ Rate Limit Exceeded: {rate_limit_key} | "
f"used={limit_info['used']}/{limit_info['limit']}"
)
return JSONResponse(
status_code=429,
content={
"status": "error",
"error_code": limit_info.get("error_code") or "MESSAGE_LIMIT_EXCEEDED",
"message": limit_info["message"],
"require_login": limit_info["require_login"],
"limit_info": {
"limit": limit_info["limit"],
"used": limit_info["used"],
"remaining": limit_info["remaining"],
"reset_seconds": limit_info["reset_seconds"],
},
},
)
else:
logger.warning(f"⚠️ No identity_key for rate limiting")
except Exception as e:
logger.error(f"❌ Rate Limit Check Error: {e}")
# Cho phép request tiếp tục nếu lỗi rate limit
return await call_next(request)
# =============================================================================
# MIDDLEWARE MANAGER - Singleton to manage all middlewares
# =============================================================================
class MiddlewareManager:
"""
Middleware Manager - Singleton Pattern
Quản lý và setup tất cả middlewares cho FastAPI app
Usage:
from common.middleware import middleware_manager
app = FastAPI()
middleware_manager.setup(app, enable_auth=True, enable_rate_limit=True)
"""
_instance: MiddlewareManager | None = None
_initialized: bool = False
def __new__(cls) -> MiddlewareManager:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if MiddlewareManager._initialized:
return
self._auth_enabled = False
self._rate_limit_enabled = False
MiddlewareManager._initialized = True
logger.info("✅ MiddlewareManager initialized")
def setup(
self,
app: FastAPI,
*,
enable_auth: bool = True,
enable_rate_limit: bool = True,
enable_cors: bool = True,
cors_origins: list[str] | None = None,
) -> None:
"""
Setup tất cả middlewares cho FastAPI app.
Args:
app: FastAPI application
enable_auth: Bật Canifa authentication middleware
enable_rate_limit: Bật rate limiting
enable_cors: Bật CORS middleware
cors_origins: List origins cho CORS (default: ["*"])
Note:
Thứ tự middleware quan trọng! Middleware thêm sau sẽ chạy TRƯỚC.
Order: CORS → Auth → RateLimit → SlowAPI
"""
# 1. CORS Middleware (thêm cuối cùng để chạy đầu tiên)
if enable_cors:
self._setup_cors(app, cors_origins or ["*"])
# 2. Auth Middleware
if enable_auth:
self._setup_auth(app)
# 3. Rate Limit Middleware
if enable_rate_limit:
self._setup_rate_limit(app)
logger.info(
f"✅ Middlewares configured: "
f"CORS={enable_cors}, Auth={enable_auth}, RateLimit={enable_rate_limit}"
)
def _setup_cors(self, app: FastAPI, origins: list[str]) -> None:
"""Setup CORS middleware."""
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logger.info(f"✅ CORS middleware enabled (origins: {origins})")
def _setup_auth(self, app: FastAPI) -> None:
"""Setup Canifa auth middleware."""
app.add_middleware(CanifaAuthMiddleware)
self._auth_enabled = True
logger.info("✅ Canifa Auth middleware enabled")
def _setup_rate_limit(self, app: FastAPI) -> None:
"""Setup rate limiting."""
from common.rate_limit import rate_limit_service
rate_limit_service.setup(app)
self._rate_limit_enabled = True
logger.info("✅ Rate Limit middleware enabled")
@property
def is_auth_enabled(self) -> bool:
return self._auth_enabled
@property
def is_rate_limit_enabled(self) -> bool:
return self._rate_limit_enabled
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
middleware_manager = MiddlewareManager()
"""
Rate Limiting Service - Singleton Pattern
Sử dụng SlowAPI với Redis backend (production) hoặc Memory (dev)
"""
from __future__ import annotations
import logging
import os
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from fastapi import Request
from fastapi.responses import JSONResponse
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
class RateLimitService:
"""
Rate Limiting Service - Singleton Pattern
Usage:
# Trong server.py
from common.rate_limit import RateLimitService
rate_limiter = RateLimitService()
rate_limiter.setup(app)
# Trong route
from common.rate_limit import RateLimitService
@router.post("/chat")
@RateLimitService().limiter.limit("10/minute")
async def chat(request: Request):
...
"""
_instance: RateLimitService | None = None
_initialized: bool = False
# =========================================================================
# SINGLETON PATTERN
# =========================================================================
def __new__(cls) -> RateLimitService:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
# Chỉ init một lần
if RateLimitService._initialized:
return
# Configuration
self.storage_uri = os.getenv("RATE_STORAGE_URI", "memory://")
self.default_limits = ["100/hour", "30/minute"]
self.block_duration_minutes = int(os.getenv("RATE_LIMIT_BLOCK_MINUTES", "5"))
# Paths không áp dụng rate limit
self.exempt_paths = {
"/",
"/health",
"/docs",
"/openapi.json",
"/redoc",
}
self.exempt_prefixes = ["/static", "/mock"]
# In-memory blocklist (có thể chuyển sang Redis)
self._blocklist: dict[str, datetime] = {}
# Create limiter instance
self.limiter = Limiter(
key_func=self._get_client_identifier,
storage_uri=self.storage_uri,
default_limits=self.default_limits,
)
RateLimitService._initialized = True
logger.info(f"✅ RateLimitService initialized (storage: {self.storage_uri})")
# =========================================================================
# CLIENT IDENTIFIER
# =========================================================================
@staticmethod
def _get_client_identifier(request: Request) -> str:
"""
Lấy client identifier cho rate limiting.
Ưu tiên: user_id (authenticated) > device_id > IP address
"""
# 1. Nếu đã authenticated → dùng user_id
if hasattr(request.state, "user_id") and request.state.user_id:
return f"user:{request.state.user_id}"
# 2. Nếu có device_id trong header → dùng device_id
device_id = request.headers.get("device_id")
if device_id:
return f"device:{device_id}"
# 3. Fallback → IP address
try:
return f"ip:{get_remote_address(request)}"
except Exception:
if request.client:
return f"ip:{request.client.host}"
return "unknown"
# =========================================================================
# BLOCKLIST MANAGEMENT
# =========================================================================
def is_blocked(self, key: str) -> tuple[bool, int]:
"""
Check if client is blocked.
Returns: (is_blocked, retry_after_seconds)
"""
now = datetime.utcnow()
blocked_until = self._blocklist.get(key)
if blocked_until:
if blocked_until > now:
retry_after = int((blocked_until - now).total_seconds())
return True, retry_after
else:
# Block expired
self._blocklist.pop(key, None)
return False, 0
def block_client(self, key: str) -> int:
"""
Block client for configured duration.
Returns: retry_after_seconds
"""
self._blocklist[key] = datetime.utcnow() + timedelta(minutes=self.block_duration_minutes)
return self.block_duration_minutes * 60
def unblock_client(self, key: str) -> None:
"""Unblock client manually."""
self._blocklist.pop(key, None)
# =========================================================================
# PATH CHECKING
# =========================================================================
def is_exempt(self, path: str) -> bool:
"""Check if path is exempt from rate limiting."""
if path in self.exempt_paths:
return True
return any(path.startswith(prefix) for prefix in self.exempt_prefixes)
# =========================================================================
# SETUP FOR FASTAPI APP
# =========================================================================
def setup(self, app: FastAPI) -> None:
"""
Setup rate limiting cho FastAPI app.
Gọi trong server.py sau khi tạo app.
"""
# Attach limiter to app state
app.state.limiter = self.limiter
app.state.rate_limit_service = self
# Register middleware
self._register_block_middleware(app)
self._register_exception_handler(app)
# Add SlowAPI middleware (PHẢI thêm SAU custom middlewares)
app.add_middleware(SlowAPIMiddleware)
logger.info("✅ Rate limiting middleware registered")
def _register_block_middleware(self, app: FastAPI) -> None:
"""Register middleware to check blocklist."""
@app.middleware("http")
async def rate_limit_block_middleware(request: Request, call_next):
path = request.url.path
# Skip exempt paths
if self.is_exempt(path):
return await call_next(request)
# Bypass header cho testing
if request.headers.get("X-Bypass-RateLimit") == "1":
return await call_next(request)
# Check blocklist
key = self._get_client_identifier(request)
is_blocked, retry_after = self.is_blocked(key)
if is_blocked:
return JSONResponse(
status_code=429,
content={
"detail": "Quá số lượt cho phép. Vui lòng thử lại sau.",
"retry_after_seconds": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
return await call_next(request)
def _register_exception_handler(self, app: FastAPI) -> None:
"""Register exception handler for rate limit exceeded."""
@app.exception_handler(RateLimitExceeded)
async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
key = self._get_client_identifier(request)
retry_after = self.block_client(key)
logger.warning(f"⚠️ Rate limit exceeded for {key}, blocked for {self.block_duration_minutes} minutes")
return JSONResponse(
status_code=429,
content={
"detail": "Quá số lượt cho phép. Vui lòng thử lại sau.",
"retry_after_seconds": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
# =============================================================================
# SINGLETON INSTANCE - Import trực tiếp để dùng
# =============================================================================
rate_limit_service = RateLimitService()
"""
Rate Limiting Service - Singleton Pattern
Sử dụng SlowAPI với Redis backend (production) hoặc Memory (dev)
"""
from __future__ import annotations
import logging
import os
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from fastapi import Request
from fastapi.responses import JSONResponse
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
class RateLimitService:
"""
Rate Limiting Service - Singleton Pattern
Usage:
# Trong server.py
from common.rate_limit import RateLimitService
rate_limiter = RateLimitService()
rate_limiter.setup(app)
# Trong route
from common.rate_limit import RateLimitService
@router.post("/chat")
@RateLimitService().limiter.limit("10/minute")
async def chat(request: Request):
...
"""
_instance: RateLimitService | None = None
_initialized: bool = False
# =========================================================================
# SINGLETON PATTERN
# =========================================================================
def __new__(cls) -> RateLimitService:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
# Chỉ init một lần
if RateLimitService._initialized:
return
# Configuration
self.storage_uri = os.getenv("RATE_STORAGE_URI", "memory://")
self.default_limits = ["100/hour", "30/minute"]
self.block_duration_minutes = int(os.getenv("RATE_LIMIT_BLOCK_MINUTES", "5"))
# Paths không áp dụng rate limit
self.exempt_paths = {
"/",
"/health",
"/docs",
"/openapi.json",
"/redoc",
}
self.exempt_prefixes = ["/static", "/mock", "/api/mock"]
# In-memory blocklist (có thể chuyển sang Redis)
self._blocklist: dict[str, datetime] = {}
# Create limiter instance
self.limiter = Limiter(
key_func=self._get_client_identifier,
storage_uri=self.storage_uri,
default_limits=self.default_limits,
)
RateLimitService._initialized = True
logger.info(f"✅ RateLimitService initialized (storage: {self.storage_uri})")
# =========================================================================
# CLIENT IDENTIFIER
# =========================================================================
@staticmethod
def _get_client_identifier(request: Request) -> str:
"""
Lấy client identifier cho rate limiting.
Ưu tiên: user_id (authenticated) > device_id > IP address
"""
# 1. Nếu đã authenticated → dùng user_id
if hasattr(request.state, "user_id") and request.state.user_id:
return f"user:{request.state.user_id}"
# 2. Nếu có device_id trong header → dùng device_id
device_id = request.headers.get("device_id")
if device_id:
return f"device:{device_id}"
# 3. Fallback → IP address
try:
return f"ip:{get_remote_address(request)}"
except Exception:
if request.client:
return f"ip:{request.client.host}"
return "unknown"
# =========================================================================
# BLOCKLIST MANAGEMENT
# =========================================================================
def is_blocked(self, key: str) -> tuple[bool, int]:
"""
Check if client is blocked.
Returns: (is_blocked, retry_after_seconds)
"""
now = datetime.utcnow()
blocked_until = self._blocklist.get(key)
if blocked_until:
if blocked_until > now:
retry_after = int((blocked_until - now).total_seconds())
return True, retry_after
else:
# Block expired
self._blocklist.pop(key, None)
return False, 0
def block_client(self, key: str) -> int:
"""
Block client for configured duration.
Returns: retry_after_seconds
"""
self._blocklist[key] = datetime.utcnow() + timedelta(minutes=self.block_duration_minutes)
return self.block_duration_minutes * 60
def unblock_client(self, key: str) -> None:
"""Unblock client manually."""
self._blocklist.pop(key, None)
# =========================================================================
# PATH CHECKING
# =========================================================================
def is_exempt(self, path: str) -> bool:
"""Check if path is exempt from rate limiting."""
if path in self.exempt_paths:
return True
return any(path.startswith(prefix) for prefix in self.exempt_prefixes)
# =========================================================================
# SETUP FOR FASTAPI APP
# =========================================================================
def setup(self, app: FastAPI) -> None:
"""
Setup rate limiting cho FastAPI app.
Gọi trong server.py sau khi tạo app.
"""
# Attach limiter to app state
app.state.limiter = self.limiter
app.state.rate_limit_service = self
# Register middleware
self._register_block_middleware(app)
self._register_exception_handler(app)
# Add SlowAPI middleware (PHẢI thêm SAU custom middlewares)
app.add_middleware(SlowAPIMiddleware)
logger.info("✅ Rate limiting middleware registered")
def _register_block_middleware(self, app: FastAPI) -> None:
"""Register middleware to check blocklist."""
@app.middleware("http")
async def rate_limit_block_middleware(request: Request, call_next):
path = request.url.path
# Skip exempt paths
if self.is_exempt(path):
return await call_next(request)
# Bypass header cho testing
if request.headers.get("X-Bypass-RateLimit") == "1":
return await call_next(request)
# Check blocklist
key = self._get_client_identifier(request)
is_blocked, retry_after = self.is_blocked(key)
if is_blocked:
return JSONResponse(
status_code=429,
content={
"detail": "Quá số lượt cho phép. Vui lòng thử lại sau.",
"retry_after_seconds": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
return await call_next(request)
def _register_exception_handler(self, app: FastAPI) -> None:
"""Register exception handler for rate limit exceeded."""
@app.exception_handler(RateLimitExceeded)
async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
key = self._get_client_identifier(request)
retry_after = self.block_client(key)
logger.warning(f"⚠️ Rate limit exceeded for {key}, blocked for {self.block_duration_minutes} minutes")
return JSONResponse(
status_code=429,
content={
"detail": "Quá số lượt cho phép. Vui lòng thử lại sau.",
"retry_after_seconds": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
# =============================================================================
# SINGLETON INSTANCE - Import trực tiếp để dùng
# =============================================================================
rate_limit_service = RateLimitService()
import asyncio
import os
import platform
if platform.system() == "Windows":
print("🔧 Windows detected: Applying SelectorEventLoopPolicy globally...")
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import logging
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router
from api.prompt_route import router as prompt_router
from common.cache import redis_cache
from common.langfuse_client import get_langfuse_client
from common.middleware import middleware_manager
from config import PORT
# Configure Logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
langfuse_client = get_langfuse_client()
if langfuse_client:
logger.info("✅ Langfuse client ready (lazy loading)")
else:
logger.warning("⚠️ Langfuse client not available (missing keys or disabled)")
app = FastAPI(
title="Contract AI Service",
description="API for Contract AI Service",
version="1.0.0",
)
# =============================================================================
# STARTUP EVENT - Initialize Redis Cache
# =============================================================================
@app.on_event("startup")
async def startup_event():
"""Initialize Redis cache on startup."""
await redis_cache.initialize()
logger.info("✅ Redis cache initialized for message limit")
# =============================================================================
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# =============================================================================
middleware_manager.setup(
app,
enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit=True, # 👈 Bật lại SlowAPI theo yêu cầu
enable_cors=True, # 👈 Bật CORS
cors_origins=["*"], # 👈 Trong production nên limit origins
)
app.include_router(conservation_router)
app.include_router(chatbot_router)
app.include_router(prompt_router)
# --- MOCK API FOR LOAD TESTING ---
try:
from api.mock_api_route import router as mock_router
import asyncio
import os
import platform
if platform.system() == "Windows":
print("🔧 Windows detected: Applying SelectorEventLoopPolicy globally...")
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import logging
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router
from api.prompt_route import router as prompt_router
from common.cache import redis_cache
from common.langfuse_client import get_langfuse_client
from common.middleware import middleware_manager
from config import PORT
# Configure Logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
langfuse_client = get_langfuse_client()
if langfuse_client:
logger.info("✅ Langfuse client ready (lazy loading)")
else:
logger.warning("⚠️ Langfuse client not available (missing keys or disabled)")
app = FastAPI(
title="Contract AI Service",
description="API for Contract AI Service",
version="1.0.0",
)
# =============================================================================
# STARTUP EVENT - Initialize Redis Cache
# =============================================================================
@app.on_event("startup")
async def startup_event():
"""Initialize Redis cache on startup."""
await redis_cache.initialize()
logger.info("✅ Redis cache initialized for message limit")
# =============================================================================
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# =============================================================================
middleware_manager.setup(
app,
enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit=True, # 👈 Bật lại SlowAPI theo yêu cầu
enable_cors=True, # 👈 Bật CORS
cors_origins=["*"], # 👈 Trong production nên limit origins
)
app.include_router(conservation_router)
app.include_router(chatbot_router)
app.include_router(prompt_router)
# --- MOCK API FOR LOAD TESTING ---
try:
from api.mock_api_route import router as mock_router
app.include_router(mock_router)
print("✅ Mock API Router mounted at /mock")
except ImportError:
print("⚠️ Mock Router not found, skipping...")
# ==========================================
# 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢
# ==========================================
try:
static_dir = os.path.join(os.path.dirname(__file__), "static")
if not os.path.exists(static_dir):
os.makedirs(static_dir)
# Mount thư mục static để chạy file index.html
app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
print(f"✅ Static files mounted at /static (Dir: {static_dir})")
except Exception as e:
print(f"⚠️ Failed to mount static files: {e}")
from fastapi.responses import RedirectResponse
@app.get("/")
async def root():
return RedirectResponse(url="/static/index.html")
if __name__ == "__main__":
print("=" * 60)
print("🚀 Contract AI Service Starting...")
print("=" * 60)
print(f"📡 REST API: http://localhost:{PORT}")
print(f"📡 Test Chatbot: http://localhost:{PORT}/static/index.html")
print(f"📚 API Docs: http://localhost:{PORT}/docs")
print("=" * 60)
ENABLE_RELOAD = False
print(f"⚠️ Hot reload: {ENABLE_RELOAD}")
reload_dirs = ["common", "api", "agent"]
if ENABLE_RELOAD:
os.environ["PYTHONUNBUFFERED"] = "1"
uvicorn.run(
"server:app",
host="0.0.0.0",
port=PORT,
reload=ENABLE_RELOAD,
reload_dirs=reload_dirs,
log_level="info",
)
print("✅ Mock API Router mounted at /api/mock")
except ImportError:
print("⚠️ Mock Router not found, skipping...")
# ==========================================
# 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢
# ==========================================
try:
static_dir = os.path.join(os.path.dirname(__file__), "static")
if not os.path.exists(static_dir):
os.makedirs(static_dir)
# Mount thư mục static để chạy file index.html
app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
print(f"✅ Static files mounted at /static (Dir: {static_dir})")
except Exception as e:
print(f"⚠️ Failed to mount static files: {e}")
from fastapi.responses import RedirectResponse
@app.get("/")
async def root():
return RedirectResponse(url="/static/index.html")
if __name__ == "__main__":
print("=" * 60)
print("🚀 Contract AI Service Starting...")
print("=" * 60)
print(f"📡 REST API: http://localhost:{PORT}")
print(f"📡 Test Chatbot: http://localhost:{PORT}/static/index.html")
print(f"📚 API Docs: http://localhost:{PORT}/docs")
print("=" * 60)
ENABLE_RELOAD = False
print(f"⚠️ Hot reload: {ENABLE_RELOAD}")
reload_dirs = ["common", "api", "agent"]
if ENABLE_RELOAD:
os.environ["PYTHONUNBUFFERED"] = "1"
uvicorn.run(
"server:app",
host="0.0.0.0",
port=PORT,
reload=ENABLE_RELOAD,
reload_dirs=reload_dirs,
log_level="info",
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment