Commit 87361fe6 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

Refactor: Implement singleton lazy loading pattern for DB & Langfuse clients

- StarRocksConnection: Use StarRocksConnectionManager singleton with lazy loading
- Langfuse: Implement LangfuseClientManager singleton with lazy loading
- EmbeddingService: Already using singleton pattern
- Remove context managers (Langfuse auto-traces LangChain)
- Fix imports across agent/tools and server
- Clean up unnecessary comments and fix code organization
parent 4be13d17
...@@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage ...@@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from common.conversation_manager import ConversationManager, get_conversation_manager from common.conversation_manager import ConversationManager, get_conversation_manager
from common.langfuse_client import get_callback_handler, langfuse_trace_context from common.langfuse_client import get_callback_handler
from common.llm_factory import create_llm from common.llm_factory import create_llm
from config import DEFAULT_MODEL from config import DEFAULT_MODEL
...@@ -61,54 +61,39 @@ async def chat_controller( ...@@ -61,54 +61,39 @@ async def chat_controller(
) )
try: try:
# 🔥 Wrap graph execution với langfuse_trace_context để set user_id cho tất cả observations result = await graph.ainvoke(initial_state, config=exec_config)
with langfuse_trace_context(user_id=user_id, session_id=user_id): all_product_ids = _extract_product_ids(result.get("messages", []))
# TỐI ƯU: Chạy Graph
result = await graph.ainvoke(initial_state, config=exec_config)
# TỐI ƯU: Extract IDs từ Tool Messages một lần duy nhất ai_raw_content = result.get("ai_response").content if result.get("ai_response") else ""
all_product_ids = _extract_product_ids(result.get("messages", [])) logger.info(f"💾 [RAW AI OUTPUT]:\n{ai_raw_content}")
# TỐI ƯU: Xử lý AI Response ai_text_response = ai_raw_content
ai_raw_content = result.get("ai_response").content if result.get("ai_response") else "" try:
logger.info(f"💾 [RAW AI OUTPUT]:\n{ai_raw_content}") ai_json = json.loads(ai_raw_content)
ai_text_response = ai_json.get("ai_response", ai_raw_content)
# Parse JSON để lấy text response và product_ids từ AI
ai_text_response = ai_raw_content explicit_ids = ai_json.get("product_ids", [])
try: if explicit_ids and isinstance(explicit_ids, list):
# Vì json_mode=True, OpenAI sẽ nhả raw JSON seen_skus = {p["sku"] for p in all_product_ids if "sku" in p}
ai_json = json.loads(ai_raw_content) for product in explicit_ids:
if isinstance(product, dict) and product.get("sku") not in seen_skus:
# Extract text response từ JSON all_product_ids.append(product)
ai_text_response = ai_json.get("ai_response", ai_raw_content) seen_skus.add(product.get("sku"))
except (json.JSONDecodeError, Exception) as e:
# Merge product_ids từ AI JSON (nếu có) - KHÔNG dùng set() vì dict unhashable logger.warning(f"Could not parse AI response as JSON: {e}")
explicit_ids = ai_json.get("product_ids", [])
if explicit_ids and isinstance(explicit_ids, list): background_tasks.add_task(
# Merge và deduplicate by SKU _handle_post_chat_async,
seen_skus = {p["sku"] for p in all_product_ids if "sku" in p} memory=memory,
for product in explicit_ids: user_id=user_id,
if isinstance(product, dict) and product.get("sku") not in seen_skus: human_query=query,
all_product_ids.append(product) ai_msg=AIMessage(content=ai_text_response),
seen_skus.add(product.get("sku")) )
except (json.JSONDecodeError, Exception) as e:
# Nếu AI trả về text thường (hiếm khi xảy ra trong JSON mode) thì ignore return {
logger.warning(f"Could not parse AI response as JSON: {e}") "ai_response": ai_text_response,
pass "product_ids": all_product_ids,
}
# BACKGROUND TASK: Lưu history nhanh gọn
background_tasks.add_task(
_handle_post_chat_async,
memory=memory,
user_id=user_id,
human_query=query,
ai_msg=AIMessage(content=ai_text_response),
)
return {
"ai_response": ai_text_response, # CHỈ text, không phải JSON
"product_ids": all_product_ids, # Array of product objects
}
except Exception as e: except Exception as e:
logger.error(f"💥 Chat error for user {user_id}: {e}", exc_info=True) logger.error(f"💥 Chat error for user {user_id}: {e}", exc_info=True)
...@@ -171,8 +156,6 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images: ...@@ -171,8 +156,6 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images:
"tags": "chatbot,production", "tags": "chatbot,production",
} }
# 🔥 CallbackHandler - sẽ được wrap trong langfuse_trace_context để set user_id
# Per Langfuse docs: propagate_attributes() handles user_id propagation
langfuse_handler = get_callback_handler() langfuse_handler = get_callback_handler()
exec_config = RunnableConfig( exec_config = RunnableConfig(
......
...@@ -4,7 +4,7 @@ from langchain_core.tools import tool ...@@ -4,7 +4,7 @@ from langchain_core.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from common.embedding_service import create_embedding_async from common.embedding_service import create_embedding_async
from common.starrocks_connection import StarRocksConnection from common.starrocks_connection import get_db_connection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,7 +56,7 @@ async def canifa_knowledge_search(query: str) -> str: ...@@ -56,7 +56,7 @@ async def canifa_knowledge_search(query: str) -> str:
LIMIT 4 LIMIT 4
""" """
sr = StarRocksConnection() sr = get_db_connection()
results = await sr.execute_query_async(sql) results = await sr.execute_query_async(sql)
if not results: if not results:
......
...@@ -12,8 +12,8 @@ from decimal import Decimal ...@@ -12,8 +12,8 @@ from decimal import Decimal
from langchain_core.tools import tool from langchain_core.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from agent.tools.product_search_helpers import build_starrocks_query, save_preview_to_log from agent.tools.product_search_helpers import build_starrocks_query
from common.starrocks_connection import StarRocksConnection from common.starrocks_connection import get_db_connection
# from langsmith import traceable # from langsmith import traceable
...@@ -89,9 +89,9 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str: ...@@ -89,9 +89,9 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str:
""" """
logger.info("🔧 [DEBUG] data_retrieval_tool STARTED") logger.info("🔧 [DEBUG] data_retrieval_tool STARTED")
try: try:
logger.info("🔧 [DEBUG] Creating StarRocksConnection instance") logger.info("🔧 [DEBUG] Getting DB connection (singleton)")
db = StarRocksConnection() db = get_db_connection()
logger.info("🔧 [DEBUG] StarRocksConnection created successfully") logger.info("🔧 [DEBUG] DB connection retrieved successfully")
# 0. Log input parameters (Đúng ý bro) # 0. Log input parameters (Đúng ý bro)
logger.info(f"📥 [Tool Input] data_retrieval_tool received {len(searches)} items:") logger.info(f"📥 [Tool Input] data_retrieval_tool received {len(searches)} items:")
...@@ -128,7 +128,7 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str: ...@@ -128,7 +128,7 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str:
return json.dumps({"status": "error", "message": str(e)}) return json.dumps({"status": "error", "message": str(e)})
async def _execute_single_search(db: StarRocksConnection, item: SearchItem) -> list[dict]: async def _execute_single_search(db, item: SearchItem) -> list[dict]:
"""Thực thi một search query đơn lẻ (Async).""" """Thực thi một search query đơn lẻ (Async)."""
try: try:
logger.info(f"🔧 [DEBUG] _execute_single_search STARTED for query: {item.query[:50] if item.query else 'None'}") logger.info(f"🔧 [DEBUG] _execute_single_search STARTED for query: {item.query[:50] if item.query else 'None'}")
...@@ -149,10 +149,10 @@ async def _execute_single_search(db: StarRocksConnection, item: SearchItem) -> l ...@@ -149,10 +149,10 @@ async def _execute_single_search(db: StarRocksConnection, item: SearchItem) -> l
logger.info(f"🔧 [DEBUG] Query executed, got {len(products)} products") logger.info(f"🔧 [DEBUG] Query executed, got {len(products)} products")
logger.info(f"⏱️ [TIMER] DB Query Execution Time: {db_time:.2f}ms") logger.info(f"⏱️ [TIMER] DB Query Execution Time: {db_time:.2f}ms")
logger.info(f"⏱️ [TIMER] Total Time (Build + DB): {query_build_time + db_time:.2f}ms") logger.info(f"⏱️ [TIMER] Total Time (Build + DB): {query_build_time + db_time:.2f}ms")
# Ghi log DB Preview (Kết quả thực tế) vào Background Task # Ghi log DB Preview (Kết quả thực tế) vào Background Task
search_label = item.magento_ref_code if item.magento_ref_code else item.query search_label = item.magento_ref_code if item.magento_ref_code else item.query
asyncio.create_task(save_preview_to_log(search_label, products)) # asyncio.create_task(save_preview_to_log(search_label, products))
return _format_product_results(products) return _format_product_results(products)
except Exception as e: except Exception as e:
......
...@@ -6,136 +6,102 @@ With propagate_attributes for proper user_id tracking ...@@ -6,136 +6,102 @@ With propagate_attributes for proper user_id tracking
import asyncio import asyncio
import logging import logging
import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from langfuse import Langfuse, get_client, propagate_attributes from langfuse import Langfuse, get_client
from langfuse.langchain import CallbackHandler from langfuse.langchain import CallbackHandler
from config import ( from config import (
LANGFUSE_BASE_URL,
LANGFUSE_PUBLIC_KEY, LANGFUSE_PUBLIC_KEY,
LANGFUSE_SECRET_KEY, LANGFUSE_SECRET_KEY,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ⚡ Global state for async batch export __all__ = ["async_flush_langfuse", "get_callback_handler", "get_langfuse_client"]
_langfuse_client: Langfuse | None = None
_export_executor: ThreadPoolExecutor | None = None
_pending_traces: list = []
_export_task: asyncio.Task | None = None
_batch_lock = asyncio.Lock if hasattr(asyncio, "Lock") else None
def initialize_langfuse() -> bool: class LangfuseClientManager:
""" """
1. Set environment variables Singleton manager for Langfuse client.
2. Initialize Langfuse client Lazy loading - only initialize when first needed.
3. Setup thread pool for async batch export
""" """
global _langfuse_client, _export_executor
if not LANGFUSE_PUBLIC_KEY or not LANGFUSE_SECRET_KEY: def __init__(self):
logger.warning("⚠️ LANGFUSE KEYS MISSING. Tracing disabled.") self._client: Langfuse | None = None
return False self._executor: ThreadPoolExecutor | None = None
self._initialized = False
# Set environment def get_client(self) -> Langfuse | None:
os.environ["LANGFUSE_PUBLIC_KEY"] = LANGFUSE_PUBLIC_KEY """
os.environ["LANGFUSE_SECRET_KEY"] = LANGFUSE_SECRET_KEY Lazy loading - initialize Langfuse client on first call.
os.environ["LANGFUSE_BASE_URL"] = LANGFUSE_BASE_URL or "https://cloud.langfuse.com" """
os.environ["LANGFUSE_TIMEOUT"] = "10" # 10s timeout, not blocking if self._initialized:
return self._client
# Disable default flush to prevent blocking logger.info("🔧 [LAZY LOADING] Initializing Langfuse client (first time)")
os.environ["LANGFUSE_FLUSHINTERVAL"] = "300" # 5 min, very infrequent
try: if not LANGFUSE_PUBLIC_KEY or not LANGFUSE_SECRET_KEY:
_langfuse_client = get_client() logger.warning("⚠️ LANGFUSE KEYS MISSING. Tracing disabled.")
_export_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="langfuse_export") self._initialized = True
return None
if _langfuse_client.auth_check(): try:
logger.info("✅ Langfuse Ready! (async batch export)") self._client = get_client()
return True self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="langfuse_export")
logger.error("❌ Langfuse auth failed")
return False
except Exception as e: if self._client.auth_check():
logger.error(f"❌ Langfuse init error: {e}") logger.info("✅ Langfuse Ready! (async batch export)")
return False self._initialized = True
return self._client
logger.error("❌ Langfuse auth failed")
self._initialized = True
return None
async def async_flush_langfuse(): except Exception as e:
""" logger.error(f"❌ Langfuse init error: {e}")
Async wrapper to flush Langfuse without blocking event loop. self._initialized = True
Uses thread pool executor to run sync flush in background.
"""
if not _langfuse_client or not _export_executor:
return
try:
loop = asyncio.get_event_loop()
# Run flush in thread pool (non-blocking)
await loop.run_in_executor(_export_executor, _langfuse_client.flush)
logger.debug("📤 Langfuse flushed (async)")
except Exception as e:
logger.warning(f"⚠️ Async flush failed: {e}")
def get_callback_handler(
trace_id: str | None = None,
user_id: str | None = None,
session_id: str | None = None,
tags: list[str] | None = None,
**trace_kwargs,
) -> CallbackHandler | None:
"""
Get CallbackHandler with unique trace context.
Args:
trace_id: Optional unique trace ID
user_id: User ID for grouping traces by user (NOT set here - use propagate_attributes instead)
session_id: Session ID for grouping traces by session/conversation
tags: List of tags for filtering traces
**trace_kwargs: Additional trace attributes
Returns:
CallbackHandler instance + propagate_attributes context manager
Note:
Per Langfuse docs: use propagate_attributes(user_id=...) context manager
to properly set user_id across all observations in the trace.
This makes user_id appear as a filterable field in Langfuse UI.
"""
try:
if not _langfuse_client:
logger.warning("⚠️ Langfuse client not initialized")
return None return None
handler = CallbackHandler() async def async_flush(self):
logger.debug("✅ Langfuse CallbackHandler created") """
return handler Async wrapper to flush Langfuse without blocking event loop.
except Exception as e: Uses thread pool executor to run sync flush in background.
logger.warning(f"⚠️ CallbackHandler error: {e}") """
return None client = self.get_client()
if not client or not self._executor:
return
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(self._executor, client.flush)
logger.debug("📤 Langfuse flushed (async)")
except Exception as e:
logger.warning(f"⚠️ Async flush failed: {e}")
def get_callback_handler(self) -> CallbackHandler | None:
"""Get CallbackHandler instance."""
client = self.get_client()
if not client:
logger.warning("⚠️ Langfuse client not available")
return None
try:
handler = CallbackHandler()
logger.debug("✅ Langfuse CallbackHandler created")
return handler
except Exception as e:
logger.warning(f"⚠️ CallbackHandler error: {e}")
return None
@contextmanager
def langfuse_trace_context(user_id: str | None = None, session_id: str | None = None, tags: list[str] | None = None):
"""
Context manager to propagate user_id, session_id, tags to all observations.
Usage: # --- Singleton ---
with langfuse_trace_context(user_id="user_123", session_id="session_456"): _manager = LangfuseClientManager()
# All observations created here will have these attributes get_langfuse_client = _manager.get_client
await invoke_chain() async_flush_langfuse = _manager.async_flush
"""
attrs = {}
if user_id: def get_callback_handler() -> CallbackHandler | None:
attrs["user_id"] = user_id """Get CallbackHandler instance (wrapper for manager)."""
if session_id: return _manager.get_callback_handler()
attrs["session_id"] = session_id
# Tags are set via metadata, not propagate_attributes
with propagate_attributes(**attrs):
yield
...@@ -3,8 +3,8 @@ StarRocks Database Connection Utility ...@@ -3,8 +3,8 @@ StarRocks Database Connection Utility
Based on chatbot-rsa pattern Based on chatbot-rsa pattern
""" """
import logging
import asyncio import asyncio
import logging
from typing import Any from typing import Any
import aiomysql import aiomysql
...@@ -21,6 +21,29 @@ from config import ( ...@@ -21,6 +21,29 @@ from config import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["StarRocksConnection", "get_db_connection"]
class StarRocksConnectionManager:
"""
Singleton Class quản lý StarRocks Connection.
"""
def __init__(self):
self._connection: StarRocksConnection | None = None
def get_connection(self) -> "StarRocksConnection":
"""Lazy loading connection"""
if self._connection is None:
logger.info("🔧 [LAZY LOADING] Creating StarRocksConnection instance (first time)")
self._connection = StarRocksConnection()
return self._connection
# --- Singleton ---
_manager = StarRocksConnectionManager()
get_db_connection = _manager.get_connection
class StarRocksConnection: class StarRocksConnection:
# Shared connection (Singleton-like behavior) for all instances # Shared connection (Singleton-like behavior) for all instances
...@@ -41,6 +64,7 @@ class StarRocksConnection: ...@@ -41,6 +64,7 @@ class StarRocksConnection:
self.port = port or STARROCKS_PORT self.port = port or STARROCKS_PORT
# self.conn references the shared connection # self.conn references the shared connection
self.conn = None self.conn = None
logger.info(f"✅ StarRocksConnection initialized: {self.host}:{self.port}")
def connect(self): def connect(self):
""" """
...@@ -129,8 +153,8 @@ class StarRocksConnection: ...@@ -129,8 +153,8 @@ class StarRocksConnection:
db=self.database, db=self.database,
charset="utf8mb4", charset="utf8mb4",
cursorclass=aiomysql.DictCursor, cursorclass=aiomysql.DictCursor,
minsize=10, # Sẵn sàng 10 kết nối ngay lập tức (Cực nhanh cho Prod) minsize=10, # Sẵn sàng 10 kết nối ngay lập tức (Cực nhanh cho Prod)
maxsize=50, # Tăng nhẹ lên 50 (Cân bằng giữa throughput và memory) maxsize=50, # Tăng nhẹ lên 50 (Cân bằng giữa throughput và memory)
connect_timeout=10, connect_timeout=10,
) )
return StarRocksConnection._shared_pool return StarRocksConnection._shared_pool
...@@ -141,32 +165,31 @@ class StarRocksConnection: ...@@ -141,32 +165,31 @@ class StarRocksConnection:
""" """
max_retries = 3 max_retries = 3
last_error = None last_error = None
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
pool = await self.get_pool() pool = await self.get_pool()
# logger.info(f"🚀 Executing Async Query (Attempt {attempt+1}).") # logger.info(f"🚀 Executing Async Query (Attempt {attempt+1}).")
async with pool.acquire() as conn, conn.cursor() as cursor: async with pool.acquire() as conn, conn.cursor() as cursor:
await cursor.execute(query, params) await cursor.execute(query, params)
results = await cursor.fetchall() results = await cursor.fetchall()
# logger.info(f"📊 Async Query successful, returned {len(results)} rows") # logger.info(f"📊 Async Query successful, returned {len(results)} rows")
return [dict(row) for row in results] return [dict(row) for row in results]
except Exception as e: except Exception as e:
last_error = e last_error = e
logger.warning(f"⚠️ StarRocks DB Error (Attempt {attempt+1}/{max_retries}): {e}") logger.warning(f"⚠️ StarRocks DB Error (Attempt {attempt + 1}/{max_retries}): {e}")
if "Memory of process exceed limit" in str(e): if "Memory of process exceed limit" in str(e):
# Nếu StarRocks OOM, đợi một chút rồi thử lại # Nếu StarRocks OOM, đợi một chút rồi thử lại
await asyncio.sleep(0.5 * (attempt + 1)) await asyncio.sleep(0.5 * (attempt + 1))
continue continue
elif "Disconnected" in str(e) or "Lost connection" in str(e): if "Disconnected" in str(e) or "Lost connection" in str(e):
# Nếu mất kết nối, có thể pool bị stale, thử lại ngay # Nếu mất kết nối, có thể pool bị stale, thử lại ngay
continue continue
else: # Các lỗi khác (cú pháp,...) thì raise luôn
# Các lỗi khác (cú pháp,...) thì raise luôn raise
raise
logger.error(f"❌ Failed after {max_retries} attempts: {last_error}") logger.error(f"❌ Failed after {max_retries} attempts: {last_error}")
raise last_error raise last_error
......
...@@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles ...@@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
from api.chatbot_route import router as chatbot_router from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router from api.conservation_route import router as conservation_router
from common.langfuse_client import initialize_langfuse from common.langfuse_client import get_langfuse_client
from config import PORT from config import PORT
# Configure Logging # Configure Logging
...@@ -27,13 +27,11 @@ logging.basicConfig( ...@@ -27,13 +27,11 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ========================================== langfuse_client = get_langfuse_client()
# 🔥 LANGFUSE INITIALIZATION if langfuse_client:
# ========================================== logger.info("✅ Langfuse client ready (lazy loading)")
if initialize_langfuse():
logger.info("✅ Langfuse initialized successfully")
else: else:
logger.warning("⚠️ Langfuse initialization failed or keys missing") logger.warning("⚠️ Langfuse client not available (missing keys or disabled)")
app = FastAPI( app = FastAPI(
title="Contract AI Service", title="Contract AI Service",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment