Commit c66892a7 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

refactor: simplify identity management, rate limiting, and chat history

parent d20b1e77
......@@ -4,7 +4,6 @@ Langfuse will auto-trace via LangChain integration (no code changes needed).
"""
import logging
import time
import uuid
from fastapi import BackgroundTasks
......@@ -14,14 +13,12 @@ from langchain_core.runnables import RunnableConfig
from common.cache import redis_cache
from common.conversation_manager import get_conversation_manager
from common.langfuse_client import get_callback_handler
from common.llm_factory import create_llm
from config import DEFAULT_MODEL, REDIS_CACHE_TURN_ON
from langfuse import propagate_attributes
from .graph import build_graph
from .helper import extract_product_ids, handle_post_chat_async, parse_ai_response
from .models import AgentState, get_config
from .tools.get_tools import get_all_tools
logger = logging.getLogger(__name__)
......@@ -75,15 +72,13 @@ async def chat_controller(
config = get_config()
config.model_name = model_name
llm = create_llm(model_name=model_name, streaming=False, json_mode=True)
# tools = get_all_tools() # Singleton now handles tools
graph = build_graph(config) # Singleton usage
graph = build_graph(config)
# Init ConversationManager (Singleton)
memory = await get_conversation_manager()
# Load History
history_dicts = await memory.get_chat_history(effective_identity_key, limit=15)
# Load History (only text, no product_ids for AI context)
history_dicts = await memory.get_chat_history(effective_identity_key, limit=15, include_product_ids=False)
messages = [
HumanMessage(content=m["message"]) if m["is_human"] else AIMessage(content=m["message"])
for m in history_dicts
......@@ -114,26 +109,15 @@ async def chat_controller(
)
# Execute Graph
start_time = time.time()
session_id = f"{user_id}-{run_id[:8]}"
with propagate_attributes(user_id=user_id, session_id=session_id):
result = await graph.ainvoke(initial_state, config=exec_config)
duration = time.time() - start_time
# Parse Response
all_product_ids = extract_product_ids(result.get("messages", []))
logger.info("🔍 [DEBUG] all_product_ids count: %s", len(all_product_ids))
if all_product_ids:
logger.info("🔍 [DEBUG] First product from tool: %s", all_product_ids[0])
ai_raw_content = result.get("ai_response").content if result.get("ai_response") else ""
ai_text_response, final_product_ids = parse_ai_response(ai_raw_content, all_product_ids)
logger.info("🔍 [DEBUG] final_product_ids count: %s, type: %s", len(final_product_ids), type(final_product_ids[0]) if final_product_ids else "empty")
if final_product_ids:
logger.info("🔍 [DEBUG] First final product: %s", final_product_ids[0])
response_payload = {
"ai_response": ai_text_response,
......@@ -159,6 +143,6 @@ async def chat_controller(
ai_response=response_payload,
)
logger.info("chat_controller finished in %.2fs", duration)
logger.info("chat_controller finished")
return {**response_payload, "cached": False}
......@@ -16,7 +16,7 @@ from langchain_core.runnables import RunnableConfig
from common.conversation_manager import ConversationManager, get_conversation_manager
from common.langfuse_client import get_callback_handler
from common.llm_factory import create_llm
from agent.tools.data_retrieval_tool import SearchItem, data_retrieval_tool
from config import DEFAULT_MODEL
from .graph import build_graph
......@@ -51,11 +51,8 @@ async def chat_controller(
config = get_config()
config.model_name = model_name
# Enable JSON mode to ensure structured output
llm = create_llm(model_name=model_name, streaming=False, json_mode=True)
tools = get_all_tools()
graph = build_graph(config, llm=llm, tools=tools)
tools = get_all_tools()
graph = build_graph(config, llm=None, tools=tools)
# Init ConversationManager (Singleton)
memory = await get_conversation_manager()
......@@ -180,9 +177,8 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images:
"tags": "chatbot,production",
}
# 🔥 CallbackHandler - sẽ được wrap trong langfuse_trace_context để set user_id
# Per Langfuse docs: propagate_attributes() handles user_id propagation
langfuse_handler = get_callback_handler()
# CallbackHandler for Langfuse (if enabled)
langfuse_handler = get_callback_handler()
exec_config = RunnableConfig(
configurable={
......@@ -214,12 +210,12 @@ async def _handle_post_chat_async(
# ========================================
async def mock_chat_controller(
query: str,
user_id: str,
background_tasks: BackgroundTasks,
images: list[str] | None = None,
) -> dict:
async def mock_chat_controller(
query: str,
user_id: str,
background_tasks: BackgroundTasks,
images: list[str] | None = None,
) -> dict:
"""
Mock Agent Controller với FAKE LLM (không gọi OpenAI):
- Sử dụng toàn bộ graph flow từ chat_controller
......@@ -238,36 +234,24 @@ async def mock_chat_controller(
✅ Không cần JSON parsing (response là plain text)
✅ Nhanh hơn (~1-3ms giả lập LLM thay vì 1-3s real LLM)
"""
logger.info(f"🚀 [MOCK Chat Controller] Starting with query: {query} for user: {user_id}")
start_time = time.time()
config = get_config()
# KHÔNG gọi OpenAI - dùng tools THẬT nhưng fake LLM response
tools = get_all_tools()
graph = build_graph(config, llm=None, tools=tools) # llm=None để skip LLM node
# Init ConversationManager (Singleton)
memory = await get_conversation_manager()
# LOAD HISTORY & Prepare State
history_dicts = await memory.get_chat_history(user_id, limit=20)
history = []
for h in reversed(history_dicts):
msg_cls = HumanMessage if h["is_human"] else AIMessage
history.append(msg_cls(content=h["message"]))
initial_state, exec_config = _prepare_execution_context(
query=query, user_id=user_id, history=history, images=images
)
try:
# Chạy Graph với tools THẬT
result = await graph.ainvoke(initial_state, config=exec_config)
logger.info(f"🚀 [MOCK Chat Controller] Starting with query: {query} for user: {user_id}")
start_time = time.time()
# Extract products từ tool messages (tools THẬT)
all_product_ids = _extract_product_ids(result.get("messages", []))
# Init ConversationManager (Singleton)
memory = await get_conversation_manager()
try:
# Gọi tool trực tiếp (không qua LLM) để tránh bottleneck
search_item = SearchItem(
query=query or "sản phẩm",
magento_ref_code=None,
price_min=None,
price_max=None,
action="search",
)
result_json = await data_retrieval_tool.ainvoke({"searches": [search_item]})
result = json.loads(result_json)
all_product_ids = result.get("results", [{}])[0].get("products", [])
# Generate FAKE LLM response (không gọi OpenAI)
logger.info("🤖 [FAKE LLM] Generating mock response...")
......
This diff is collapsed.
......@@ -10,21 +10,16 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from fastapi.responses import JSONResponse
from opentelemetry import trace
from agent.controller import chat_controller
from agent.models import QueryRequest
from common.message_limit import message_limit_service
from common.user_identity import get_user_identity
from common.rate_limit import rate_limit_service
from config import DEFAULT_MODEL
logger = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)
router = APIRouter()
from common.rate_limit import rate_limit_service
@router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)")
@rate_limit_service.limiter.limit("50/minute")
async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks):
......@@ -33,55 +28,40 @@ async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks:
Note: Rate limit đã được check trong middleware.
"""
# 1. Xác định user identity
identity = get_user_identity(request)
user_id = identity.primary_id
# 1. Lấy user identity từ Middleware (request.state)
# Logic: Login -> User ID | Guest -> Device ID
user_id = getattr(request.state, "user_id", None)
device_id = getattr(request.state, "device_id", "unknown")
is_authenticated = getattr(request.state, "is_authenticated", False)
# Định danh duy nhất cho Request này (Log, History, Rate Limit, Langfuse)
identity_id = user_id if is_authenticated else device_id
# Rate limit đã check trong middleware, lấy limit_info từ request.state
limit_info = getattr(request.state, 'limit_info', None)
logger.info(f"📥 [Incoming Query - NonStream] User: {user_id} | Query: {req.user_query}")
# Get current span để add logs VÀO JAEGER UI
span = trace.get_current_span()
span.set_attribute("user.id", user_id)
span.set_attribute("chat.user_query", req.user_query)
span.add_event(
"📥 User query received", attributes={"user_id": user_id, "query": req.user_query, "timestamp": "incoming"}
)
logger.info(f"📥 [Incoming Query - NonStream] User: {identity_id} | Query: {req.user_query}")
try:
# Gọi controller để xử lý logic (Non-streaming)
result = await chat_controller(
query=req.user_query,
user_id=user_id,
user_id=str(identity_id), # Langfuse User ID
background_tasks=background_tasks,
model_name=DEFAULT_MODEL,
images=req.images,
identity_key=identity.history_key, # Guest: device_id, User: user_id
identity_key=str(identity_id), # Key lưu history
)
# Log chi tiết response
logger.info(f"📤 [Outgoing Response - NonStream] User: {user_id}")
logger.info(f"📤 [Outgoing Response - NonStream] User: {identity_id}")
logger.info(f"💬 AI Response: {result['ai_response']}")
logger.info(f"🛍️ Product IDs: {result.get('product_ids', [])}")
# Add to span (hiển thị trong Jaeger UI)
span.set_attribute("chat.ai_response", result["ai_response"][:200]) # Giới hạn 200 ký tự
span.set_attribute("chat.product_count", len(result.get("product_ids", [])))
span.add_event(
"💬 AI response generated",
attributes={
"ai_response_preview": result["ai_response"][:100],
"product_count": len(result.get("product_ids", [])),
"product_ids": str(result.get("product_ids", [])[:5]), # First 5 IDs
},
)
# Increment message count SAU KHI chat thành công
usage_info = await message_limit_service.increment(
identity_key=identity.rate_limit_key,
is_authenticated=identity.is_authenticated,
identity_key=identity_id,
is_authenticated=is_authenticated,
)
return {
......
"""
Chat History API Routes
- GET /api/history/{identity_key} - Lấy lịch sử chat (có product_ids)
- GET /api/history/{identity_key} - Lấy lịch sử chat
- DELETE /api/history/{identity_key} - Xóa lịch sử chat
Note: identity_key có thể là device_id (guest) hoặc user_id (đã login)
......@@ -12,7 +12,6 @@ from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from common.conversation_manager import get_conversation_manager
from common.user_identity import get_user_identity
router = APIRouter(tags=["Chat History"])
logger = logging.getLogger(__name__)
......@@ -40,17 +39,17 @@ async def get_chat_history(request: Request, identity_key: str, limit: int | Non
(identity_key trong URL chỉ là fallback)
"""
try:
# Tự động resolve identity từ middleware
identity = get_user_identity(request)
# Resolve identity từ middleware (request.state)
user_id = getattr(request.state, "user_id", None)
device_id = getattr(request.state, "device_id", identity_key)
is_authenticated = getattr(request.state, "is_authenticated", False)
# Nếu đã login -> Dùng user_id
if identity.is_authenticated:
resolved_key = identity.history_key
else:
# Nếu chưa login (Guest) -> Dùng identity_key từ URL
resolved_key = identity_key
logger.info(f"GET History: URL key={identity_key} -> Resolved key={resolved_key}")
# Log chi tiết để debug
logger.info(f"GET History: auth={is_authenticated} | user_id={user_id} | device_id={device_id}")
# Nếu đã login -> Dùng user_id, không thì dùng device_id
resolved_key = user_id if is_authenticated else device_id
logger.info(f"GET History: resolved_key={resolved_key}")
manager = await get_conversation_manager()
history = await manager.get_chat_history(resolved_key, limit=limit, before_id=before_id)
......@@ -100,10 +99,13 @@ async def archive_chat_history(request: Request):
Giới hạn 5 lần/ngày.
"""
try:
identity = get_user_identity(request)
# Resolve identity từ middleware (request.state)
user_id = getattr(request.state, "user_id", None)
device_id = getattr(request.state, "device_id", "")
is_authenticated = getattr(request.state, "is_authenticated", False)
# Chỉ dành cho User đã đăng nhập
if not identity.is_authenticated:
if not is_authenticated:
return JSONResponse(
status_code=401,
content={
......@@ -114,7 +116,7 @@ async def archive_chat_history(request: Request):
}
)
identity_key = identity.history_key
identity_key = user_id
# Check reset limit
can_reset, usage, remaining = await reset_limit_service.check_limit(identity_key)
......
......@@ -7,6 +7,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel
from agent.tools.data_retrieval_tool import SearchItem, data_retrieval_tool
from agent.mock_controller import mock_chat_controller
logger = logging.getLogger(__name__)
router = APIRouter()
......@@ -31,6 +32,7 @@ class MockQueryRequest(BaseModel):
user_query: str
user_id: str | None = "test_user"
session_id: str | None = None
images: list[str] | None = None
class MockDBRequest(BaseModel):
......@@ -62,10 +64,6 @@ MOCK_AI_RESPONSES = [
# --- ENDPOINTS ---
from agent.mock_controller import mock_chat_controller
@router.post("/api/mock/agent/chat", summary="Mock Agent Chat (Real Tools + Fake LLM)")
async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks):
"""
......@@ -82,6 +80,7 @@ async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks):
query=req.user_query,
user_id=req.user_id or "test_user",
background_tasks=background_tasks,
images=req.images,
)
return {
......@@ -146,9 +145,9 @@ async def mock_db_search(req: MockDBRequest):
raise HTTPException(status_code=500, detail=f"DB Search Error: {e!s}")
@router.post("/api/mock/retrieverdb", summary="Real Embedding + Real DB Vector Search")
@router.post("/api/mock/retriverdb", summary="Real Embedding + Real DB Vector Search (Legacy)")
async def mock_retriever_db(req: MockRetrieverRequest):
@router.post("/api/mock/retrieverdb", summary="Real Embedding + Real DB Vector Search")
@router.post("/api/mock/retriverdb", summary="Real Embedding + Real DB Vector Search (Legacy)")
async def mock_retriever_db(req: MockRetrieverRequest):
"""
API thực tế để test Retriever + DB Search (dùng agent tool):
- Lấy query từ user
......
......@@ -45,6 +45,7 @@ async def get_system_prompt_content(request: Request):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/api/agent/system-prompt")
@rate_limit_service.limiter.limit("10/minute")
async def update_system_prompt_content(request: Request, body: PromptUpdateRequest):
......
......@@ -10,7 +10,27 @@ import httpx
logger = logging.getLogger(__name__)
CANIFA_CUSTOMER_API = "https://vsf2.canifa.com/v1/magento/customer"
# CANIFA_CUSTOMER_API = "https://vsf2.canifa.com/v1/magento/customer"
CANIFA_CUSTOMER_API = "https://canifa.com/v1/magento/customer"
_http_client: httpx.AsyncClient | None = None
def _get_http_client() -> httpx.AsyncClient:
global _http_client
if _http_client is None:
_http_client = httpx.AsyncClient(timeout=10.0)
return _http_client
async def close_http_client() -> None:
global _http_client
if _http_client is not None:
await _http_client.aclose()
_http_client = None
CANIFA_QUERY_BODY = [
{
......@@ -43,23 +63,23 @@ async def verify_canifa_token(token: str) -> dict[str, Any] | None:
}
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(CANIFA_CUSTOMER_API, json=CANIFA_QUERY_BODY, headers=headers)
if response.status_code == 200:
data = response.json()
logger.debug(f"Canifa API Raw Response: {data}")
client = _get_http_client()
response = await client.post(CANIFA_CUSTOMER_API, json=CANIFA_QUERY_BODY, headers=headers)
# Response format: {"data": {"customer": {...}}, "loading": false, ...}
if isinstance(data, dict):
# Trả về toàn bộ data để extract_user_id xử lý
return data
if response.status_code == 200:
data = response.json()
logger.debug(f"Canifa API Raw Response: {data}")
# Nếu Canifa trả list (batch request)
# Response format: {"data": {"customer": {...}}, "loading": false, ...}
if isinstance(data, dict):
# Trả về toàn bộ data để extract_user_id xử lý
return data
logger.warning(f"Canifa API Failed: {response.status_code} - {response.text}")
return None
# Nếu Canifa trả list (batch request)
return data
logger.warning(f"Canifa API Failed: {response.status_code} - {response.text}")
return None
except Exception as e:
logger.error(f"Error calling Canifa API: {e}")
......
......@@ -109,12 +109,15 @@ class ConversationManager:
raise
async def get_chat_history(
self, identity_key: str, limit: int | None = None, before_id: int | None = None
self, identity_key: str, limit: int | None = None, before_id: int | None = None,
include_product_ids: bool = True
) -> list[dict[str, Any]]:
"""
Retrieve chat history for an identity (user_id or device_id) using cursor-based pagination.
AI messages được parse từ JSON string để lấy product_ids.
Uses cached graph for performance.
Args:
include_product_ids: True for API (frontend needs product cards),
False for AI context (only text needed)
"""
max_retries = 3
for attempt in range(max_retries):
......@@ -166,15 +169,17 @@ class ConversationManager:
# User message - text thuần
entry["message"] = message_content
else:
# AI message - parse JSON để lấy ai_response + product_ids
# AI message - parse JSON
try:
parsed = json.loads(message_content)
entry["message"] = parsed.get("ai_response", message_content)
entry["product_ids"] = parsed.get("product_ids", [])
if include_product_ids:
entry["product_ids"] = parsed.get("product_ids", [])
except (json.JSONDecodeError, TypeError):
# Fallback nếu không phải JSON (data cũ)
entry["message"] = message_content
entry["product_ids"] = []
if include_product_ids:
entry["product_ids"] = []
history.append(entry)
......
import logging
import hashlib
import json
import logging
from openai import AsyncOpenAI, OpenAI
......@@ -91,7 +93,7 @@ async def create_embedding_async(text: str) -> list[float]:
return []
async def create_embeddings_async(texts: list[str]) -> list[list[float]]:
async def create_embeddings_async(texts: list[str]) -> list[list[float]]:
"""
Batch async embedding generation with per-item Layer 2 Cache.
"""
......@@ -99,18 +101,28 @@ async def create_embeddings_async(texts: list[str]) -> list[list[float]]:
if not texts:
return []
results = [[] for _ in texts]
missed_indices = []
missed_texts = []
# 1. Check Cache for each text
for i, text in enumerate(texts):
cached = await redis_cache.get_embedding(text)
if cached:
results[i] = cached
else:
missed_indices.append(i)
missed_texts.append(text)
results = [[] for _ in texts]
missed_indices = []
missed_texts = []
client = redis_cache.get_client()
if client:
keys = []
for text in texts:
text_hash = hashlib.md5(text.strip().lower().encode()).hexdigest()
keys.append(f"emb_cache:{text_hash}")
cached_values = await client.mget(keys)
for i, cached in enumerate(cached_values):
if cached:
results[i] = json.loads(cached)
else:
missed_indices.append(i)
missed_texts.append(texts[i])
else:
# Fallback: no redis client, treat all as miss
missed_indices = list(range(len(texts)))
missed_texts = texts
# 2. Call OpenAI for missed texts
if missed_texts:
......
import logging
import uuid
import httpx
import httpx
from config import CONV_SUPABASE_KEY, CONV_SUPABASE_URL
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
_http_client: httpx.AsyncClient | None = None
def _get_http_client() -> httpx.AsyncClient:
global _http_client
if _http_client is None:
_http_client = httpx.AsyncClient()
return _http_client
async def close_http_client() -> None:
global _http_client
if _http_client is not None:
await _http_client.aclose()
_http_client = None
class ImageStorageService:
"""
......@@ -51,16 +67,16 @@ class ImageStorageService:
headers = {"Authorization": f"Bearer {self.key}", "apikey": self.key, "Content-Type": content_type}
async with httpx.AsyncClient() as client:
response = await client.post(upload_url, content=file_content, headers=headers)
if response.status_code == 200:
# Lấy public URL (Giả định bucket là public)
public_url = f"{self.url}/storage/v1/object/public/{self.bucket_name}/{filename}"
logger.info(f"✅ Uploaded image successfully: {public_url}")
return public_url
logger.error(f"❌ Failed to upload image: {response.status_code} - {response.text}")
return None
client = _get_http_client()
response = await client.post(upload_url, content=file_content, headers=headers)
if response.status_code == 200:
# Lấy public URL (Giả định bucket là public)
public_url = f"{self.url}/storage/v1/object/public/{self.bucket_name}/{filename}"
logger.info(f"✅ Uploaded image successfully: {public_url}")
return public_url
logger.error(f"❌ Failed to upload image: {response.status_code} - {response.text}")
return None
except Exception as e:
logger.error(f"Error uploading image to Supabase: {e}")
......
......@@ -17,11 +17,6 @@ from config import RATE_LIMIT_GUEST, RATE_LIMIT_USER
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION (from config.py)
# =============================================================================
# Redis key prefix
MESSAGE_COUNT_PREFIX = "msg_limit:"
class MessageLimitService:
......@@ -92,6 +87,7 @@ class MessageLimitService:
today = self._get_today_key()
return f"{MESSAGE_COUNT_PREFIX}{today}:{identity_key}"
def _get_seconds_until_midnight(self) -> int:
"""
Get seconds until next midnight (00:00).
......@@ -104,6 +100,7 @@ class MessageLimitService:
return int((midnight - now).total_seconds())
def _reset_memory_if_new_day(self) -> None:
"""Reset in-memory storage nếu qua ngày mới."""
today = self._get_today_key()
......@@ -112,10 +109,10 @@ class MessageLimitService:
self._memory_date = today
logger.debug(f"🔄 Memory storage reset for new day: {today}")
# =========================================================================
# REDIS OPERATIONS
# =========================================================================
async def _get_counts_from_redis(self, identity_key: str) -> dict[str, int] | None:
"""
Get all counts (guest, user) từ Redis Hash.
......@@ -143,6 +140,7 @@ class MessageLimitService:
logger.warning(f"Redis get counts error: {e}")
return None
async def _increment_in_redis(self, identity_key: str, field: str) -> int | None:
"""
Increment specific field ('guest' or 'user') trong Redis Hash.
......@@ -171,10 +169,10 @@ class MessageLimitService:
logger.warning(f"Redis increment error: {e}")
return None
# =========================================================================
# PUBLIC METHODS
# =========================================================================
# =========================================================================
async def check_limit(
self,
identity_key: str,
......
This diff is collapsed.
This diff is collapsed.
"""
Reset Limit Service - Chỉ dành cho User đã login
Không giới hạn số lần reset (archive chat)
"""
import logging
from datetime import datetime
from common.cache import redis_cache
logger = logging.getLogger(__name__)
class ResetLimitService:
def __init__(self, limit: int = 5):
self.limit = limit
self.expiration_seconds = 86400 # 1 day
"""
Service quản lý việc reset (archive) chat.
Chỉ dành cho user đã đăng nhập, không giới hạn số lần.
"""
async def check_limit(self, identity_key: str) -> tuple[bool, int, int]:
"""
Check if user can reset chat.
Luôn cho phép reset (không giới hạn).
Returns (can_reset, current_usage, remaining)
"""
redis_client = redis_cache.get_client()
if not redis_client:
# Fallback if Redis is down: allow reset
return True, 0, self.limit
today = datetime.now().strftime("%Y-%m-%d")
key = f"reset_limit:{identity_key}:{today}"
try:
count = await redis_client.get(key)
if count is None:
return True, 0, self.limit
current_usage = int(count)
remaining = self.limit - current_usage
if current_usage >= self.limit:
return False, current_usage, 0
return True, current_usage, remaining
except Exception as e:
logger.error(f"Error checking reset limit: {e}")
return True, 0, self.limit
# Không giới hạn - luôn cho phép
return True, 0, 999
async def increment(self, identity_key: str):
redis_client = redis_cache.get_client()
if not redis_client:
return
"""
Không cần track usage nữa vì không giới hạn.
"""
pass
today = datetime.now().strftime("%Y-%m-%d")
key = f"reset_limit:{identity_key}:{today}"
try:
pipe = redis_client.pipeline()
pipe.incr(key)
pipe.expire(key, self.expiration_seconds)
await pipe.execute()
except Exception as e:
logger.error(f"Error incrementing reset limit: {e}")
reset_limit_service = ResetLimitService(limit=5)
reset_limit_service = ResetLimitService()
......@@ -3,8 +3,9 @@ StarRocks Database Connection Utility
Based on chatbot-rsa pattern
"""
import asyncio
import logging
import asyncio
import logging
import os
from typing import Any
import aiomysql
......@@ -156,17 +157,19 @@ class StarRocksConnection:
async with StarRocksConnection._pool_lock:
if StarRocksConnection._shared_pool is None:
logger.info(f"🔌 Creating Async Pool to {self.host}:{self.port}...")
StarRocksConnection._shared_pool = await aiomysql.create_pool(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
charset="utf8mb4",
cursorclass=aiomysql.DictCursor,
minsize=2, # Giảm minsize để đỡ tốn tài nguyên idle
maxsize=80,
connect_timeout=10,
minsize = int(os.getenv("STARROCKS_POOL_MINSIZE", "2"))
maxsize = int(os.getenv("STARROCKS_POOL_MAXSIZE", "80"))
StarRocksConnection._shared_pool = await aiomysql.create_pool(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
charset="utf8mb4",
cursorclass=aiomysql.DictCursor,
minsize=minsize, # Giảm minsize để đỡ tốn tài nguyên idle
maxsize=maxsize,
connect_timeout=10,
# --- CHỈNH SỬA QUAN TRỌNG Ở ĐÂY ---
pool_recycle=280, # Recycle sau 4 phút rưỡi (tránh timeout 5 phút của Windows/Firewall)
# ----------------------------------
......
"""
User Identity Helper
Xác định user identity từ request
Design:
- Có user_id: Langfuse User ID = user_id, metadata = {device_id: "xxx", is_authenticated: true}
- Không user_id: Langfuse User ID = device_id, metadata = {device_id: "xxx", is_authenticated: false}
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from datetime import datetime
from fastapi import Request
logger = logging.getLogger(__name__)
@dataclass
class UserIdentity:
"""User identity với các thông tin cần thiết"""
# ID chính dùng cho Langfuse, history, rate limit
primary_id: str
# Device ID (luôn có)
device_id: str
# User ID từ token (chỉ có khi đã login)
user_id: str | None
# Đã login hay chưa
is_authenticated: bool
@property
def langfuse_user_id(self) -> str:
"""User ID cho Langfuse tracking"""
return self.primary_id
@property
def langfuse_session_id(self) -> str:
"""Session ID cho Langfuse (theo device + ngày)"""
today = datetime.now().strftime("%Y%m%d")
return f"{self.device_id}-{today}"
@property
def langfuse_metadata(self) -> dict:
"""Metadata cho Langfuse"""
return {
"device_id": self.device_id,
"is_authenticated": self.is_authenticated,
}
@property
def langfuse_tags(self) -> list[str]:
"""Tags cho Langfuse"""
tags = ["chatbot", "production"]
tags.append("customer" if self.is_authenticated else "guest")
return tags
@property
def history_key(self) -> str:
"""
Key để lưu/load chat history.
- Guest (chưa login): device_id
- User (đã login): user_id (customer_id từ Canifa)
"""
if self.is_authenticated and self.user_id:
return self.user_id
return self.device_id
@property
def rate_limit_key(self) -> str:
"""
Key cho rate limiting.
- Guest (chưa login): device_id → limit 10
- User (đã login): user_id → limit 100
"""
if self.is_authenticated and self.user_id:
return self.user_id
return self.device_id
def get_user_identity(request: Request) -> UserIdentity:
"""
Extract user identity từ request.
Logic:
- Có user_id (từ token) → primary_id = user_id
- Không có → primary_id = device_id
Args:
request: FastAPI Request object
Returns:
UserIdentity object
"""
# 1. Lấy device_id ưu tiên từ request.state (do middleware parse từ body), sau đó mới tới header
device_id = ""
if hasattr(request.state, "device_id") and request.state.device_id:
device_id = request.state.device_id
if not device_id:
device_id = request.headers.get("device_id", "")
if not device_id:
device_id = f"unknown_{request.client.host}" if request.client else "unknown"
# 2. Lấy user_id từ token (middleware đã parse)
user_id = None
is_authenticated = False
if hasattr(request.state, "user_id") and request.state.user_id:
user_id = request.state.user_id
is_authenticated = True
# 3. Primary ID - LUÔN LUÔN LÀ device_id
primary_id = device_id
identity = UserIdentity(
primary_id=primary_id,
device_id=device_id,
user_id=user_id,
is_authenticated=is_authenticated,
)
logger.debug(
f"UserIdentity: langfuse_user_id={identity.langfuse_user_id}, "
f"metadata={identity.langfuse_metadata}"
)
return identity
import asyncio
import logging
import os
import platform
if platform.system() == "Windows":
print("🔧 Windows detected: Applying SelectorEventLoopPolicy globally...")
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import logging
import uvicorn
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router
from api.prompt_route import router as prompt_router
from api.mock_api_route import router as mock_router
from common.cache import redis_cache
from common.langfuse_client import get_langfuse_client
from common.middleware import middleware_manager
from config import PORT
# Configure LoggingP
if platform.system() == "Windows":
print("🔧 Windows detected: Applying SelectorEventLoopPolicy globally...")
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# Configure Logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
......@@ -29,11 +31,7 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
langfuse_client = get_langfuse_client()
if langfuse_client:
logger.info("✅ Langfuse client ready (lazy loading)")
else:
logger.warning("⚠️ Langfuse client not available (missing keys or disabled)")
app = FastAPI(
title="Contract AI Service",
......@@ -42,62 +40,41 @@ app = FastAPI(
)
# =============================================================================
# STARTUP EVENT - Initialize Redis Cache
# =============================================================================
@app.on_event("startup")
async def startup_event():
"""Initialize Redis cache on startup."""
await redis_cache.initialize()
logger.info("✅ Redis cache initialized for message limit")
logger.info("✅ Redis cache initialized")
# =============================================================================
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# =============================================================================
middleware_manager.setup(
app,
enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit=True, # 👈 Bật lại SlowAPI theo yêu cầu
enable_cors=True, # 👈 Bật CORS
cors_origins=["*"], # 👈 Trong production nên limit origins
enable_auth=True,
enable_rate_limit=True,
enable_cors=True,
cors_origins=["*"],
)
# api include
app.include_router(conservation_router)
app.include_router(chatbot_router)
app.include_router(prompt_router)
from api.mock_api_route import router as mock_router
app.include_router(mock_router)
print("✅ Mock API Router mounted at /api/mock")
# --- MOCK API FOR LOAD TESTING ---
try:
from api.mock_api_route import router as mock_router
app.include_router(mock_router, prefix="/api")
print("✅ Mock API Router mounted at /api/mock")
except ImportError:
print("⚠️ Mock Router not found, skipping...")
# ==========================================
# 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢
# ==========================================
try:
static_dir = os.path.join(os.path.dirname(__file__), "static")
if not os.path.exists(static_dir):
os.makedirs(static_dir)
# Mount thư mục static để chạy file index.html
app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
print(f"✅ Static files mounted at /static (Dir: {static_dir})")
except Exception as e:
print(f"⚠️ Failed to mount static files: {e}")
from fastapi.responses import RedirectResponse
@app.get("/")
async def root():
return RedirectResponse(url="/static/index.html")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment