Commit 87361fe6 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

Refactor: Implement singleton lazy loading pattern for DB & Langfuse clients

- StarRocksConnection: Use StarRocksConnectionManager singleton with lazy loading
- Langfuse: Implement LangfuseClientManager singleton with lazy loading
- EmbeddingService: Already using singleton pattern
- Remove context managers (Langfuse auto-traces LangChain)
- Fix imports across agent/tools and server
- Clean up unnecessary comments and fix code organization
parent 4be13d17
...@@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage ...@@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from common.conversation_manager import ConversationManager, get_conversation_manager from common.conversation_manager import ConversationManager, get_conversation_manager
from common.langfuse_client import get_callback_handler, langfuse_trace_context from common.langfuse_client import get_callback_handler
from common.llm_factory import create_llm from common.llm_factory import create_llm
from config import DEFAULT_MODEL from config import DEFAULT_MODEL
...@@ -61,42 +61,27 @@ async def chat_controller( ...@@ -61,42 +61,27 @@ async def chat_controller(
) )
try: try:
# 🔥 Wrap graph execution với langfuse_trace_context để set user_id cho tất cả observations
with langfuse_trace_context(user_id=user_id, session_id=user_id):
# TỐI ƯU: Chạy Graph
result = await graph.ainvoke(initial_state, config=exec_config) result = await graph.ainvoke(initial_state, config=exec_config)
# TỐI ƯU: Extract IDs từ Tool Messages một lần duy nhất
all_product_ids = _extract_product_ids(result.get("messages", [])) all_product_ids = _extract_product_ids(result.get("messages", []))
# TỐI ƯU: Xử lý AI Response
ai_raw_content = result.get("ai_response").content if result.get("ai_response") else "" ai_raw_content = result.get("ai_response").content if result.get("ai_response") else ""
logger.info(f"💾 [RAW AI OUTPUT]:\n{ai_raw_content}") logger.info(f"💾 [RAW AI OUTPUT]:\n{ai_raw_content}")
# Parse JSON để lấy text response và product_ids từ AI
ai_text_response = ai_raw_content ai_text_response = ai_raw_content
try: try:
# Vì json_mode=True, OpenAI sẽ nhả raw JSON
ai_json = json.loads(ai_raw_content) ai_json = json.loads(ai_raw_content)
# Extract text response từ JSON
ai_text_response = ai_json.get("ai_response", ai_raw_content) ai_text_response = ai_json.get("ai_response", ai_raw_content)
# Merge product_ids từ AI JSON (nếu có) - KHÔNG dùng set() vì dict unhashable
explicit_ids = ai_json.get("product_ids", []) explicit_ids = ai_json.get("product_ids", [])
if explicit_ids and isinstance(explicit_ids, list): if explicit_ids and isinstance(explicit_ids, list):
# Merge và deduplicate by SKU
seen_skus = {p["sku"] for p in all_product_ids if "sku" in p} seen_skus = {p["sku"] for p in all_product_ids if "sku" in p}
for product in explicit_ids: for product in explicit_ids:
if isinstance(product, dict) and product.get("sku") not in seen_skus: if isinstance(product, dict) and product.get("sku") not in seen_skus:
all_product_ids.append(product) all_product_ids.append(product)
seen_skus.add(product.get("sku")) seen_skus.add(product.get("sku"))
except (json.JSONDecodeError, Exception) as e: except (json.JSONDecodeError, Exception) as e:
# Nếu AI trả về text thường (hiếm khi xảy ra trong JSON mode) thì ignore
logger.warning(f"Could not parse AI response as JSON: {e}") logger.warning(f"Could not parse AI response as JSON: {e}")
pass
# BACKGROUND TASK: Lưu history nhanh gọn
background_tasks.add_task( background_tasks.add_task(
_handle_post_chat_async, _handle_post_chat_async,
memory=memory, memory=memory,
...@@ -106,8 +91,8 @@ async def chat_controller( ...@@ -106,8 +91,8 @@ async def chat_controller(
) )
return { return {
"ai_response": ai_text_response, # CHỈ text, không phải JSON "ai_response": ai_text_response,
"product_ids": all_product_ids, # Array of product objects "product_ids": all_product_ids,
} }
except Exception as e: except Exception as e:
...@@ -171,8 +156,6 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images: ...@@ -171,8 +156,6 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images:
"tags": "chatbot,production", "tags": "chatbot,production",
} }
# 🔥 CallbackHandler - sẽ được wrap trong langfuse_trace_context để set user_id
# Per Langfuse docs: propagate_attributes() handles user_id propagation
langfuse_handler = get_callback_handler() langfuse_handler = get_callback_handler()
exec_config = RunnableConfig( exec_config = RunnableConfig(
......
...@@ -4,7 +4,7 @@ from langchain_core.tools import tool ...@@ -4,7 +4,7 @@ from langchain_core.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from common.embedding_service import create_embedding_async from common.embedding_service import create_embedding_async
from common.starrocks_connection import StarRocksConnection from common.starrocks_connection import get_db_connection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,7 +56,7 @@ async def canifa_knowledge_search(query: str) -> str: ...@@ -56,7 +56,7 @@ async def canifa_knowledge_search(query: str) -> str:
LIMIT 4 LIMIT 4
""" """
sr = StarRocksConnection() sr = get_db_connection()
results = await sr.execute_query_async(sql) results = await sr.execute_query_async(sql)
if not results: if not results:
......
...@@ -12,8 +12,8 @@ from decimal import Decimal ...@@ -12,8 +12,8 @@ from decimal import Decimal
from langchain_core.tools import tool from langchain_core.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from agent.tools.product_search_helpers import build_starrocks_query, save_preview_to_log from agent.tools.product_search_helpers import build_starrocks_query
from common.starrocks_connection import StarRocksConnection from common.starrocks_connection import get_db_connection
# from langsmith import traceable # from langsmith import traceable
...@@ -89,9 +89,9 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str: ...@@ -89,9 +89,9 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str:
""" """
logger.info("🔧 [DEBUG] data_retrieval_tool STARTED") logger.info("🔧 [DEBUG] data_retrieval_tool STARTED")
try: try:
logger.info("🔧 [DEBUG] Creating StarRocksConnection instance") logger.info("🔧 [DEBUG] Getting DB connection (singleton)")
db = StarRocksConnection() db = get_db_connection()
logger.info("🔧 [DEBUG] StarRocksConnection created successfully") logger.info("🔧 [DEBUG] DB connection retrieved successfully")
# 0. Log input parameters (Đúng ý bro) # 0. Log input parameters (Đúng ý bro)
logger.info(f"📥 [Tool Input] data_retrieval_tool received {len(searches)} items:") logger.info(f"📥 [Tool Input] data_retrieval_tool received {len(searches)} items:")
...@@ -128,7 +128,7 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str: ...@@ -128,7 +128,7 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str:
return json.dumps({"status": "error", "message": str(e)}) return json.dumps({"status": "error", "message": str(e)})
async def _execute_single_search(db: StarRocksConnection, item: SearchItem) -> list[dict]: async def _execute_single_search(db, item: SearchItem) -> list[dict]:
"""Thực thi một search query đơn lẻ (Async).""" """Thực thi một search query đơn lẻ (Async)."""
try: try:
logger.info(f"🔧 [DEBUG] _execute_single_search STARTED for query: {item.query[:50] if item.query else 'None'}") logger.info(f"🔧 [DEBUG] _execute_single_search STARTED for query: {item.query[:50] if item.query else 'None'}")
...@@ -152,7 +152,7 @@ async def _execute_single_search(db: StarRocksConnection, item: SearchItem) -> l ...@@ -152,7 +152,7 @@ async def _execute_single_search(db: StarRocksConnection, item: SearchItem) -> l
# Ghi log DB Preview (Kết quả thực tế) vào Background Task # Ghi log DB Preview (Kết quả thực tế) vào Background Task
search_label = item.magento_ref_code if item.magento_ref_code else item.query search_label = item.magento_ref_code if item.magento_ref_code else item.query
asyncio.create_task(save_preview_to_log(search_label, products)) # asyncio.create_task(save_preview_to_log(search_label, products))
return _format_product_results(products) return _format_product_results(products)
except Exception as e: except Exception as e:
......
...@@ -6,112 +6,88 @@ With propagate_attributes for proper user_id tracking ...@@ -6,112 +6,88 @@ With propagate_attributes for proper user_id tracking
import asyncio import asyncio
import logging import logging
import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from langfuse import Langfuse, get_client, propagate_attributes from langfuse import Langfuse, get_client
from langfuse.langchain import CallbackHandler from langfuse.langchain import CallbackHandler
from config import ( from config import (
LANGFUSE_BASE_URL,
LANGFUSE_PUBLIC_KEY, LANGFUSE_PUBLIC_KEY,
LANGFUSE_SECRET_KEY, LANGFUSE_SECRET_KEY,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ⚡ Global state for async batch export __all__ = ["async_flush_langfuse", "get_callback_handler", "get_langfuse_client"]
_langfuse_client: Langfuse | None = None
_export_executor: ThreadPoolExecutor | None = None
_pending_traces: list = []
_export_task: asyncio.Task | None = None
_batch_lock = asyncio.Lock if hasattr(asyncio, "Lock") else None
def initialize_langfuse() -> bool: class LangfuseClientManager:
""" """
1. Set environment variables Singleton manager for Langfuse client.
2. Initialize Langfuse client Lazy loading - only initialize when first needed.
3. Setup thread pool for async batch export
""" """
global _langfuse_client, _export_executor
if not LANGFUSE_PUBLIC_KEY or not LANGFUSE_SECRET_KEY: def __init__(self):
logger.warning("⚠️ LANGFUSE KEYS MISSING. Tracing disabled.") self._client: Langfuse | None = None
return False self._executor: ThreadPoolExecutor | None = None
self._initialized = False
def get_client(self) -> Langfuse | None:
"""
Lazy loading - initialize Langfuse client on first call.
"""
if self._initialized:
return self._client
# Set environment logger.info("🔧 [LAZY LOADING] Initializing Langfuse client (first time)")
os.environ["LANGFUSE_PUBLIC_KEY"] = LANGFUSE_PUBLIC_KEY
os.environ["LANGFUSE_SECRET_KEY"] = LANGFUSE_SECRET_KEY
os.environ["LANGFUSE_BASE_URL"] = LANGFUSE_BASE_URL or "https://cloud.langfuse.com"
os.environ["LANGFUSE_TIMEOUT"] = "10" # 10s timeout, not blocking
# Disable default flush to prevent blocking if not LANGFUSE_PUBLIC_KEY or not LANGFUSE_SECRET_KEY:
os.environ["LANGFUSE_FLUSHINTERVAL"] = "300" # 5 min, very infrequent logger.warning("⚠️ LANGFUSE KEYS MISSING. Tracing disabled.")
self._initialized = True
return None
try: try:
_langfuse_client = get_client() self._client = get_client()
_export_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="langfuse_export") self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="langfuse_export")
if _langfuse_client.auth_check(): if self._client.auth_check():
logger.info("✅ Langfuse Ready! (async batch export)") logger.info("✅ Langfuse Ready! (async batch export)")
return True self._initialized = True
return self._client
logger.error("❌ Langfuse auth failed") logger.error("❌ Langfuse auth failed")
return False self._initialized = True
return None
except Exception as e: except Exception as e:
logger.error(f"❌ Langfuse init error: {e}") logger.error(f"❌ Langfuse init error: {e}")
return False self._initialized = True
return None
async def async_flush_langfuse(): async def async_flush(self):
""" """
Async wrapper to flush Langfuse without blocking event loop. Async wrapper to flush Langfuse without blocking event loop.
Uses thread pool executor to run sync flush in background. Uses thread pool executor to run sync flush in background.
""" """
if not _langfuse_client or not _export_executor: client = self.get_client()
if not client or not self._executor:
return return
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# Run flush in thread pool (non-blocking) await loop.run_in_executor(self._executor, client.flush)
await loop.run_in_executor(_export_executor, _langfuse_client.flush)
logger.debug("📤 Langfuse flushed (async)") logger.debug("📤 Langfuse flushed (async)")
except Exception as e: except Exception as e:
logger.warning(f"⚠️ Async flush failed: {e}") logger.warning(f"⚠️ Async flush failed: {e}")
def get_callback_handler(self) -> CallbackHandler | None:
def get_callback_handler( """Get CallbackHandler instance."""
trace_id: str | None = None, client = self.get_client()
user_id: str | None = None, if not client:
session_id: str | None = None, logger.warning("⚠️ Langfuse client not available")
tags: list[str] | None = None,
**trace_kwargs,
) -> CallbackHandler | None:
"""
Get CallbackHandler with unique trace context.
Args:
trace_id: Optional unique trace ID
user_id: User ID for grouping traces by user (NOT set here - use propagate_attributes instead)
session_id: Session ID for grouping traces by session/conversation
tags: List of tags for filtering traces
**trace_kwargs: Additional trace attributes
Returns:
CallbackHandler instance + propagate_attributes context manager
Note:
Per Langfuse docs: use propagate_attributes(user_id=...) context manager
to properly set user_id across all observations in the trace.
This makes user_id appear as a filterable field in Langfuse UI.
"""
try:
if not _langfuse_client:
logger.warning("⚠️ Langfuse client not initialized")
return None return None
try:
handler = CallbackHandler() handler = CallbackHandler()
logger.debug("✅ Langfuse CallbackHandler created") logger.debug("✅ Langfuse CallbackHandler created")
return handler return handler
...@@ -120,22 +96,12 @@ def get_callback_handler( ...@@ -120,22 +96,12 @@ def get_callback_handler(
return None return None
@contextmanager # --- Singleton ---
def langfuse_trace_context(user_id: str | None = None, session_id: str | None = None, tags: list[str] | None = None): _manager = LangfuseClientManager()
""" get_langfuse_client = _manager.get_client
Context manager to propagate user_id, session_id, tags to all observations. async_flush_langfuse = _manager.async_flush
Usage:
with langfuse_trace_context(user_id="user_123", session_id="session_456"): def get_callback_handler() -> CallbackHandler | None:
# All observations created here will have these attributes """Get CallbackHandler instance (wrapper for manager)."""
await invoke_chain() return _manager.get_callback_handler()
"""
attrs = {}
if user_id:
attrs["user_id"] = user_id
if session_id:
attrs["session_id"] = session_id
# Tags are set via metadata, not propagate_attributes
with propagate_attributes(**attrs):
yield
...@@ -3,8 +3,8 @@ StarRocks Database Connection Utility ...@@ -3,8 +3,8 @@ StarRocks Database Connection Utility
Based on chatbot-rsa pattern Based on chatbot-rsa pattern
""" """
import logging
import asyncio import asyncio
import logging
from typing import Any from typing import Any
import aiomysql import aiomysql
...@@ -21,6 +21,29 @@ from config import ( ...@@ -21,6 +21,29 @@ from config import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["StarRocksConnection", "get_db_connection"]
class StarRocksConnectionManager:
"""
Singleton Class quản lý StarRocks Connection.
"""
def __init__(self):
self._connection: StarRocksConnection | None = None
def get_connection(self) -> "StarRocksConnection":
"""Lazy loading connection"""
if self._connection is None:
logger.info("🔧 [LAZY LOADING] Creating StarRocksConnection instance (first time)")
self._connection = StarRocksConnection()
return self._connection
# --- Singleton ---
_manager = StarRocksConnectionManager()
get_db_connection = _manager.get_connection
class StarRocksConnection: class StarRocksConnection:
# Shared connection (Singleton-like behavior) for all instances # Shared connection (Singleton-like behavior) for all instances
...@@ -41,6 +64,7 @@ class StarRocksConnection: ...@@ -41,6 +64,7 @@ class StarRocksConnection:
self.port = port or STARROCKS_PORT self.port = port or STARROCKS_PORT
# self.conn references the shared connection # self.conn references the shared connection
self.conn = None self.conn = None
logger.info(f"✅ StarRocksConnection initialized: {self.host}:{self.port}")
def connect(self): def connect(self):
""" """
...@@ -155,15 +179,14 @@ class StarRocksConnection: ...@@ -155,15 +179,14 @@ class StarRocksConnection:
except Exception as e: except Exception as e:
last_error = e last_error = e
logger.warning(f"⚠️ StarRocks DB Error (Attempt {attempt+1}/{max_retries}): {e}") logger.warning(f"⚠️ StarRocks DB Error (Attempt {attempt + 1}/{max_retries}): {e}")
if "Memory of process exceed limit" in str(e): if "Memory of process exceed limit" in str(e):
# Nếu StarRocks OOM, đợi một chút rồi thử lại # Nếu StarRocks OOM, đợi một chút rồi thử lại
await asyncio.sleep(0.5 * (attempt + 1)) await asyncio.sleep(0.5 * (attempt + 1))
continue continue
elif "Disconnected" in str(e) or "Lost connection" in str(e): if "Disconnected" in str(e) or "Lost connection" in str(e):
# Nếu mất kết nối, có thể pool bị stale, thử lại ngay # Nếu mất kết nối, có thể pool bị stale, thử lại ngay
continue continue
else:
# Các lỗi khác (cú pháp,...) thì raise luôn # Các lỗi khác (cú pháp,...) thì raise luôn
raise raise
......
...@@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles ...@@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
from api.chatbot_route import router as chatbot_router from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router from api.conservation_route import router as conservation_router
from common.langfuse_client import initialize_langfuse from common.langfuse_client import get_langfuse_client
from config import PORT from config import PORT
# Configure Logging # Configure Logging
...@@ -27,13 +27,11 @@ logging.basicConfig( ...@@ -27,13 +27,11 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ========================================== langfuse_client = get_langfuse_client()
# 🔥 LANGFUSE INITIALIZATION if langfuse_client:
# ========================================== logger.info("✅ Langfuse client ready (lazy loading)")
if initialize_langfuse():
logger.info("✅ Langfuse initialized successfully")
else: else:
logger.warning("⚠️ Langfuse initialization failed or keys missing") logger.warning("⚠️ Langfuse client not available (missing keys or disabled)")
app = FastAPI( app = FastAPI(
title="Contract AI Service", title="Contract AI Service",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment