Commit c66892a7 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

refactor: simplify identity management, rate limiting, and chat history

parent d20b1e77
...@@ -4,7 +4,6 @@ Langfuse will auto-trace via LangChain integration (no code changes needed). ...@@ -4,7 +4,6 @@ Langfuse will auto-trace via LangChain integration (no code changes needed).
""" """
import logging import logging
import time
import uuid import uuid
from fastapi import BackgroundTasks from fastapi import BackgroundTasks
...@@ -14,14 +13,12 @@ from langchain_core.runnables import RunnableConfig ...@@ -14,14 +13,12 @@ from langchain_core.runnables import RunnableConfig
from common.cache import redis_cache from common.cache import redis_cache
from common.conversation_manager import get_conversation_manager from common.conversation_manager import get_conversation_manager
from common.langfuse_client import get_callback_handler from common.langfuse_client import get_callback_handler
from common.llm_factory import create_llm
from config import DEFAULT_MODEL, REDIS_CACHE_TURN_ON from config import DEFAULT_MODEL, REDIS_CACHE_TURN_ON
from langfuse import propagate_attributes from langfuse import propagate_attributes
from .graph import build_graph from .graph import build_graph
from .helper import extract_product_ids, handle_post_chat_async, parse_ai_response from .helper import extract_product_ids, handle_post_chat_async, parse_ai_response
from .models import AgentState, get_config from .models import AgentState, get_config
from .tools.get_tools import get_all_tools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -75,15 +72,13 @@ async def chat_controller( ...@@ -75,15 +72,13 @@ async def chat_controller(
config = get_config() config = get_config()
config.model_name = model_name config.model_name = model_name
llm = create_llm(model_name=model_name, streaming=False, json_mode=True) graph = build_graph(config)
# tools = get_all_tools() # Singleton now handles tools
graph = build_graph(config) # Singleton usage
# Init ConversationManager (Singleton) # Init ConversationManager (Singleton)
memory = await get_conversation_manager() memory = await get_conversation_manager()
# Load History # Load History (only text, no product_ids for AI context)
history_dicts = await memory.get_chat_history(effective_identity_key, limit=15) history_dicts = await memory.get_chat_history(effective_identity_key, limit=15, include_product_ids=False)
messages = [ messages = [
HumanMessage(content=m["message"]) if m["is_human"] else AIMessage(content=m["message"]) HumanMessage(content=m["message"]) if m["is_human"] else AIMessage(content=m["message"])
for m in history_dicts for m in history_dicts
...@@ -114,26 +109,15 @@ async def chat_controller( ...@@ -114,26 +109,15 @@ async def chat_controller(
) )
# Execute Graph # Execute Graph
start_time = time.time()
session_id = f"{user_id}-{run_id[:8]}" session_id = f"{user_id}-{run_id[:8]}"
with propagate_attributes(user_id=user_id, session_id=session_id): with propagate_attributes(user_id=user_id, session_id=session_id):
result = await graph.ainvoke(initial_state, config=exec_config) result = await graph.ainvoke(initial_state, config=exec_config)
duration = time.time() - start_time
# Parse Response # Parse Response
all_product_ids = extract_product_ids(result.get("messages", [])) all_product_ids = extract_product_ids(result.get("messages", []))
logger.info("🔍 [DEBUG] all_product_ids count: %s", len(all_product_ids))
if all_product_ids:
logger.info("🔍 [DEBUG] First product from tool: %s", all_product_ids[0])
ai_raw_content = result.get("ai_response").content if result.get("ai_response") else "" ai_raw_content = result.get("ai_response").content if result.get("ai_response") else ""
ai_text_response, final_product_ids = parse_ai_response(ai_raw_content, all_product_ids) ai_text_response, final_product_ids = parse_ai_response(ai_raw_content, all_product_ids)
logger.info("🔍 [DEBUG] final_product_ids count: %s, type: %s", len(final_product_ids), type(final_product_ids[0]) if final_product_ids else "empty")
if final_product_ids:
logger.info("🔍 [DEBUG] First final product: %s", final_product_ids[0])
response_payload = { response_payload = {
"ai_response": ai_text_response, "ai_response": ai_text_response,
...@@ -159,6 +143,6 @@ async def chat_controller( ...@@ -159,6 +143,6 @@ async def chat_controller(
ai_response=response_payload, ai_response=response_payload,
) )
logger.info("chat_controller finished in %.2fs", duration) logger.info("chat_controller finished")
return {**response_payload, "cached": False} return {**response_payload, "cached": False}
...@@ -16,7 +16,7 @@ from langchain_core.runnables import RunnableConfig ...@@ -16,7 +16,7 @@ from langchain_core.runnables import RunnableConfig
from common.conversation_manager import ConversationManager, get_conversation_manager from common.conversation_manager import ConversationManager, get_conversation_manager
from common.langfuse_client import get_callback_handler from common.langfuse_client import get_callback_handler
from common.llm_factory import create_llm from agent.tools.data_retrieval_tool import SearchItem, data_retrieval_tool
from config import DEFAULT_MODEL from config import DEFAULT_MODEL
from .graph import build_graph from .graph import build_graph
...@@ -51,11 +51,8 @@ async def chat_controller( ...@@ -51,11 +51,8 @@ async def chat_controller(
config = get_config() config = get_config()
config.model_name = model_name config.model_name = model_name
# Enable JSON mode to ensure structured output tools = get_all_tools()
llm = create_llm(model_name=model_name, streaming=False, json_mode=True) graph = build_graph(config, llm=None, tools=tools)
tools = get_all_tools()
graph = build_graph(config, llm=llm, tools=tools)
# Init ConversationManager (Singleton) # Init ConversationManager (Singleton)
memory = await get_conversation_manager() memory = await get_conversation_manager()
...@@ -180,9 +177,8 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images: ...@@ -180,9 +177,8 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images:
"tags": "chatbot,production", "tags": "chatbot,production",
} }
# 🔥 CallbackHandler - sẽ được wrap trong langfuse_trace_context để set user_id # CallbackHandler for Langfuse (if enabled)
# Per Langfuse docs: propagate_attributes() handles user_id propagation langfuse_handler = get_callback_handler()
langfuse_handler = get_callback_handler()
exec_config = RunnableConfig( exec_config = RunnableConfig(
configurable={ configurable={
...@@ -214,12 +210,12 @@ async def _handle_post_chat_async( ...@@ -214,12 +210,12 @@ async def _handle_post_chat_async(
# ======================================== # ========================================
async def mock_chat_controller( async def mock_chat_controller(
query: str, query: str,
user_id: str, user_id: str,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
images: list[str] | None = None, images: list[str] | None = None,
) -> dict: ) -> dict:
""" """
Mock Agent Controller với FAKE LLM (không gọi OpenAI): Mock Agent Controller với FAKE LLM (không gọi OpenAI):
- Sử dụng toàn bộ graph flow từ chat_controller - Sử dụng toàn bộ graph flow từ chat_controller
...@@ -238,36 +234,24 @@ async def mock_chat_controller( ...@@ -238,36 +234,24 @@ async def mock_chat_controller(
✅ Không cần JSON parsing (response là plain text) ✅ Không cần JSON parsing (response là plain text)
✅ Nhanh hơn (~1-3ms giả lập LLM thay vì 1-3s real LLM) ✅ Nhanh hơn (~1-3ms giả lập LLM thay vì 1-3s real LLM)
""" """
logger.info(f"🚀 [MOCK Chat Controller] Starting with query: {query} for user: {user_id}") logger.info(f"🚀 [MOCK Chat Controller] Starting with query: {query} for user: {user_id}")
start_time = time.time() start_time = time.time()
config = get_config()
# KHÔNG gọi OpenAI - dùng tools THẬT nhưng fake LLM response
tools = get_all_tools()
graph = build_graph(config, llm=None, tools=tools) # llm=None để skip LLM node
# Init ConversationManager (Singleton)
memory = await get_conversation_manager()
# LOAD HISTORY & Prepare State
history_dicts = await memory.get_chat_history(user_id, limit=20)
history = []
for h in reversed(history_dicts):
msg_cls = HumanMessage if h["is_human"] else AIMessage
history.append(msg_cls(content=h["message"]))
initial_state, exec_config = _prepare_execution_context(
query=query, user_id=user_id, history=history, images=images
)
try:
# Chạy Graph với tools THẬT
result = await graph.ainvoke(initial_state, config=exec_config)
# Extract products từ tool messages (tools THẬT) # Init ConversationManager (Singleton)
all_product_ids = _extract_product_ids(result.get("messages", [])) memory = await get_conversation_manager()
try:
# Gọi tool trực tiếp (không qua LLM) để tránh bottleneck
search_item = SearchItem(
query=query or "sản phẩm",
magento_ref_code=None,
price_min=None,
price_max=None,
action="search",
)
result_json = await data_retrieval_tool.ainvoke({"searches": [search_item]})
result = json.loads(result_json)
all_product_ids = result.get("results", [{}])[0].get("products", [])
# Generate FAKE LLM response (không gọi OpenAI) # Generate FAKE LLM response (không gọi OpenAI)
logger.info("🤖 [FAKE LLM] Generating mock response...") logger.info("🤖 [FAKE LLM] Generating mock response...")
......
This diff is collapsed.
...@@ -10,21 +10,16 @@ import logging ...@@ -10,21 +10,16 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from opentelemetry import trace
from agent.controller import chat_controller from agent.controller import chat_controller
from agent.models import QueryRequest from agent.models import QueryRequest
from common.message_limit import message_limit_service from common.message_limit import message_limit_service
from common.user_identity import get_user_identity from common.rate_limit import rate_limit_service
from config import DEFAULT_MODEL from config import DEFAULT_MODEL
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)
router = APIRouter() router = APIRouter()
from common.rate_limit import rate_limit_service
@router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)") @router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)")
@rate_limit_service.limiter.limit("50/minute") @rate_limit_service.limiter.limit("50/minute")
async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks): async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks):
...@@ -33,55 +28,40 @@ async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: ...@@ -33,55 +28,40 @@ async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks:
Note: Rate limit đã được check trong middleware. Note: Rate limit đã được check trong middleware.
""" """
# 1. Xác định user identity # 1. Lấy user identity từ Middleware (request.state)
identity = get_user_identity(request) # Logic: Login -> User ID | Guest -> Device ID
user_id = identity.primary_id user_id = getattr(request.state, "user_id", None)
device_id = getattr(request.state, "device_id", "unknown")
is_authenticated = getattr(request.state, "is_authenticated", False)
# Định danh duy nhất cho Request này (Log, History, Rate Limit, Langfuse)
identity_id = user_id if is_authenticated else device_id
# Rate limit đã check trong middleware, lấy limit_info từ request.state # Rate limit đã check trong middleware, lấy limit_info từ request.state
limit_info = getattr(request.state, 'limit_info', None) limit_info = getattr(request.state, 'limit_info', None)
logger.info(f"📥 [Incoming Query - NonStream] User: {user_id} | Query: {req.user_query}") logger.info(f"📥 [Incoming Query - NonStream] User: {identity_id} | Query: {req.user_query}")
# Get current span để add logs VÀO JAEGER UI
span = trace.get_current_span()
span.set_attribute("user.id", user_id)
span.set_attribute("chat.user_query", req.user_query)
span.add_event(
"📥 User query received", attributes={"user_id": user_id, "query": req.user_query, "timestamp": "incoming"}
)
try: try:
# Gọi controller để xử lý logic (Non-streaming) # Gọi controller để xử lý logic (Non-streaming)
result = await chat_controller( result = await chat_controller(
query=req.user_query, query=req.user_query,
user_id=user_id, user_id=str(identity_id), # Langfuse User ID
background_tasks=background_tasks, background_tasks=background_tasks,
model_name=DEFAULT_MODEL, model_name=DEFAULT_MODEL,
images=req.images, images=req.images,
identity_key=identity.history_key, # Guest: device_id, User: user_id identity_key=str(identity_id), # Key lưu history
) )
# Log chi tiết response # Log chi tiết response
logger.info(f"📤 [Outgoing Response - NonStream] User: {user_id}") logger.info(f"📤 [Outgoing Response - NonStream] User: {identity_id}")
logger.info(f"💬 AI Response: {result['ai_response']}") logger.info(f"💬 AI Response: {result['ai_response']}")
logger.info(f"🛍️ Product IDs: {result.get('product_ids', [])}") logger.info(f"🛍️ Product IDs: {result.get('product_ids', [])}")
# Add to span (hiển thị trong Jaeger UI)
span.set_attribute("chat.ai_response", result["ai_response"][:200]) # Giới hạn 200 ký tự
span.set_attribute("chat.product_count", len(result.get("product_ids", [])))
span.add_event(
"💬 AI response generated",
attributes={
"ai_response_preview": result["ai_response"][:100],
"product_count": len(result.get("product_ids", [])),
"product_ids": str(result.get("product_ids", [])[:5]), # First 5 IDs
},
)
# Increment message count SAU KHI chat thành công # Increment message count SAU KHI chat thành công
usage_info = await message_limit_service.increment( usage_info = await message_limit_service.increment(
identity_key=identity.rate_limit_key, identity_key=identity_id,
is_authenticated=identity.is_authenticated, is_authenticated=is_authenticated,
) )
return { return {
......
""" """
Chat History API Routes Chat History API Routes
- GET /api/history/{identity_key} - Lấy lịch sử chat (có product_ids) - GET /api/history/{identity_key} - Lấy lịch sử chat
- DELETE /api/history/{identity_key} - Xóa lịch sử chat - DELETE /api/history/{identity_key} - Xóa lịch sử chat
Note: identity_key có thể là device_id (guest) hoặc user_id (đã login) Note: identity_key có thể là device_id (guest) hoặc user_id (đã login)
...@@ -12,7 +12,6 @@ from typing import Any ...@@ -12,7 +12,6 @@ from typing import Any
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel from pydantic import BaseModel
from common.conversation_manager import get_conversation_manager from common.conversation_manager import get_conversation_manager
from common.user_identity import get_user_identity
router = APIRouter(tags=["Chat History"]) router = APIRouter(tags=["Chat History"])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -40,17 +39,17 @@ async def get_chat_history(request: Request, identity_key: str, limit: int | Non ...@@ -40,17 +39,17 @@ async def get_chat_history(request: Request, identity_key: str, limit: int | Non
(identity_key trong URL chỉ là fallback) (identity_key trong URL chỉ là fallback)
""" """
try: try:
# Tự động resolve identity từ middleware # Resolve identity từ middleware (request.state)
identity = get_user_identity(request) user_id = getattr(request.state, "user_id", None)
device_id = getattr(request.state, "device_id", identity_key)
is_authenticated = getattr(request.state, "is_authenticated", False)
# Nếu đã login -> Dùng user_id # Log chi tiết để debug
if identity.is_authenticated: logger.info(f"GET History: auth={is_authenticated} | user_id={user_id} | device_id={device_id}")
resolved_key = identity.history_key
else: # Nếu đã login -> Dùng user_id, không thì dùng device_id
# Nếu chưa login (Guest) -> Dùng identity_key từ URL resolved_key = user_id if is_authenticated else device_id
resolved_key = identity_key logger.info(f"GET History: resolved_key={resolved_key}")
logger.info(f"GET History: URL key={identity_key} -> Resolved key={resolved_key}")
manager = await get_conversation_manager() manager = await get_conversation_manager()
history = await manager.get_chat_history(resolved_key, limit=limit, before_id=before_id) history = await manager.get_chat_history(resolved_key, limit=limit, before_id=before_id)
...@@ -100,10 +99,13 @@ async def archive_chat_history(request: Request): ...@@ -100,10 +99,13 @@ async def archive_chat_history(request: Request):
Giới hạn 5 lần/ngày. Giới hạn 5 lần/ngày.
""" """
try: try:
identity = get_user_identity(request) # Resolve identity từ middleware (request.state)
user_id = getattr(request.state, "user_id", None)
device_id = getattr(request.state, "device_id", "")
is_authenticated = getattr(request.state, "is_authenticated", False)
# Chỉ dành cho User đã đăng nhập # Chỉ dành cho User đã đăng nhập
if not identity.is_authenticated: if not is_authenticated:
return JSONResponse( return JSONResponse(
status_code=401, status_code=401,
content={ content={
...@@ -114,7 +116,7 @@ async def archive_chat_history(request: Request): ...@@ -114,7 +116,7 @@ async def archive_chat_history(request: Request):
} }
) )
identity_key = identity.history_key identity_key = user_id
# Check reset limit # Check reset limit
can_reset, usage, remaining = await reset_limit_service.check_limit(identity_key) can_reset, usage, remaining = await reset_limit_service.check_limit(identity_key)
......
...@@ -7,6 +7,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException ...@@ -7,6 +7,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from agent.tools.data_retrieval_tool import SearchItem, data_retrieval_tool from agent.tools.data_retrieval_tool import SearchItem, data_retrieval_tool
from agent.mock_controller import mock_chat_controller
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
...@@ -31,6 +32,7 @@ class MockQueryRequest(BaseModel): ...@@ -31,6 +32,7 @@ class MockQueryRequest(BaseModel):
user_query: str user_query: str
user_id: str | None = "test_user" user_id: str | None = "test_user"
session_id: str | None = None session_id: str | None = None
images: list[str] | None = None
class MockDBRequest(BaseModel): class MockDBRequest(BaseModel):
...@@ -62,10 +64,6 @@ MOCK_AI_RESPONSES = [ ...@@ -62,10 +64,6 @@ MOCK_AI_RESPONSES = [
# --- ENDPOINTS --- # --- ENDPOINTS ---
from agent.mock_controller import mock_chat_controller
@router.post("/api/mock/agent/chat", summary="Mock Agent Chat (Real Tools + Fake LLM)") @router.post("/api/mock/agent/chat", summary="Mock Agent Chat (Real Tools + Fake LLM)")
async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks): async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks):
""" """
...@@ -82,6 +80,7 @@ async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks): ...@@ -82,6 +80,7 @@ async def mock_chat(req: MockQueryRequest, background_tasks: BackgroundTasks):
query=req.user_query, query=req.user_query,
user_id=req.user_id or "test_user", user_id=req.user_id or "test_user",
background_tasks=background_tasks, background_tasks=background_tasks,
images=req.images,
) )
return { return {
...@@ -146,9 +145,9 @@ async def mock_db_search(req: MockDBRequest): ...@@ -146,9 +145,9 @@ async def mock_db_search(req: MockDBRequest):
raise HTTPException(status_code=500, detail=f"DB Search Error: {e!s}") raise HTTPException(status_code=500, detail=f"DB Search Error: {e!s}")
@router.post("/api/mock/retrieverdb", summary="Real Embedding + Real DB Vector Search") @router.post("/api/mock/retrieverdb", summary="Real Embedding + Real DB Vector Search")
@router.post("/api/mock/retriverdb", summary="Real Embedding + Real DB Vector Search (Legacy)") @router.post("/api/mock/retriverdb", summary="Real Embedding + Real DB Vector Search (Legacy)")
async def mock_retriever_db(req: MockRetrieverRequest): async def mock_retriever_db(req: MockRetrieverRequest):
""" """
API thực tế để test Retriever + DB Search (dùng agent tool): API thực tế để test Retriever + DB Search (dùng agent tool):
- Lấy query từ user - Lấy query từ user
......
...@@ -45,6 +45,7 @@ async def get_system_prompt_content(request: Request): ...@@ -45,6 +45,7 @@ async def get_system_prompt_content(request: Request):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/api/agent/system-prompt") @router.post("/api/agent/system-prompt")
@rate_limit_service.limiter.limit("10/minute") @rate_limit_service.limiter.limit("10/minute")
async def update_system_prompt_content(request: Request, body: PromptUpdateRequest): async def update_system_prompt_content(request: Request, body: PromptUpdateRequest):
......
...@@ -10,7 +10,27 @@ import httpx ...@@ -10,7 +10,27 @@ import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CANIFA_CUSTOMER_API = "https://vsf2.canifa.com/v1/magento/customer" # CANIFA_CUSTOMER_API = "https://vsf2.canifa.com/v1/magento/customer"
CANIFA_CUSTOMER_API = "https://canifa.com/v1/magento/customer"
_http_client: httpx.AsyncClient | None = None
def _get_http_client() -> httpx.AsyncClient:
global _http_client
if _http_client is None:
_http_client = httpx.AsyncClient(timeout=10.0)
return _http_client
async def close_http_client() -> None:
global _http_client
if _http_client is not None:
await _http_client.aclose()
_http_client = None
CANIFA_QUERY_BODY = [ CANIFA_QUERY_BODY = [
{ {
...@@ -43,23 +63,23 @@ async def verify_canifa_token(token: str) -> dict[str, Any] | None: ...@@ -43,23 +63,23 @@ async def verify_canifa_token(token: str) -> dict[str, Any] | None:
} }
try: try:
async with httpx.AsyncClient(timeout=10.0) as client: client = _get_http_client()
response = await client.post(CANIFA_CUSTOMER_API, json=CANIFA_QUERY_BODY, headers=headers) response = await client.post(CANIFA_CUSTOMER_API, json=CANIFA_QUERY_BODY, headers=headers)
if response.status_code == 200:
data = response.json()
logger.debug(f"Canifa API Raw Response: {data}")
# Response format: {"data": {"customer": {...}}, "loading": false, ...} if response.status_code == 200:
if isinstance(data, dict): data = response.json()
# Trả về toàn bộ data để extract_user_id xử lý logger.debug(f"Canifa API Raw Response: {data}")
return data
# Nếu Canifa trả list (batch request) # Response format: {"data": {"customer": {...}}, "loading": false, ...}
if isinstance(data, dict):
# Trả về toàn bộ data để extract_user_id xử lý
return data return data
logger.warning(f"Canifa API Failed: {response.status_code} - {response.text}") # Nếu Canifa trả list (batch request)
return None return data
logger.warning(f"Canifa API Failed: {response.status_code} - {response.text}")
return None
except Exception as e: except Exception as e:
logger.error(f"Error calling Canifa API: {e}") logger.error(f"Error calling Canifa API: {e}")
......
...@@ -109,12 +109,15 @@ class ConversationManager: ...@@ -109,12 +109,15 @@ class ConversationManager:
raise raise
async def get_chat_history( async def get_chat_history(
self, identity_key: str, limit: int | None = None, before_id: int | None = None self, identity_key: str, limit: int | None = None, before_id: int | None = None,
include_product_ids: bool = True
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Retrieve chat history for an identity (user_id or device_id) using cursor-based pagination. Retrieve chat history for an identity (user_id or device_id) using cursor-based pagination.
AI messages được parse từ JSON string để lấy product_ids.
Uses cached graph for performance. Args:
include_product_ids: True for API (frontend needs product cards),
False for AI context (only text needed)
""" """
max_retries = 3 max_retries = 3
for attempt in range(max_retries): for attempt in range(max_retries):
...@@ -166,15 +169,17 @@ class ConversationManager: ...@@ -166,15 +169,17 @@ class ConversationManager:
# User message - text thuần # User message - text thuần
entry["message"] = message_content entry["message"] = message_content
else: else:
# AI message - parse JSON để lấy ai_response + product_ids # AI message - parse JSON
try: try:
parsed = json.loads(message_content) parsed = json.loads(message_content)
entry["message"] = parsed.get("ai_response", message_content) entry["message"] = parsed.get("ai_response", message_content)
entry["product_ids"] = parsed.get("product_ids", []) if include_product_ids:
entry["product_ids"] = parsed.get("product_ids", [])
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
# Fallback nếu không phải JSON (data cũ) # Fallback nếu không phải JSON (data cũ)
entry["message"] = message_content entry["message"] = message_content
entry["product_ids"] = [] if include_product_ids:
entry["product_ids"] = []
history.append(entry) history.append(entry)
......
import logging import hashlib
import json
import logging
from openai import AsyncOpenAI, OpenAI from openai import AsyncOpenAI, OpenAI
...@@ -91,7 +93,7 @@ async def create_embedding_async(text: str) -> list[float]: ...@@ -91,7 +93,7 @@ async def create_embedding_async(text: str) -> list[float]:
return [] return []
async def create_embeddings_async(texts: list[str]) -> list[list[float]]: async def create_embeddings_async(texts: list[str]) -> list[list[float]]:
""" """
Batch async embedding generation with per-item Layer 2 Cache. Batch async embedding generation with per-item Layer 2 Cache.
""" """
...@@ -99,18 +101,28 @@ async def create_embeddings_async(texts: list[str]) -> list[list[float]]: ...@@ -99,18 +101,28 @@ async def create_embeddings_async(texts: list[str]) -> list[list[float]]:
if not texts: if not texts:
return [] return []
results = [[] for _ in texts] results = [[] for _ in texts]
missed_indices = [] missed_indices = []
missed_texts = [] missed_texts = []
# 1. Check Cache for each text client = redis_cache.get_client()
for i, text in enumerate(texts): if client:
cached = await redis_cache.get_embedding(text) keys = []
if cached: for text in texts:
results[i] = cached text_hash = hashlib.md5(text.strip().lower().encode()).hexdigest()
else: keys.append(f"emb_cache:{text_hash}")
missed_indices.append(i)
missed_texts.append(text) cached_values = await client.mget(keys)
for i, cached in enumerate(cached_values):
if cached:
results[i] = json.loads(cached)
else:
missed_indices.append(i)
missed_texts.append(texts[i])
else:
# Fallback: no redis client, treat all as miss
missed_indices = list(range(len(texts)))
missed_texts = texts
# 2. Call OpenAI for missed texts # 2. Call OpenAI for missed texts
if missed_texts: if missed_texts:
......
import logging import logging
import uuid import uuid
import httpx import httpx
from config import CONV_SUPABASE_KEY, CONV_SUPABASE_URL from config import CONV_SUPABASE_KEY, CONV_SUPABASE_URL
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_http_client: httpx.AsyncClient | None = None
def _get_http_client() -> httpx.AsyncClient:
global _http_client
if _http_client is None:
_http_client = httpx.AsyncClient()
return _http_client
async def close_http_client() -> None:
global _http_client
if _http_client is not None:
await _http_client.aclose()
_http_client = None
class ImageStorageService: class ImageStorageService:
""" """
...@@ -51,16 +67,16 @@ class ImageStorageService: ...@@ -51,16 +67,16 @@ class ImageStorageService:
headers = {"Authorization": f"Bearer {self.key}", "apikey": self.key, "Content-Type": content_type} headers = {"Authorization": f"Bearer {self.key}", "apikey": self.key, "Content-Type": content_type}
async with httpx.AsyncClient() as client: client = _get_http_client()
response = await client.post(upload_url, content=file_content, headers=headers) response = await client.post(upload_url, content=file_content, headers=headers)
if response.status_code == 200: if response.status_code == 200:
# Lấy public URL (Giả định bucket là public) # Lấy public URL (Giả định bucket là public)
public_url = f"{self.url}/storage/v1/object/public/{self.bucket_name}/{filename}" public_url = f"{self.url}/storage/v1/object/public/{self.bucket_name}/{filename}"
logger.info(f"✅ Uploaded image successfully: {public_url}") logger.info(f"✅ Uploaded image successfully: {public_url}")
return public_url return public_url
logger.error(f"❌ Failed to upload image: {response.status_code} - {response.text}") logger.error(f"❌ Failed to upload image: {response.status_code} - {response.text}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error uploading image to Supabase: {e}") logger.error(f"Error uploading image to Supabase: {e}")
......
...@@ -17,11 +17,6 @@ from config import RATE_LIMIT_GUEST, RATE_LIMIT_USER ...@@ -17,11 +17,6 @@ from config import RATE_LIMIT_GUEST, RATE_LIMIT_USER
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION (from config.py)
# =============================================================================
# Redis key prefix
MESSAGE_COUNT_PREFIX = "msg_limit:" MESSAGE_COUNT_PREFIX = "msg_limit:"
class MessageLimitService: class MessageLimitService:
...@@ -92,6 +87,7 @@ class MessageLimitService: ...@@ -92,6 +87,7 @@ class MessageLimitService:
today = self._get_today_key() today = self._get_today_key()
return f"{MESSAGE_COUNT_PREFIX}{today}:{identity_key}" return f"{MESSAGE_COUNT_PREFIX}{today}:{identity_key}"
def _get_seconds_until_midnight(self) -> int: def _get_seconds_until_midnight(self) -> int:
""" """
Get seconds until next midnight (00:00). Get seconds until next midnight (00:00).
...@@ -104,6 +100,7 @@ class MessageLimitService: ...@@ -104,6 +100,7 @@ class MessageLimitService:
return int((midnight - now).total_seconds()) return int((midnight - now).total_seconds())
def _reset_memory_if_new_day(self) -> None: def _reset_memory_if_new_day(self) -> None:
"""Reset in-memory storage nếu qua ngày mới.""" """Reset in-memory storage nếu qua ngày mới."""
today = self._get_today_key() today = self._get_today_key()
...@@ -112,10 +109,10 @@ class MessageLimitService: ...@@ -112,10 +109,10 @@ class MessageLimitService:
self._memory_date = today self._memory_date = today
logger.debug(f"🔄 Memory storage reset for new day: {today}") logger.debug(f"🔄 Memory storage reset for new day: {today}")
# ========================================================================= # =========================================================================
# REDIS OPERATIONS # REDIS OPERATIONS
# ========================================================================= # =========================================================================
async def _get_counts_from_redis(self, identity_key: str) -> dict[str, int] | None: async def _get_counts_from_redis(self, identity_key: str) -> dict[str, int] | None:
""" """
Get all counts (guest, user) từ Redis Hash. Get all counts (guest, user) từ Redis Hash.
...@@ -143,6 +140,7 @@ class MessageLimitService: ...@@ -143,6 +140,7 @@ class MessageLimitService:
logger.warning(f"Redis get counts error: {e}") logger.warning(f"Redis get counts error: {e}")
return None return None
async def _increment_in_redis(self, identity_key: str, field: str) -> int | None: async def _increment_in_redis(self, identity_key: str, field: str) -> int | None:
""" """
Increment specific field ('guest' or 'user') trong Redis Hash. Increment specific field ('guest' or 'user') trong Redis Hash.
...@@ -171,10 +169,10 @@ class MessageLimitService: ...@@ -171,10 +169,10 @@ class MessageLimitService:
logger.warning(f"Redis increment error: {e}") logger.warning(f"Redis increment error: {e}")
return None return None
# ========================================================================= # =========================================================================
# PUBLIC METHODS # PUBLIC METHODS
# ========================================================================= # =========================================================================
async def check_limit( async def check_limit(
self, self,
identity_key: str, identity_key: str,
......
This diff is collapsed.
This diff is collapsed.
"""
Reset Limit Service - Chỉ dành cho User đã login
Không giới hạn số lần reset (archive chat)
"""
import logging import logging
from datetime import datetime
from common.cache import redis_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ResetLimitService: class ResetLimitService:
def __init__(self, limit: int = 5): """
self.limit = limit Service quản lý việc reset (archive) chat.
self.expiration_seconds = 86400 # 1 day Chỉ dành cho user đã đăng nhập, không giới hạn số lần.
"""
async def check_limit(self, identity_key: str) -> tuple[bool, int, int]: async def check_limit(self, identity_key: str) -> tuple[bool, int, int]:
""" """
Check if user can reset chat. Luôn cho phép reset (không giới hạn).
Returns (can_reset, current_usage, remaining) Returns (can_reset, current_usage, remaining)
""" """
redis_client = redis_cache.get_client() # Không giới hạn - luôn cho phép
if not redis_client: return True, 0, 999
# Fallback if Redis is down: allow reset
return True, 0, self.limit
today = datetime.now().strftime("%Y-%m-%d")
key = f"reset_limit:{identity_key}:{today}"
try:
count = await redis_client.get(key)
if count is None:
return True, 0, self.limit
current_usage = int(count)
remaining = self.limit - current_usage
if current_usage >= self.limit:
return False, current_usage, 0
return True, current_usage, remaining
except Exception as e:
logger.error(f"Error checking reset limit: {e}")
return True, 0, self.limit
async def increment(self, identity_key: str): async def increment(self, identity_key: str):
redis_client = redis_cache.get_client() """
if not redis_client: Không cần track usage nữa vì không giới hạn.
return """
pass
today = datetime.now().strftime("%Y-%m-%d")
key = f"reset_limit:{identity_key}:{today}"
try:
pipe = redis_client.pipeline()
pipe.incr(key)
pipe.expire(key, self.expiration_seconds)
await pipe.execute()
except Exception as e:
logger.error(f"Error incrementing reset limit: {e}")
reset_limit_service = ResetLimitService(limit=5) reset_limit_service = ResetLimitService()
...@@ -3,8 +3,9 @@ StarRocks Database Connection Utility ...@@ -3,8 +3,9 @@ StarRocks Database Connection Utility
Based on chatbot-rsa pattern Based on chatbot-rsa pattern
""" """
import asyncio import asyncio
import logging import logging
import os
from typing import Any from typing import Any
import aiomysql import aiomysql
...@@ -156,17 +157,19 @@ class StarRocksConnection: ...@@ -156,17 +157,19 @@ class StarRocksConnection:
async with StarRocksConnection._pool_lock: async with StarRocksConnection._pool_lock:
if StarRocksConnection._shared_pool is None: if StarRocksConnection._shared_pool is None:
logger.info(f"🔌 Creating Async Pool to {self.host}:{self.port}...") logger.info(f"🔌 Creating Async Pool to {self.host}:{self.port}...")
StarRocksConnection._shared_pool = await aiomysql.create_pool( minsize = int(os.getenv("STARROCKS_POOL_MINSIZE", "2"))
host=self.host, maxsize = int(os.getenv("STARROCKS_POOL_MAXSIZE", "80"))
port=self.port, StarRocksConnection._shared_pool = await aiomysql.create_pool(
user=self.user, host=self.host,
password=self.password, port=self.port,
db=self.database, user=self.user,
charset="utf8mb4", password=self.password,
cursorclass=aiomysql.DictCursor, db=self.database,
minsize=2, # Giảm minsize để đỡ tốn tài nguyên idle charset="utf8mb4",
maxsize=80, cursorclass=aiomysql.DictCursor,
connect_timeout=10, minsize=minsize, # Giảm minsize để đỡ tốn tài nguyên idle
maxsize=maxsize,
connect_timeout=10,
# --- CHỈNH SỬA QUAN TRỌNG Ở ĐÂY --- # --- CHỈNH SỬA QUAN TRỌNG Ở ĐÂY ---
pool_recycle=280, # Recycle sau 4 phút rưỡi (tránh timeout 5 phút của Windows/Firewall) pool_recycle=280, # Recycle sau 4 phút rưỡi (tránh timeout 5 phút của Windows/Firewall)
# ---------------------------------- # ----------------------------------
......
"""
User Identity Helper
Xác định user identity từ request
Design:
- Có user_id: Langfuse User ID = user_id, metadata = {device_id: "xxx", is_authenticated: true}
- Không user_id: Langfuse User ID = device_id, metadata = {device_id: "xxx", is_authenticated: false}
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from datetime import datetime
from fastapi import Request
logger = logging.getLogger(__name__)
@dataclass
class UserIdentity:
"""User identity với các thông tin cần thiết"""
# ID chính dùng cho Langfuse, history, rate limit
primary_id: str
# Device ID (luôn có)
device_id: str
# User ID từ token (chỉ có khi đã login)
user_id: str | None
# Đã login hay chưa
is_authenticated: bool
@property
def langfuse_user_id(self) -> str:
"""User ID cho Langfuse tracking"""
return self.primary_id
@property
def langfuse_session_id(self) -> str:
"""Session ID cho Langfuse (theo device + ngày)"""
today = datetime.now().strftime("%Y%m%d")
return f"{self.device_id}-{today}"
@property
def langfuse_metadata(self) -> dict:
"""Metadata cho Langfuse"""
return {
"device_id": self.device_id,
"is_authenticated": self.is_authenticated,
}
@property
def langfuse_tags(self) -> list[str]:
"""Tags cho Langfuse"""
tags = ["chatbot", "production"]
tags.append("customer" if self.is_authenticated else "guest")
return tags
@property
def history_key(self) -> str:
"""
Key để lưu/load chat history.
- Guest (chưa login): device_id
- User (đã login): user_id (customer_id từ Canifa)
"""
if self.is_authenticated and self.user_id:
return self.user_id
return self.device_id
@property
def rate_limit_key(self) -> str:
"""
Key cho rate limiting.
- Guest (chưa login): device_id → limit 10
- User (đã login): user_id → limit 100
"""
if self.is_authenticated and self.user_id:
return self.user_id
return self.device_id
def get_user_identity(request: Request) -> UserIdentity:
"""
Extract user identity từ request.
Logic:
- Có user_id (từ token) → primary_id = user_id
- Không có → primary_id = device_id
Args:
request: FastAPI Request object
Returns:
UserIdentity object
"""
# 1. Lấy device_id ưu tiên từ request.state (do middleware parse từ body), sau đó mới tới header
device_id = ""
if hasattr(request.state, "device_id") and request.state.device_id:
device_id = request.state.device_id
if not device_id:
device_id = request.headers.get("device_id", "")
if not device_id:
device_id = f"unknown_{request.client.host}" if request.client else "unknown"
# 2. Lấy user_id từ token (middleware đã parse)
user_id = None
is_authenticated = False
if hasattr(request.state, "user_id") and request.state.user_id:
user_id = request.state.user_id
is_authenticated = True
# 3. Primary ID - LUÔN LUÔN LÀ device_id
primary_id = device_id
identity = UserIdentity(
primary_id=primary_id,
device_id=device_id,
user_id=user_id,
is_authenticated=is_authenticated,
)
logger.debug(
f"UserIdentity: langfuse_user_id={identity.langfuse_user_id}, "
f"metadata={identity.langfuse_metadata}"
)
return identity
import asyncio import asyncio
import logging
import os import os
import platform import platform
if platform.system() == "Windows":
print("🔧 Windows detected: Applying SelectorEventLoopPolicy globally...")
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import logging
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from api.chatbot_route import router as chatbot_router from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router from api.conservation_route import router as conservation_router
from api.prompt_route import router as prompt_router from api.prompt_route import router as prompt_router
from api.mock_api_route import router as mock_router
from common.cache import redis_cache from common.cache import redis_cache
from common.langfuse_client import get_langfuse_client from common.langfuse_client import get_langfuse_client
from common.middleware import middleware_manager from common.middleware import middleware_manager
from config import PORT from config import PORT
# Configure LoggingP if platform.system() == "Windows":
print("🔧 Windows detected: Applying SelectorEventLoopPolicy globally...")
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# Configure Logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
...@@ -29,11 +31,7 @@ logging.basicConfig( ...@@ -29,11 +31,7 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
langfuse_client = get_langfuse_client()
if langfuse_client:
logger.info("✅ Langfuse client ready (lazy loading)")
else:
logger.warning("⚠️ Langfuse client not available (missing keys or disabled)")
app = FastAPI( app = FastAPI(
title="Contract AI Service", title="Contract AI Service",
...@@ -42,62 +40,41 @@ app = FastAPI( ...@@ -42,62 +40,41 @@ app = FastAPI(
) )
# =============================================================================
# STARTUP EVENT - Initialize Redis Cache
# =============================================================================
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
"""Initialize Redis cache on startup.""" """Initialize Redis cache on startup."""
await redis_cache.initialize() await redis_cache.initialize()
logger.info("✅ Redis cache initialized for message limit") logger.info("✅ Redis cache initialized")
# =============================================================================
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# =============================================================================
middleware_manager.setup( middleware_manager.setup(
app, app,
enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User enable_auth=True,
enable_rate_limit=True, # 👈 Bật lại SlowAPI theo yêu cầu enable_rate_limit=True,
enable_cors=True, # 👈 Bật CORS enable_cors=True,
cors_origins=["*"], # 👈 Trong production nên limit origins cors_origins=["*"],
) )
# api include
app.include_router(conservation_router) app.include_router(conservation_router)
app.include_router(chatbot_router) app.include_router(chatbot_router)
app.include_router(prompt_router) app.include_router(prompt_router)
from api.mock_api_route import router as mock_router
app.include_router(mock_router) app.include_router(mock_router)
print("✅ Mock API Router mounted at /api/mock")
# --- MOCK API FOR LOAD TESTING ---
try:
from api.mock_api_route import router as mock_router
app.include_router(mock_router, prefix="/api")
print("✅ Mock API Router mounted at /api/mock")
except ImportError:
print("⚠️ Mock Router not found, skipping...")
# ==========================================
# 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢
# ==========================================
try: try:
static_dir = os.path.join(os.path.dirname(__file__), "static") static_dir = os.path.join(os.path.dirname(__file__), "static")
if not os.path.exists(static_dir): if not os.path.exists(static_dir):
os.makedirs(static_dir) os.makedirs(static_dir)
# Mount thư mục static để chạy file index.html
app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static") app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
print(f"✅ Static files mounted at /static (Dir: {static_dir})") print(f"✅ Static files mounted at /static (Dir: {static_dir})")
except Exception as e: except Exception as e:
print(f"⚠️ Failed to mount static files: {e}") print(f"⚠️ Failed to mount static files: {e}")
from fastapi.responses import RedirectResponse
@app.get("/") @app.get("/")
async def root(): async def root():
return RedirectResponse(url="/static/index.html") return RedirectResponse(url="/static/index.html")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment