Commit 49f43a45 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

Fix mock API routing and retriever alias

parent 566ee233
import asyncio import asyncio
import json import json
import logging import logging
import time import time
from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from agent.tools.data_retrieval_tool import SearchItem, data_retrieval_tool from agent.tools.data_retrieval_tool import SearchItem, data_retrieval_tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
# --- HELPERS --- # --- HELPERS ---
async def retry_with_backoff(coro_fn, max_retries=3, backoff_factor=2): async def retry_with_backoff(coro_fn, max_retries=3, backoff_factor=2):
"""Retry async function with exponential backoff""" """Retry async function with exponential backoff"""
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
return await coro_fn() return await coro_fn()
except Exception as e: except Exception as e:
if attempt == max_retries - 1: if attempt == max_retries - 1:
raise raise
wait_time = backoff_factor**attempt wait_time = backoff_factor**attempt
logger.warning(f"⚠️ Attempt {attempt + 1} failed: {e!s}, retrying in {wait_time}s...") logger.warning(f"⚠️ Attempt {attempt + 1} failed: {e!s}, retrying in {wait_time}s...")
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
# --- MODELS --- # --- MODELS ---
class MockQueryRequest(BaseModel): class MockQueryRequest(BaseModel):
user_query: str user_query: str
user_id: str | None = "test_user" user_id: str | None = "test_user"
session_id: str | None = None session_id: str | None = None
class MockDBRequest(BaseModel): class MockDBRequest(BaseModel):
query: str | None = None query: str | None = None
magento_ref_code: str | None = None magento_ref_code: str | None = None
price_min: float | None = None price_min: float | None = None
price_max: float | None = None price_max: float | None = None
top_k: int = 10 top_k: int = 10
class MockRetrieverRequest(BaseModel): class MockRetrieverRequest(BaseModel):
user_query: str user_query: str
price_min: float | None = None price_min: float | None = None
price_max: float | None = None price_max: float | None = None
magento_ref_code: str | None = None magento_ref_code: str | None = None
user_id: str | None = "test_user" user_id: str | None = "test_user"
session_id: str | None = None session_id: str | None = None
# --- MOCK LLM RESPONSES (không gọi OpenAI) --- # --- MOCK LLM RESPONSES (không gọi OpenAI) ---
MOCK_AI_RESPONSES = [ MOCK_AI_RESPONSES = [
"Dựa trên tìm kiếm của bạn, tôi tìm thấy các sản phẩm phù hợp với nhu cầu của bạn. Những mặt hàng này có chất lượng tốt và giá cả phải chăng.", "Dựa trên tìm kiếm của bạn, tôi tìm thấy các sản phẩm phù hợp với nhu cầu của bạn. Những mặt hàng này có chất lượng tốt và giá cả phải chăng.",
"Tôi gợi ý cho bạn những sản phẩm sau. Chúng đều là những lựa chọn phổ biến và nhận được đánh giá cao từ khách hàng.", "Tôi gợi ý cho bạn những sản phẩm sau. Chúng đều là những lựa chọn phổ biến và nhận được đánh giá cao từ khách hàng.",
"Dựa trên tiêu chí tìm kiếm của bạn, đây là những sản phẩm tốt nhất mà tôi có thể giới thiệu.", "Dựa trên tiêu chí tìm kiếm của bạn, đây là những sản phẩm tốt nhất mà tôi có thể giới thiệu.",
"Những sản phẩm này hoàn toàn phù hợp với yêu cầu của bạn. Hãy xem chi tiết để chọn sản phẩm yêu thích nhất.", "Những sản phẩm này hoàn toàn phù hợp với yêu cầu của bạn. Hãy xem chi tiết để chọn sản phẩm yêu thích nhất.",
"Tôi đã tìm được các mặt hàng tuyệt vời cho bạn. Hãy kiểm tra chúng để tìm ra lựa chọn tốt nhất.", "Tôi đã tìm được các mặt hàng tuyệt vời cho bạn. Hãy kiểm tra chúng để tìm ra lựa chọn tốt nhất.",
] ]
# --- ENDPOINTS --- # --- ENDPOINTS ---
from agent.mock_controller import mock_chat_controller from agent.mock_controller import mock_chat_controller
@router.post("/mock/agent/chat", summary="Mock Agent Chat (Real Tools + Fake LLM)") @router.post("/api/mock/agent/chat", summary="Mock Agent Chat (Real Tools + Fake LLM)")
async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks): async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks):
""" """
Mock Agent Chat using mock_chat_controller: Mock Agent Chat using mock_chat_controller:
- ✅ Real embedding + vector search (data_retrieval_tool THẬT) - ✅ Real embedding + vector search (data_retrieval_tool THẬT)
- ✅ Real products from StarRocks - ✅ Real products from StarRocks
- ❌ Fake LLM response (no OpenAI cost) - ❌ Fake LLM response (no OpenAI cost)
- Perfect for stress testing + end-to-end testing - Perfect for stress testing + end-to-end testing
""" """
try: try:
logger.info(f"🚀 [Mock Agent Chat] Starting with query: {req.user_query}") logger.info(f"🚀 [Mock Agent Chat] Starting with query: {req.user_query}")
result = await mock_chat_controller( result = await mock_chat_controller(
query=req.user_query, query=req.user_query,
user_id=req.user_id or "test_user", user_id=req.user_id or "test_user",
background_tasks=background_tasks, background_tasks=background_tasks,
) )
return { return {
"status": "success", "status": "success",
"user_query": req.user_query, "user_query": req.user_query,
"user_id": req.user_id, "user_id": req.user_id,
"session_id": req.session_id, "session_id": req.session_id,
**result, # Include status, ai_response, product_ids, etc. **result, # Include status, ai_response, product_ids, etc.
} }
except Exception as e: except Exception as e:
logger.error(f"❌ Error in mock agent chat: {e!s}", exc_info=True) logger.error(f"❌ Error in mock agent chat: {e!s}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Mock Agent Chat Error: {e!s}") raise HTTPException(status_code=500, detail=f"Mock Agent Chat Error: {e!s}")
@router.post("/api/mock/db/search", summary="Real Data Retrieval Tool (Agent Tool)")
async def mock_db_search(req: MockDBRequest):
@router.post("/mock/db/search", summary="Real Data Retrieval Tool (Agent Tool)") """
async def mock_db_search(req: MockDBRequest): Dùng `data_retrieval_tool` THẬT từ Agent:
""" - Nếu có magento_ref_code → CODE SEARCH (không cần embedding)
Dùng `data_retrieval_tool` THẬT từ Agent: - Nếu có query → HYDE SEMANTIC SEARCH (embedding + vector search)
- Nếu có magento_ref_code → CODE SEARCH (không cần embedding) - Lọc theo giá nếu có price_min/price_max
- Nếu có query → HYDE SEMANTIC SEARCH (embedding + vector search) - Trả về sản phẩm thực từ StarRocks
- Lọc theo giá nếu có price_min/price_max
- Trả về sản phẩm thực từ StarRocks Format input giống SearchItem của agent tool.
"""
Format input giống SearchItem của agent tool. try:
""" logger.info("📍 Data Retrieval Tool called")
try: start_time = time.time()
logger.info("📍 Data Retrieval Tool called")
start_time = time.time() # Xây dựng SearchItem từ request
search_item = SearchItem(
# Xây dựng SearchItem từ request query=req.query or "sản phẩm",
search_item = SearchItem( magento_ref_code=req.magento_ref_code,
query=req.query or "sản phẩm", price_min=req.price_min,
magento_ref_code=req.magento_ref_code, price_max=req.price_max,
price_min=req.price_min, action="search",
price_max=req.price_max, )
action="search",
) logger.info(f"🔧 Search params: {search_item.dict(exclude_none=True)}")
logger.info(f"🔧 Search params: {search_item.dict(exclude_none=True)}") # Gọi data_retrieval_tool THẬT với retry
result_json = await retry_with_backoff(
# Gọi data_retrieval_tool THẬT với retry lambda: data_retrieval_tool.ainvoke({"searches": [search_item]}), max_retries=3
result_json = await retry_with_backoff( )
lambda: data_retrieval_tool.ainvoke({"searches": [search_item]}), max_retries=3 result = json.loads(result_json)
)
result = json.loads(result_json) elapsed_time = time.time() - start_time
logger.info(f"✅ Data Retrieval completed in {elapsed_time:.3f}s")
elapsed_time = time.time() - start_time
logger.info(f"✅ Data Retrieval completed in {elapsed_time:.3f}s") return {
"status": result.get("status", "success"),
return { "search_params": search_item.dict(exclude_none=True),
"status": result.get("status", "success"), "total_results": len(result.get("results", [{}])[0].get("products", [])),
"search_params": search_item.dict(exclude_none=True), "products": result.get("results", [{}])[0].get("products", []),
"total_results": len(result.get("results", [{}])[0].get("products", [])), "processing_time_ms": round(elapsed_time * 1000, 2),
"products": result.get("results", [{}])[0].get("products", []), "raw_result": result,
"processing_time_ms": round(elapsed_time * 1000, 2), }
"raw_result": result,
} except Exception as e:
logger.error(f"❌ Error in DB search: {e!s}", exc_info=True)
except Exception as e: raise HTTPException(status_code=500, detail=f"DB Search Error: {e!s}")
logger.error(f"❌ Error in DB search: {e!s}", exc_info=True)
raise HTTPException(status_code=500, detail=f"DB Search Error: {e!s}")
@router.post("/api/mock/retrieverdb", summary="Real Embedding + Real DB Vector Search")
@router.post("/api/mock/retriverdb", summary="Real Embedding + Real DB Vector Search (Legacy)")
@router.post("/mock/retriverdb", summary="Real Embedding + Real DB Vector Search")
async def mock_retriever_db(req: MockRetrieverRequest): async def mock_retriever_db(req: MockRetrieverRequest):
""" """
API thực tế để test Retriever + DB Search (dùng agent tool): API thực tế để test Retriever + DB Search (dùng agent tool):
- Lấy query từ user - Lấy query từ user
- Embedding THẬT (gọi OpenAI embedding trong tool) - Embedding THẬT (gọi OpenAI embedding trong tool)
- Vector search THẬT trong StarRocks - Vector search THẬT trong StarRocks
- Trả về kết quả sản phẩm thực (bỏ qua LLM) - Trả về kết quả sản phẩm thực (bỏ qua LLM)
Dùng để test performance của embedding + vector search riêng biệt. Dùng để test performance của embedding + vector search riêng biệt.
""" """
try: try:
logger.info(f"📍 Retriever DB started: {req.user_query}") logger.info(f"📍 Retriever DB started: {req.user_query}")
start_time = time.time() start_time = time.time()
# Xây dựng SearchItem từ request # Xây dựng SearchItem từ request
search_item = SearchItem( search_item = SearchItem(
query=req.user_query, query=req.user_query,
magento_ref_code=req.magento_ref_code, magento_ref_code=req.magento_ref_code,
price_min=req.price_min, price_min=req.price_min,
price_max=req.price_max, price_max=req.price_max,
action="search", action="search",
) )
logger.info(f"🔧 Retriever params: {search_item.dict(exclude_none=True)}") logger.info(f"🔧 Retriever params: {search_item.dict(exclude_none=True)}")
# Gọi data_retrieval_tool THẬT (embedding + vector search) với retry # Gọi data_retrieval_tool THẬT (embedding + vector search) với retry
result_json = await retry_with_backoff( result_json = await retry_with_backoff(
lambda: data_retrieval_tool.ainvoke({"searches": [search_item]}), max_retries=3 lambda: data_retrieval_tool.ainvoke({"searches": [search_item]}), max_retries=3
) )
result = json.loads(result_json) result = json.loads(result_json)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"✅ Retriever completed in {elapsed_time:.3f}s") logger.info(f"✅ Retriever completed in {elapsed_time:.3f}s")
# Parse kết quả # Parse kết quả
search_results = result.get("results", [{}])[0] search_results = result.get("results", [{}])[0]
products = search_results.get("products", []) products = search_results.get("products", [])
return { return {
"status": result.get("status", "success"), "status": result.get("status", "success"),
"user_query": req.user_query, "user_query": req.user_query,
"user_id": req.user_id, "user_id": req.user_id,
"session_id": req.session_id, "session_id": req.session_id,
"search_params": search_item.dict(exclude_none=True), "search_params": search_item.dict(exclude_none=True),
"total_results": len(products), "total_results": len(products),
"products": products, "products": products,
"processing_time_ms": round(elapsed_time * 1000, 2), "processing_time_ms": round(elapsed_time * 1000, 2),
} }
except Exception as e: except Exception as e:
logger.error(f"❌ Error in retriever DB: {e!s}", exc_info=True) logger.error(f"❌ Error in retriever DB: {e!s}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Retriever DB Error: {e!s}") raise HTTPException(status_code=500, detail=f"Retriever DB Error: {e!s}")
""" """
Middleware Module - Gom Auth + Rate Limit Middleware Module - Gom Auth + Rate Limit
Singleton Pattern cho cả 2 services Singleton Pattern cho cả 2 services
""" """
from __future__ import annotations from __future__ import annotations
import json import json
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from fastapi import HTTPException, Request, status from fastapi import HTTPException, Request, status
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi import FastAPI from fastapi import FastAPI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ============================================================================= # =============================================================================
# CONFIGURATION # CONFIGURATION
# ============================================================================= # =============================================================================
# Public endpoints - không cần auth # Public endpoints - không cần auth
PUBLIC_PATHS = { PUBLIC_PATHS = {
"/", "/",
"/health", "/health",
"/docs", "/docs",
"/openapi.json", "/openapi.json",
"/redoc", "/redoc",
} }
# Public path prefixes # Public path prefixes
PUBLIC_PATH_PREFIXES = [ PUBLIC_PATH_PREFIXES = [
"/static", "/static",
"/mock", "/mock",
] "/api/mock",
]
# =============================================================================
# AUTH + RATE LIMIT MIDDLEWARE CLASS # =============================================================================
# ============================================================================= # AUTH + RATE LIMIT MIDDLEWARE CLASS
# =============================================================================
# Paths that need rate limit check
RATE_LIMITED_PATHS = [ # Paths that need rate limit check
"/api/agent/chat", RATE_LIMITED_PATHS = [
] "/api/agent/chat",
]
class CanifaAuthMiddleware(BaseHTTPMiddleware):
""" class CanifaAuthMiddleware(BaseHTTPMiddleware):
Canifa Authentication + Rate Limit Middleware """
Canifa Authentication + Rate Limit Middleware
Flow:
1. Frontend gửi request với Authorization: Bearer <canifa_token> Flow:
2. Middleware verify token với Canifa API → extract customer_id 1. Frontend gửi request với Authorization: Bearer <canifa_token>
3. Check message rate limit (Guest: 10, User: 100) 2. Middleware verify token với Canifa API → extract customer_id
4. Attach user info vào request.state 3. Check message rate limit (Guest: 10, User: 100)
5. Routes lấy trực tiếp từ request.state 4. Attach user info vào request.state
""" 5. Routes lấy trực tiếp từ request.state
"""
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path async def dispatch(self, request: Request, call_next: Callable):
method = request.method path = request.url.path
method = request.method
# ✅ Allow OPTIONS requests (CORS preflight)
if method == "OPTIONS": # ✅ Allow OPTIONS requests (CORS preflight)
return await call_next(request) if method == "OPTIONS":
return await call_next(request)
# Skip public endpoints
if path in PUBLIC_PATHS: # Skip public endpoints
return await call_next(request) if path in PUBLIC_PATHS:
return await call_next(request)
# Skip public path prefixes
if any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES): # Skip public path prefixes
return await call_next(request) if any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES):
return await call_next(request)
# =====================================================================
# STEP 1: AUTHENTICATION (Canifa API) # =====================================================================
# ===================================================================== # STEP 1: AUTHENTICATION (Canifa API)
try: # =====================================================================
auth_header = request.headers.get("Authorization") try:
auth_header = request.headers.get("Authorization")
# --- Device ID from Body ---
device_id = "" # --- Device ID from Body ---
if method in ["POST", "PUT", "PATCH"]: device_id = ""
try: if method in ["POST", "PUT", "PATCH"]:
body_bytes = await request.body() try:
body_bytes = await request.body()
async def receive_wrapper():
return {"type": "http.request", "body": body_bytes} async def receive_wrapper():
request._receive = receive_wrapper return {"type": "http.request", "body": body_bytes}
request._receive = receive_wrapper
if body_bytes:
try: if body_bytes:
body_json = json.loads(body_bytes) try:
device_id = body_json.get("device_id", "") body_json = json.loads(body_bytes)
except json.JSONDecodeError: device_id = body_json.get("device_id", "")
pass except json.JSONDecodeError:
except Exception as e: pass
logger.warning(f"Error reading device_id from body: {e}") except Exception as e:
logger.warning(f"Error reading device_id from body: {e}")
# Fallback: Nếu không có trong body, tìm trong header -> IP
if not device_id: # Fallback: Nếu không có trong body, tìm trong header -> IP
device_id = request.headers.get("device_id", "") if not device_id:
device_id = request.headers.get("device_id", "")
if not device_id:
device_id = f"unknown_{request.client.host}" if request.client else "unknown" if not device_id:
device_id = f"unknown_{request.client.host}" if request.client else "unknown"
# ========== DEV MODE: Bypass auth ==========
dev_user_id = request.headers.get("X-Dev-User-Id") # ========== DEV MODE: Bypass auth ==========
if dev_user_id: dev_user_id = request.headers.get("X-Dev-User-Id")
logger.warning(f"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id}") if dev_user_id:
request.state.user = {"customer_id": dev_user_id} logger.warning(f"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id}")
request.state.user_id = dev_user_id request.state.user = {"customer_id": dev_user_id}
request.state.is_authenticated = True request.state.user_id = dev_user_id
request.state.device_id = device_id or dev_user_id request.state.is_authenticated = True
return await call_next(request) request.state.device_id = device_id or dev_user_id
return await call_next(request)
# --- TRƯỜNG HỢP 1: KHÔNG CÓ TOKEN -> GUEST ---
if not auth_header or not auth_header.startswith("Bearer "): # --- TRƯỜNG HỢP 1: KHÔNG CÓ TOKEN -> GUEST ---
request.state.user = None if not auth_header or not auth_header.startswith("Bearer "):
request.state.user_id = None request.state.user = None
request.state.is_authenticated = False request.state.user_id = None
request.state.device_id = device_id request.state.is_authenticated = False
else: request.state.device_id = device_id
# --- TRƯỜNG HỢP 2: CÓ TOKEN -> GỌI CANIFA VERIFY --- else:
token = auth_header.replace("Bearer ", "") # --- TRƯỜNG HỢP 2: CÓ TOKEN -> GỌI CANIFA VERIFY ---
token = auth_header.replace("Bearer ", "")
from common.canifa_api import verify_canifa_token, extract_user_id_from_canifa_response
from common.canifa_api import verify_canifa_token, extract_user_id_from_canifa_response
try:
user_data = await verify_canifa_token(token) try:
user_id = await extract_user_id_from_canifa_response(user_data) user_data = await verify_canifa_token(token)
user_id = await extract_user_id_from_canifa_response(user_data)
if user_id:
request.state.user = user_data if user_id:
request.state.user_id = user_id request.state.user = user_data
request.state.token = token request.state.user_id = user_id
request.state.is_authenticated = True request.state.token = token
request.state.device_id = device_id request.state.is_authenticated = True
logger.debug(f"✅ Canifa Auth Success: User {user_id}") request.state.device_id = device_id
else: logger.debug(f"✅ Canifa Auth Success: User {user_id}")
logger.warning(f"⚠️ Invalid Canifa Token -> Guest Mode") else:
request.state.user = None logger.warning(f"⚠️ Invalid Canifa Token -> Guest Mode")
request.state.user_id = None request.state.user = None
request.state.is_authenticated = False request.state.user_id = None
request.state.device_id = device_id request.state.is_authenticated = False
request.state.device_id = device_id
except Exception as e:
logger.error(f"❌ Canifa Auth Error: {e} -> Guest Mode") except Exception as e:
request.state.user = None logger.error(f"❌ Canifa Auth Error: {e} -> Guest Mode")
request.state.user_id = None request.state.user = None
request.state.is_authenticated = False request.state.user_id = None
request.state.device_id = device_id request.state.is_authenticated = False
request.state.device_id = device_id
except Exception as e:
logger.error(f"❌ Middleware Auth Error: {e}") except Exception as e:
request.state.user = None logger.error(f"❌ Middleware Auth Error: {e}")
request.state.user_id = None request.state.user = None
request.state.is_authenticated = False request.state.user_id = None
request.state.device_id = "" request.state.is_authenticated = False
request.state.device_id = ""
# =====================================================================
# STEP 2: RATE LIMIT CHECK (Chỉ cho các path cần limit) # =====================================================================
# ===================================================================== # STEP 2: RATE LIMIT CHECK (Chỉ cho các path cần limit)
if path in RATE_LIMITED_PATHS: # =====================================================================
try: if path in RATE_LIMITED_PATHS:
from common.message_limit import message_limit_service try:
from fastapi.responses import JSONResponse from common.message_limit import message_limit_service
from fastapi.responses import JSONResponse
# Lấy identity_key làm rate limit key
# Guest: device_id → limit 10 # Lấy identity_key làm rate limit key
# User: user_id → limit 100 # Guest: device_id → limit 10
is_authenticated = request.state.is_authenticated # User: user_id → limit 100
if is_authenticated and request.state.user_id: is_authenticated = request.state.is_authenticated
rate_limit_key = request.state.user_id if is_authenticated and request.state.user_id:
else: rate_limit_key = request.state.user_id
rate_limit_key = request.state.device_id else:
rate_limit_key = request.state.device_id
if rate_limit_key:
can_send, limit_info = await message_limit_service.check_limit( if rate_limit_key:
identity_key=rate_limit_key, can_send, limit_info = await message_limit_service.check_limit(
is_authenticated=is_authenticated, identity_key=rate_limit_key,
) is_authenticated=is_authenticated,
)
# Lưu limit_info vào request.state để route có thể dùng
request.state.limit_info = limit_info # Lưu limit_info vào request.state để route có thể dùng
request.state.limit_info = limit_info
if not can_send:
logger.warning( if not can_send:
f"⚠️ Rate Limit Exceeded: {rate_limit_key} | " logger.warning(
f"used={limit_info['used']}/{limit_info['limit']}" f"⚠️ Rate Limit Exceeded: {rate_limit_key} | "
) f"used={limit_info['used']}/{limit_info['limit']}"
return JSONResponse( )
status_code=429, return JSONResponse(
content={ status_code=429,
"status": "error", content={
"error_code": limit_info.get("error_code") or "MESSAGE_LIMIT_EXCEEDED", "status": "error",
"message": limit_info["message"], "error_code": limit_info.get("error_code") or "MESSAGE_LIMIT_EXCEEDED",
"require_login": limit_info["require_login"], "message": limit_info["message"],
"limit_info": { "require_login": limit_info["require_login"],
"limit": limit_info["limit"], "limit_info": {
"used": limit_info["used"], "limit": limit_info["limit"],
"remaining": limit_info["remaining"], "used": limit_info["used"],
"reset_seconds": limit_info["reset_seconds"], "remaining": limit_info["remaining"],
}, "reset_seconds": limit_info["reset_seconds"],
}, },
) },
else: )
logger.warning(f"⚠️ No identity_key for rate limiting") else:
logger.warning(f"⚠️ No identity_key for rate limiting")
except Exception as e:
logger.error(f"❌ Rate Limit Check Error: {e}") except Exception as e:
# Cho phép request tiếp tục nếu lỗi rate limit logger.error(f"❌ Rate Limit Check Error: {e}")
# Cho phép request tiếp tục nếu lỗi rate limit
return await call_next(request)
return await call_next(request)
# =============================================================================
# MIDDLEWARE MANAGER - Singleton to manage all middlewares # =============================================================================
# ============================================================================= # MIDDLEWARE MANAGER - Singleton to manage all middlewares
# =============================================================================
class MiddlewareManager:
""" class MiddlewareManager:
Middleware Manager - Singleton Pattern """
Quản lý và setup tất cả middlewares cho FastAPI app Middleware Manager - Singleton Pattern
Quản lý và setup tất cả middlewares cho FastAPI app
Usage:
from common.middleware import middleware_manager Usage:
from common.middleware import middleware_manager
app = FastAPI()
middleware_manager.setup(app, enable_auth=True, enable_rate_limit=True) app = FastAPI()
""" middleware_manager.setup(app, enable_auth=True, enable_rate_limit=True)
"""
_instance: MiddlewareManager | None = None
_initialized: bool = False _instance: MiddlewareManager | None = None
_initialized: bool = False
def __new__(cls) -> MiddlewareManager:
if cls._instance is None: def __new__(cls) -> MiddlewareManager:
cls._instance = super().__new__(cls) if cls._instance is None:
return cls._instance cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if MiddlewareManager._initialized: def __init__(self) -> None:
return if MiddlewareManager._initialized:
return
self._auth_enabled = False
self._rate_limit_enabled = False self._auth_enabled = False
self._rate_limit_enabled = False
MiddlewareManager._initialized = True
logger.info("✅ MiddlewareManager initialized") MiddlewareManager._initialized = True
logger.info("✅ MiddlewareManager initialized")
def setup(
self, def setup(
app: FastAPI, self,
*, app: FastAPI,
enable_auth: bool = True, *,
enable_rate_limit: bool = True, enable_auth: bool = True,
enable_cors: bool = True, enable_rate_limit: bool = True,
cors_origins: list[str] | None = None, enable_cors: bool = True,
) -> None: cors_origins: list[str] | None = None,
""" ) -> None:
Setup tất cả middlewares cho FastAPI app. """
Setup tất cả middlewares cho FastAPI app.
Args:
app: FastAPI application Args:
enable_auth: Bật Canifa authentication middleware app: FastAPI application
enable_rate_limit: Bật rate limiting enable_auth: Bật Canifa authentication middleware
enable_cors: Bật CORS middleware enable_rate_limit: Bật rate limiting
cors_origins: List origins cho CORS (default: ["*"]) enable_cors: Bật CORS middleware
cors_origins: List origins cho CORS (default: ["*"])
Note:
Thứ tự middleware quan trọng! Middleware thêm sau sẽ chạy TRƯỚC. Note:
Order: CORS → Auth → RateLimit → SlowAPI Thứ tự middleware quan trọng! Middleware thêm sau sẽ chạy TRƯỚC.
""" Order: CORS → Auth → RateLimit → SlowAPI
# 1. CORS Middleware (thêm cuối cùng để chạy đầu tiên) """
if enable_cors: # 1. CORS Middleware (thêm cuối cùng để chạy đầu tiên)
self._setup_cors(app, cors_origins or ["*"]) if enable_cors:
self._setup_cors(app, cors_origins or ["*"])
# 2. Auth Middleware
if enable_auth: # 2. Auth Middleware
self._setup_auth(app) if enable_auth:
self._setup_auth(app)
# 3. Rate Limit Middleware
if enable_rate_limit: # 3. Rate Limit Middleware
self._setup_rate_limit(app) if enable_rate_limit:
self._setup_rate_limit(app)
logger.info(
f"✅ Middlewares configured: " logger.info(
f"CORS={enable_cors}, Auth={enable_auth}, RateLimit={enable_rate_limit}" f"✅ Middlewares configured: "
) f"CORS={enable_cors}, Auth={enable_auth}, RateLimit={enable_rate_limit}"
)
def _setup_cors(self, app: FastAPI, origins: list[str]) -> None:
"""Setup CORS middleware.""" def _setup_cors(self, app: FastAPI, origins: list[str]) -> None:
from fastapi.middleware.cors import CORSMiddleware """Setup CORS middleware."""
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware, app.add_middleware(
allow_origins=origins, CORSMiddleware,
allow_credentials=True, allow_origins=origins,
allow_methods=["*"], allow_credentials=True,
allow_headers=["*"], allow_methods=["*"],
) allow_headers=["*"],
logger.info(f"✅ CORS middleware enabled (origins: {origins})") )
logger.info(f"✅ CORS middleware enabled (origins: {origins})")
def _setup_auth(self, app: FastAPI) -> None:
"""Setup Canifa auth middleware.""" def _setup_auth(self, app: FastAPI) -> None:
app.add_middleware(CanifaAuthMiddleware) """Setup Canifa auth middleware."""
self._auth_enabled = True app.add_middleware(CanifaAuthMiddleware)
logger.info("✅ Canifa Auth middleware enabled") self._auth_enabled = True
logger.info("✅ Canifa Auth middleware enabled")
def _setup_rate_limit(self, app: FastAPI) -> None:
"""Setup rate limiting.""" def _setup_rate_limit(self, app: FastAPI) -> None:
from common.rate_limit import rate_limit_service """Setup rate limiting."""
from common.rate_limit import rate_limit_service
rate_limit_service.setup(app)
self._rate_limit_enabled = True rate_limit_service.setup(app)
logger.info("✅ Rate Limit middleware enabled") self._rate_limit_enabled = True
logger.info("✅ Rate Limit middleware enabled")
@property
def is_auth_enabled(self) -> bool: @property
return self._auth_enabled def is_auth_enabled(self) -> bool:
return self._auth_enabled
@property
def is_rate_limit_enabled(self) -> bool: @property
return self._rate_limit_enabled def is_rate_limit_enabled(self) -> bool:
return self._rate_limit_enabled
# =============================================================================
# SINGLETON INSTANCE # =============================================================================
# ============================================================================= # SINGLETON INSTANCE
# =============================================================================
middleware_manager = MiddlewareManager()
middleware_manager = MiddlewareManager()
""" """
Rate Limiting Service - Singleton Pattern Rate Limiting Service - Singleton Pattern
Sử dụng SlowAPI với Redis backend (production) hoặc Memory (dev) Sử dụng SlowAPI với Redis backend (production) hoặc Memory (dev)
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from slowapi import Limiter from slowapi import Limiter
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi import FastAPI from fastapi import FastAPI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RateLimitService: class RateLimitService:
""" """
Rate Limiting Service - Singleton Pattern Rate Limiting Service - Singleton Pattern
Usage: Usage:
# Trong server.py # Trong server.py
from common.rate_limit import RateLimitService from common.rate_limit import RateLimitService
rate_limiter = RateLimitService() rate_limiter = RateLimitService()
rate_limiter.setup(app) rate_limiter.setup(app)
# Trong route # Trong route
from common.rate_limit import RateLimitService from common.rate_limit import RateLimitService
@router.post("/chat") @router.post("/chat")
@RateLimitService().limiter.limit("10/minute") @RateLimitService().limiter.limit("10/minute")
async def chat(request: Request): async def chat(request: Request):
... ...
""" """
_instance: RateLimitService | None = None _instance: RateLimitService | None = None
_initialized: bool = False _initialized: bool = False
# ========================================================================= # =========================================================================
# SINGLETON PATTERN # SINGLETON PATTERN
# ========================================================================= # =========================================================================
def __new__(cls) -> RateLimitService: def __new__(cls) -> RateLimitService:
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self) -> None: def __init__(self) -> None:
# Chỉ init một lần # Chỉ init một lần
if RateLimitService._initialized: if RateLimitService._initialized:
return return
# Configuration # Configuration
self.storage_uri = os.getenv("RATE_STORAGE_URI", "memory://") self.storage_uri = os.getenv("RATE_STORAGE_URI", "memory://")
self.default_limits = ["100/hour", "30/minute"] self.default_limits = ["100/hour", "30/minute"]
self.block_duration_minutes = int(os.getenv("RATE_LIMIT_BLOCK_MINUTES", "5")) self.block_duration_minutes = int(os.getenv("RATE_LIMIT_BLOCK_MINUTES", "5"))
# Paths không áp dụng rate limit # Paths không áp dụng rate limit
self.exempt_paths = { self.exempt_paths = {
"/", "/",
"/health", "/health",
"/docs", "/docs",
"/openapi.json", "/openapi.json",
"/redoc", "/redoc",
} }
self.exempt_prefixes = ["/static", "/mock"] self.exempt_prefixes = ["/static", "/mock", "/api/mock"]
# In-memory blocklist (có thể chuyển sang Redis) # In-memory blocklist (có thể chuyển sang Redis)
self._blocklist: dict[str, datetime] = {} self._blocklist: dict[str, datetime] = {}
# Create limiter instance # Create limiter instance
self.limiter = Limiter( self.limiter = Limiter(
key_func=self._get_client_identifier, key_func=self._get_client_identifier,
storage_uri=self.storage_uri, storage_uri=self.storage_uri,
default_limits=self.default_limits, default_limits=self.default_limits,
) )
RateLimitService._initialized = True RateLimitService._initialized = True
logger.info(f"✅ RateLimitService initialized (storage: {self.storage_uri})") logger.info(f"✅ RateLimitService initialized (storage: {self.storage_uri})")
# ========================================================================= # =========================================================================
# CLIENT IDENTIFIER # CLIENT IDENTIFIER
# ========================================================================= # =========================================================================
@staticmethod @staticmethod
def _get_client_identifier(request: Request) -> str: def _get_client_identifier(request: Request) -> str:
""" """
Lấy client identifier cho rate limiting. Lấy client identifier cho rate limiting.
Ưu tiên: user_id (authenticated) > device_id > IP address Ưu tiên: user_id (authenticated) > device_id > IP address
""" """
# 1. Nếu đã authenticated → dùng user_id # 1. Nếu đã authenticated → dùng user_id
if hasattr(request.state, "user_id") and request.state.user_id: if hasattr(request.state, "user_id") and request.state.user_id:
return f"user:{request.state.user_id}" return f"user:{request.state.user_id}"
# 2. Nếu có device_id trong header → dùng device_id # 2. Nếu có device_id trong header → dùng device_id
device_id = request.headers.get("device_id") device_id = request.headers.get("device_id")
if device_id: if device_id:
return f"device:{device_id}" return f"device:{device_id}"
# 3. Fallback → IP address # 3. Fallback → IP address
try: try:
return f"ip:{get_remote_address(request)}" return f"ip:{get_remote_address(request)}"
except Exception: except Exception:
if request.client: if request.client:
return f"ip:{request.client.host}" return f"ip:{request.client.host}"
return "unknown" return "unknown"
# ========================================================================= # =========================================================================
# BLOCKLIST MANAGEMENT # BLOCKLIST MANAGEMENT
# ========================================================================= # =========================================================================
def is_blocked(self, key: str) -> tuple[bool, int]: def is_blocked(self, key: str) -> tuple[bool, int]:
""" """
Check if client is blocked. Check if client is blocked.
Returns: (is_blocked, retry_after_seconds) Returns: (is_blocked, retry_after_seconds)
""" """
now = datetime.utcnow() now = datetime.utcnow()
blocked_until = self._blocklist.get(key) blocked_until = self._blocklist.get(key)
if blocked_until: if blocked_until:
if blocked_until > now: if blocked_until > now:
retry_after = int((blocked_until - now).total_seconds()) retry_after = int((blocked_until - now).total_seconds())
return True, retry_after return True, retry_after
else: else:
# Block expired # Block expired
self._blocklist.pop(key, None) self._blocklist.pop(key, None)
return False, 0 return False, 0
def block_client(self, key: str) -> int: def block_client(self, key: str) -> int:
""" """
Block client for configured duration. Block client for configured duration.
Returns: retry_after_seconds Returns: retry_after_seconds
""" """
self._blocklist[key] = datetime.utcnow() + timedelta(minutes=self.block_duration_minutes) self._blocklist[key] = datetime.utcnow() + timedelta(minutes=self.block_duration_minutes)
return self.block_duration_minutes * 60 return self.block_duration_minutes * 60
def unblock_client(self, key: str) -> None: def unblock_client(self, key: str) -> None:
"""Unblock client manually.""" """Unblock client manually."""
self._blocklist.pop(key, None) self._blocklist.pop(key, None)
# ========================================================================= # =========================================================================
# PATH CHECKING # PATH CHECKING
# ========================================================================= # =========================================================================
def is_exempt(self, path: str) -> bool: def is_exempt(self, path: str) -> bool:
"""Check if path is exempt from rate limiting.""" """Check if path is exempt from rate limiting."""
if path in self.exempt_paths: if path in self.exempt_paths:
return True return True
return any(path.startswith(prefix) for prefix in self.exempt_prefixes) return any(path.startswith(prefix) for prefix in self.exempt_prefixes)
# ========================================================================= # =========================================================================
# SETUP FOR FASTAPI APP # SETUP FOR FASTAPI APP
# ========================================================================= # =========================================================================
def setup(self, app: FastAPI) -> None: def setup(self, app: FastAPI) -> None:
""" """
Setup rate limiting cho FastAPI app. Setup rate limiting cho FastAPI app.
Gọi trong server.py sau khi tạo app. Gọi trong server.py sau khi tạo app.
""" """
# Attach limiter to app state # Attach limiter to app state
app.state.limiter = self.limiter app.state.limiter = self.limiter
app.state.rate_limit_service = self app.state.rate_limit_service = self
# Register middleware # Register middleware
self._register_block_middleware(app) self._register_block_middleware(app)
self._register_exception_handler(app) self._register_exception_handler(app)
# Add SlowAPI middleware (PHẢI thêm SAU custom middlewares) # Add SlowAPI middleware (PHẢI thêm SAU custom middlewares)
app.add_middleware(SlowAPIMiddleware) app.add_middleware(SlowAPIMiddleware)
logger.info("✅ Rate limiting middleware registered") logger.info("✅ Rate limiting middleware registered")
def _register_block_middleware(self, app: FastAPI) -> None: def _register_block_middleware(self, app: FastAPI) -> None:
"""Register middleware to check blocklist.""" """Register middleware to check blocklist."""
@app.middleware("http") @app.middleware("http")
async def rate_limit_block_middleware(request: Request, call_next): async def rate_limit_block_middleware(request: Request, call_next):
path = request.url.path path = request.url.path
# Skip exempt paths # Skip exempt paths
if self.is_exempt(path): if self.is_exempt(path):
return await call_next(request) return await call_next(request)
# Bypass header cho testing # Bypass header cho testing
if request.headers.get("X-Bypass-RateLimit") == "1": if request.headers.get("X-Bypass-RateLimit") == "1":
return await call_next(request) return await call_next(request)
# Check blocklist # Check blocklist
key = self._get_client_identifier(request) key = self._get_client_identifier(request)
is_blocked, retry_after = self.is_blocked(key) is_blocked, retry_after = self.is_blocked(key)
if is_blocked: if is_blocked:
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={ content={
"detail": "Quá số lượt cho phép. Vui lòng thử lại sau.", "detail": "Quá số lượt cho phép. Vui lòng thử lại sau.",
"retry_after_seconds": retry_after, "retry_after_seconds": retry_after,
}, },
headers={"Retry-After": str(retry_after)}, headers={"Retry-After": str(retry_after)},
) )
return await call_next(request) return await call_next(request)
def _register_exception_handler(self, app: FastAPI) -> None: def _register_exception_handler(self, app: FastAPI) -> None:
"""Register exception handler for rate limit exceeded.""" """Register exception handler for rate limit exceeded."""
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
key = self._get_client_identifier(request) key = self._get_client_identifier(request)
retry_after = self.block_client(key) retry_after = self.block_client(key)
logger.warning(f"⚠️ Rate limit exceeded for {key}, blocked for {self.block_duration_minutes} minutes") logger.warning(f"⚠️ Rate limit exceeded for {key}, blocked for {self.block_duration_minutes} minutes")
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={ content={
"detail": "Quá số lượt cho phép. Vui lòng thử lại sau.", "detail": "Quá số lượt cho phép. Vui lòng thử lại sau.",
"retry_after_seconds": retry_after, "retry_after_seconds": retry_after,
}, },
headers={"Retry-After": str(retry_after)}, headers={"Retry-After": str(retry_after)},
) )
# ============================================================================= # =============================================================================
# SINGLETON INSTANCE - Import trực tiếp để dùng # SINGLETON INSTANCE - Import trực tiếp để dùng
# ============================================================================= # =============================================================================
rate_limit_service = RateLimitService() rate_limit_service = RateLimitService()
import asyncio import asyncio
import os import os
import platform import platform
if platform.system() == "Windows": if platform.system() == "Windows":
print("🔧 Windows detected: Applying SelectorEventLoopPolicy globally...") print("🔧 Windows detected: Applying SelectorEventLoopPolicy globally...")
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import logging import logging
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from api.chatbot_route import router as chatbot_router from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router from api.conservation_route import router as conservation_router
from api.prompt_route import router as prompt_router from api.prompt_route import router as prompt_router
from common.cache import redis_cache from common.cache import redis_cache
from common.langfuse_client import get_langfuse_client from common.langfuse_client import get_langfuse_client
from common.middleware import middleware_manager from common.middleware import middleware_manager
from config import PORT from config import PORT
# Configure Logging # Configure Logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[logging.StreamHandler()], handlers=[logging.StreamHandler()],
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
langfuse_client = get_langfuse_client() langfuse_client = get_langfuse_client()
if langfuse_client: if langfuse_client:
logger.info("✅ Langfuse client ready (lazy loading)") logger.info("✅ Langfuse client ready (lazy loading)")
else: else:
logger.warning("⚠️ Langfuse client not available (missing keys or disabled)") logger.warning("⚠️ Langfuse client not available (missing keys or disabled)")
app = FastAPI( app = FastAPI(
title="Contract AI Service", title="Contract AI Service",
description="API for Contract AI Service", description="API for Contract AI Service",
version="1.0.0", version="1.0.0",
) )
# ============================================================================= # =============================================================================
# STARTUP EVENT - Initialize Redis Cache # STARTUP EVENT - Initialize Redis Cache
# ============================================================================= # =============================================================================
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
"""Initialize Redis cache on startup.""" """Initialize Redis cache on startup."""
await redis_cache.initialize() await redis_cache.initialize()
logger.info("✅ Redis cache initialized for message limit") logger.info("✅ Redis cache initialized for message limit")
# ============================================================================= # =============================================================================
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ # MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# ============================================================================= # =============================================================================
middleware_manager.setup( middleware_manager.setup(
app, app,
enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit=True, # 👈 Bật lại SlowAPI theo yêu cầu enable_rate_limit=True, # 👈 Bật lại SlowAPI theo yêu cầu
enable_cors=True, # 👈 Bật CORS enable_cors=True, # 👈 Bật CORS
cors_origins=["*"], # 👈 Trong production nên limit origins cors_origins=["*"], # 👈 Trong production nên limit origins
) )
app.include_router(conservation_router) app.include_router(conservation_router)
app.include_router(chatbot_router) app.include_router(chatbot_router)
app.include_router(prompt_router) app.include_router(prompt_router)
# --- MOCK API FOR LOAD TESTING --- # --- MOCK API FOR LOAD TESTING ---
try: try:
from api.mock_api_route import router as mock_router from api.mock_api_route import router as mock_router
app.include_router(mock_router) app.include_router(mock_router)
print("✅ Mock API Router mounted at /mock") print("✅ Mock API Router mounted at /api/mock")
except ImportError: except ImportError:
print("⚠️ Mock Router not found, skipping...") print("⚠️ Mock Router not found, skipping...")
# ========================================== # ==========================================
# 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢 # 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢
# ========================================== # ==========================================
try: try:
static_dir = os.path.join(os.path.dirname(__file__), "static") static_dir = os.path.join(os.path.dirname(__file__), "static")
if not os.path.exists(static_dir): if not os.path.exists(static_dir):
os.makedirs(static_dir) os.makedirs(static_dir)
# Mount thư mục static để chạy file index.html # Mount thư mục static để chạy file index.html
app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static") app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
print(f"✅ Static files mounted at /static (Dir: {static_dir})") print(f"✅ Static files mounted at /static (Dir: {static_dir})")
except Exception as e: except Exception as e:
print(f"⚠️ Failed to mount static files: {e}") print(f"⚠️ Failed to mount static files: {e}")
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
@app.get("/") @app.get("/")
async def root(): async def root():
return RedirectResponse(url="/static/index.html") return RedirectResponse(url="/static/index.html")
if __name__ == "__main__": if __name__ == "__main__":
print("=" * 60) print("=" * 60)
print("🚀 Contract AI Service Starting...") print("🚀 Contract AI Service Starting...")
print("=" * 60) print("=" * 60)
print(f"📡 REST API: http://localhost:{PORT}") print(f"📡 REST API: http://localhost:{PORT}")
print(f"📡 Test Chatbot: http://localhost:{PORT}/static/index.html") print(f"📡 Test Chatbot: http://localhost:{PORT}/static/index.html")
print(f"📚 API Docs: http://localhost:{PORT}/docs") print(f"📚 API Docs: http://localhost:{PORT}/docs")
print("=" * 60) print("=" * 60)
ENABLE_RELOAD = False ENABLE_RELOAD = False
print(f"⚠️ Hot reload: {ENABLE_RELOAD}") print(f"⚠️ Hot reload: {ENABLE_RELOAD}")
reload_dirs = ["common", "api", "agent"] reload_dirs = ["common", "api", "agent"]
if ENABLE_RELOAD: if ENABLE_RELOAD:
os.environ["PYTHONUNBUFFERED"] = "1" os.environ["PYTHONUNBUFFERED"] = "1"
uvicorn.run( uvicorn.run(
"server:app", "server:app",
host="0.0.0.0", host="0.0.0.0",
port=PORT, port=PORT,
reload=ENABLE_RELOAD, reload=ENABLE_RELOAD,
reload_dirs=reload_dirs, reload_dirs=reload_dirs,
log_level="info", log_level="info",
) )
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment