Commit 87361fe6 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

Refactor: Implement singleton lazy loading pattern for DB & Langfuse clients

- StarRocksConnection: Use StarRocksConnectionManager singleton with lazy loading
- Langfuse: Implement LangfuseClientManager singleton with lazy loading
- EmbeddingService: Already using singleton pattern
- Remove context managers (Langfuse auto-traces LangChain)
- Fix imports across agent/tools and server
- Clean up unnecessary comments and fix code organization
parent 4be13d17
......@@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from common.conversation_manager import ConversationManager, get_conversation_manager
from common.langfuse_client import get_callback_handler, langfuse_trace_context
from common.langfuse_client import get_callback_handler
from common.llm_factory import create_llm
from config import DEFAULT_MODEL
......@@ -61,54 +61,39 @@ async def chat_controller(
)
try:
# 🔥 Wrap graph execution với langfuse_trace_context để set user_id cho tất cả observations
with langfuse_trace_context(user_id=user_id, session_id=user_id):
# TỐI ƯU: Chạy Graph
result = await graph.ainvoke(initial_state, config=exec_config)
result = await graph.ainvoke(initial_state, config=exec_config)
all_product_ids = _extract_product_ids(result.get("messages", []))
# TỐI ƯU: Extract IDs từ Tool Messages một lần duy nhất
all_product_ids = _extract_product_ids(result.get("messages", []))
ai_raw_content = result.get("ai_response").content if result.get("ai_response") else ""
logger.info(f"💾 [RAW AI OUTPUT]:\n{ai_raw_content}")
# TỐI ƯU: Xử lý AI Response
ai_raw_content = result.get("ai_response").content if result.get("ai_response") else ""
logger.info(f"💾 [RAW AI OUTPUT]:\n{ai_raw_content}")
# Parse JSON để lấy text response và product_ids từ AI
ai_text_response = ai_raw_content
try:
# Vì json_mode=True, OpenAI sẽ nhả raw JSON
ai_json = json.loads(ai_raw_content)
# Extract text response từ JSON
ai_text_response = ai_json.get("ai_response", ai_raw_content)
# Merge product_ids từ AI JSON (nếu có) - KHÔNG dùng set() vì dict unhashable
explicit_ids = ai_json.get("product_ids", [])
if explicit_ids and isinstance(explicit_ids, list):
# Merge và deduplicate by SKU
seen_skus = {p["sku"] for p in all_product_ids if "sku" in p}
for product in explicit_ids:
if isinstance(product, dict) and product.get("sku") not in seen_skus:
all_product_ids.append(product)
seen_skus.add(product.get("sku"))
except (json.JSONDecodeError, Exception) as e:
# Nếu AI trả về text thường (hiếm khi xảy ra trong JSON mode) thì ignore
logger.warning(f"Could not parse AI response as JSON: {e}")
pass
# BACKGROUND TASK: Lưu history nhanh gọn
background_tasks.add_task(
_handle_post_chat_async,
memory=memory,
user_id=user_id,
human_query=query,
ai_msg=AIMessage(content=ai_text_response),
)
return {
"ai_response": ai_text_response, # CHỈ text, không phải JSON
"product_ids": all_product_ids, # Array of product objects
}
ai_text_response = ai_raw_content
try:
ai_json = json.loads(ai_raw_content)
ai_text_response = ai_json.get("ai_response", ai_raw_content)
explicit_ids = ai_json.get("product_ids", [])
if explicit_ids and isinstance(explicit_ids, list):
seen_skus = {p["sku"] for p in all_product_ids if "sku" in p}
for product in explicit_ids:
if isinstance(product, dict) and product.get("sku") not in seen_skus:
all_product_ids.append(product)
seen_skus.add(product.get("sku"))
except (json.JSONDecodeError, Exception) as e:
logger.warning(f"Could not parse AI response as JSON: {e}")
background_tasks.add_task(
_handle_post_chat_async,
memory=memory,
user_id=user_id,
human_query=query,
ai_msg=AIMessage(content=ai_text_response),
)
return {
"ai_response": ai_text_response,
"product_ids": all_product_ids,
}
except Exception as e:
logger.error(f"💥 Chat error for user {user_id}: {e}", exc_info=True)
......@@ -171,8 +156,6 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images:
"tags": "chatbot,production",
}
# 🔥 CallbackHandler - sẽ được wrap trong langfuse_trace_context để set user_id
# Per Langfuse docs: propagate_attributes() handles user_id propagation
langfuse_handler = get_callback_handler()
exec_config = RunnableConfig(
......
......@@ -4,7 +4,7 @@ from langchain_core.tools import tool
from pydantic import BaseModel, Field
from common.embedding_service import create_embedding_async
from common.starrocks_connection import StarRocksConnection
from common.starrocks_connection import get_db_connection
logger = logging.getLogger(__name__)
......@@ -56,7 +56,7 @@ async def canifa_knowledge_search(query: str) -> str:
LIMIT 4
"""
sr = StarRocksConnection()
sr = get_db_connection()
results = await sr.execute_query_async(sql)
if not results:
......
......@@ -12,8 +12,8 @@ from decimal import Decimal
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from agent.tools.product_search_helpers import build_starrocks_query, save_preview_to_log
from common.starrocks_connection import StarRocksConnection
from agent.tools.product_search_helpers import build_starrocks_query
from common.starrocks_connection import get_db_connection
# from langsmith import traceable
......@@ -89,9 +89,9 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str:
"""
logger.info("🔧 [DEBUG] data_retrieval_tool STARTED")
try:
logger.info("🔧 [DEBUG] Creating StarRocksConnection instance")
db = StarRocksConnection()
logger.info("🔧 [DEBUG] StarRocksConnection created successfully")
logger.info("🔧 [DEBUG] Getting DB connection (singleton)")
db = get_db_connection()
logger.info("🔧 [DEBUG] DB connection retrieved successfully")
# 0. Log input parameters (Đúng ý bro)
logger.info(f"📥 [Tool Input] data_retrieval_tool received {len(searches)} items:")
......@@ -128,7 +128,7 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str:
return json.dumps({"status": "error", "message": str(e)})
async def _execute_single_search(db: StarRocksConnection, item: SearchItem) -> list[dict]:
async def _execute_single_search(db, item: SearchItem) -> list[dict]:
"""Thực thi một search query đơn lẻ (Async)."""
try:
logger.info(f"🔧 [DEBUG] _execute_single_search STARTED for query: {item.query[:50] if item.query else 'None'}")
......@@ -149,10 +149,10 @@ async def _execute_single_search(db: StarRocksConnection, item: SearchItem) -> l
logger.info(f"🔧 [DEBUG] Query executed, got {len(products)} products")
logger.info(f"⏱️ [TIMER] DB Query Execution Time: {db_time:.2f}ms")
logger.info(f"⏱️ [TIMER] Total Time (Build + DB): {query_build_time + db_time:.2f}ms")
# Ghi log DB Preview (Kết quả thực tế) vào Background Task
search_label = item.magento_ref_code if item.magento_ref_code else item.query
asyncio.create_task(save_preview_to_log(search_label, products))
# asyncio.create_task(save_preview_to_log(search_label, products))
return _format_product_results(products)
except Exception as e:
......
......@@ -6,136 +6,102 @@ With propagate_attributes for proper user_id tracking
import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from langfuse import Langfuse, get_client, propagate_attributes
from langfuse import Langfuse, get_client
from langfuse.langchain import CallbackHandler
from config import (
LANGFUSE_BASE_URL,
LANGFUSE_PUBLIC_KEY,
LANGFUSE_SECRET_KEY,
)
logger = logging.getLogger(__name__)
# ⚡ Global state for async batch export
_langfuse_client: Langfuse | None = None
_export_executor: ThreadPoolExecutor | None = None
_pending_traces: list = []
_export_task: asyncio.Task | None = None
_batch_lock = asyncio.Lock if hasattr(asyncio, "Lock") else None
__all__ = ["async_flush_langfuse", "get_callback_handler", "get_langfuse_client"]
def initialize_langfuse() -> bool:
class LangfuseClientManager:
"""
1. Set environment variables
2. Initialize Langfuse client
3. Setup thread pool for async batch export
Singleton manager for Langfuse client.
Lazy loading - only initialize when first needed.
"""
global _langfuse_client, _export_executor
if not LANGFUSE_PUBLIC_KEY or not LANGFUSE_SECRET_KEY:
logger.warning("⚠️ LANGFUSE KEYS MISSING. Tracing disabled.")
return False
def __init__(self):
self._client: Langfuse | None = None
self._executor: ThreadPoolExecutor | None = None
self._initialized = False
# Set environment
os.environ["LANGFUSE_PUBLIC_KEY"] = LANGFUSE_PUBLIC_KEY
os.environ["LANGFUSE_SECRET_KEY"] = LANGFUSE_SECRET_KEY
os.environ["LANGFUSE_BASE_URL"] = LANGFUSE_BASE_URL or "https://cloud.langfuse.com"
os.environ["LANGFUSE_TIMEOUT"] = "10" # 10s timeout, not blocking
def get_client(self) -> Langfuse | None:
"""
Lazy loading - initialize Langfuse client on first call.
"""
if self._initialized:
return self._client
# Disable default flush to prevent blocking
os.environ["LANGFUSE_FLUSHINTERVAL"] = "300" # 5 min, very infrequent
logger.info("🔧 [LAZY LOADING] Initializing Langfuse client (first time)")
try:
_langfuse_client = get_client()
_export_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="langfuse_export")
if not LANGFUSE_PUBLIC_KEY or not LANGFUSE_SECRET_KEY:
logger.warning("⚠️ LANGFUSE KEYS MISSING. Tracing disabled.")
self._initialized = True
return None
if _langfuse_client.auth_check():
logger.info("✅ Langfuse Ready! (async batch export)")
return True
logger.error("❌ Langfuse auth failed")
return False
try:
self._client = get_client()
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="langfuse_export")
except Exception as e:
logger.error(f"❌ Langfuse init error: {e}")
return False
if self._client.auth_check():
logger.info("✅ Langfuse Ready! (async batch export)")
self._initialized = True
return self._client
logger.error("❌ Langfuse auth failed")
self._initialized = True
return None
async def async_flush_langfuse():
"""
Async wrapper to flush Langfuse without blocking event loop.
Uses thread pool executor to run sync flush in background.
"""
if not _langfuse_client or not _export_executor:
return
try:
loop = asyncio.get_event_loop()
# Run flush in thread pool (non-blocking)
await loop.run_in_executor(_export_executor, _langfuse_client.flush)
logger.debug("📤 Langfuse flushed (async)")
except Exception as e:
logger.warning(f"⚠️ Async flush failed: {e}")
def get_callback_handler(
trace_id: str | None = None,
user_id: str | None = None,
session_id: str | None = None,
tags: list[str] | None = None,
**trace_kwargs,
) -> CallbackHandler | None:
"""
Get CallbackHandler with unique trace context.
Args:
trace_id: Optional unique trace ID
user_id: User ID for grouping traces by user (NOT set here - use propagate_attributes instead)
session_id: Session ID for grouping traces by session/conversation
tags: List of tags for filtering traces
**trace_kwargs: Additional trace attributes
Returns:
CallbackHandler instance + propagate_attributes context manager
Note:
Per Langfuse docs: use propagate_attributes(user_id=...) context manager
to properly set user_id across all observations in the trace.
This makes user_id appear as a filterable field in Langfuse UI.
"""
try:
if not _langfuse_client:
logger.warning("⚠️ Langfuse client not initialized")
except Exception as e:
logger.error(f"❌ Langfuse init error: {e}")
self._initialized = True
return None
handler = CallbackHandler()
logger.debug("✅ Langfuse CallbackHandler created")
return handler
except Exception as e:
logger.warning(f"⚠️ CallbackHandler error: {e}")
return None
async def async_flush(self):
"""
Async wrapper to flush Langfuse without blocking event loop.
Uses thread pool executor to run sync flush in background.
"""
client = self.get_client()
if not client or not self._executor:
return
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(self._executor, client.flush)
logger.debug("📤 Langfuse flushed (async)")
except Exception as e:
logger.warning(f"⚠️ Async flush failed: {e}")
def get_callback_handler(self) -> CallbackHandler | None:
"""Get CallbackHandler instance."""
client = self.get_client()
if not client:
logger.warning("⚠️ Langfuse client not available")
return None
try:
handler = CallbackHandler()
logger.debug("✅ Langfuse CallbackHandler created")
return handler
except Exception as e:
logger.warning(f"⚠️ CallbackHandler error: {e}")
return None
@contextmanager
def langfuse_trace_context(user_id: str | None = None, session_id: str | None = None, tags: list[str] | None = None):
"""
Context manager to propagate user_id, session_id, tags to all observations.
Usage:
with langfuse_trace_context(user_id="user_123", session_id="session_456"):
# All observations created here will have these attributes
await invoke_chain()
"""
attrs = {}
if user_id:
attrs["user_id"] = user_id
if session_id:
attrs["session_id"] = session_id
# Tags are set via metadata, not propagate_attributes
with propagate_attributes(**attrs):
yield
# --- Singleton ---
_manager = LangfuseClientManager()
get_langfuse_client = _manager.get_client
async_flush_langfuse = _manager.async_flush
def get_callback_handler() -> CallbackHandler | None:
"""Get CallbackHandler instance (wrapper for manager)."""
return _manager.get_callback_handler()
......@@ -3,8 +3,8 @@ StarRocks Database Connection Utility
Based on chatbot-rsa pattern
"""
import logging
import asyncio
import logging
from typing import Any
import aiomysql
......@@ -21,6 +21,29 @@ from config import (
logger = logging.getLogger(__name__)
__all__ = ["StarRocksConnection", "get_db_connection"]
class StarRocksConnectionManager:
"""
Singleton Class quản lý StarRocks Connection.
"""
def __init__(self):
self._connection: StarRocksConnection | None = None
def get_connection(self) -> "StarRocksConnection":
"""Lazy loading connection"""
if self._connection is None:
logger.info("🔧 [LAZY LOADING] Creating StarRocksConnection instance (first time)")
self._connection = StarRocksConnection()
return self._connection
# --- Singleton ---
_manager = StarRocksConnectionManager()
get_db_connection = _manager.get_connection
class StarRocksConnection:
# Shared connection (Singleton-like behavior) for all instances
......@@ -41,6 +64,7 @@ class StarRocksConnection:
self.port = port or STARROCKS_PORT
# self.conn references the shared connection
self.conn = None
logger.info(f"✅ StarRocksConnection initialized: {self.host}:{self.port}")
def connect(self):
"""
......@@ -129,8 +153,8 @@ class StarRocksConnection:
db=self.database,
charset="utf8mb4",
cursorclass=aiomysql.DictCursor,
minsize=10, # Sẵn sàng 10 kết nối ngay lập tức (Cực nhanh cho Prod)
maxsize=50, # Tăng nhẹ lên 50 (Cân bằng giữa throughput và memory)
minsize=10, # Sẵn sàng 10 kết nối ngay lập tức (Cực nhanh cho Prod)
maxsize=50, # Tăng nhẹ lên 50 (Cân bằng giữa throughput và memory)
connect_timeout=10,
)
return StarRocksConnection._shared_pool
......@@ -141,32 +165,31 @@ class StarRocksConnection:
"""
max_retries = 3
last_error = None
for attempt in range(max_retries):
try:
pool = await self.get_pool()
# logger.info(f"🚀 Executing Async Query (Attempt {attempt+1}).")
async with pool.acquire() as conn, conn.cursor() as cursor:
await cursor.execute(query, params)
results = await cursor.fetchall()
# logger.info(f"📊 Async Query successful, returned {len(results)} rows")
return [dict(row) for row in results]
except Exception as e:
last_error = e
logger.warning(f"⚠️ StarRocks DB Error (Attempt {attempt+1}/{max_retries}): {e}")
logger.warning(f"⚠️ StarRocks DB Error (Attempt {attempt + 1}/{max_retries}): {e}")
if "Memory of process exceed limit" in str(e):
# Nếu StarRocks OOM, đợi một chút rồi thử lại
await asyncio.sleep(0.5 * (attempt + 1))
continue
elif "Disconnected" in str(e) or "Lost connection" in str(e):
if "Disconnected" in str(e) or "Lost connection" in str(e):
# Nếu mất kết nối, có thể pool bị stale, thử lại ngay
continue
else:
# Các lỗi khác (cú pháp,...) thì raise luôn
raise
# Các lỗi khác (cú pháp,...) thì raise luôn
raise
logger.error(f"❌ Failed after {max_retries} attempts: {last_error}")
raise last_error
......
......@@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router
from common.langfuse_client import initialize_langfuse
from common.langfuse_client import get_langfuse_client
from config import PORT
# Configure Logging
......@@ -27,13 +27,11 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
# ==========================================
# 🔥 LANGFUSE INITIALIZATION
# ==========================================
if initialize_langfuse():
logger.info("✅ Langfuse initialized successfully")
langfuse_client = get_langfuse_client()
if langfuse_client:
logger.info("✅ Langfuse client ready (lazy loading)")
else:
logger.warning("⚠️ Langfuse initialization failed or keys missing")
logger.warning("⚠️ Langfuse client not available (missing keys or disabled)")
app = FastAPI(
title="Contract AI Service",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment