Commit 95f6323f authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

fix: stabilize backend, connection timeout, suppress noise, and extract extra styling info

parent d116f990
...@@ -75,7 +75,10 @@ async def chat_controller( ...@@ -75,7 +75,10 @@ async def chat_controller(
session_id = f"{identity_key}-{str(uuid.uuid4())[:8]}" session_id = f"{identity_key}-{str(uuid.uuid4())[:8]}"
tags = ["stylist_pro", "user:authenticated" if is_authenticated else "user:anonymous"] tags = ["stylist_pro", "user:authenticated" if is_authenticated else "user:anonymous"]
langfuse_handler = get_callback_handler() langfuse_handler = get_callback_handler(
user_id=identity_key,
tags=tags,
)
exec_config = { exec_config = {
"configurable": { "configurable": {
"user_id": identity_key, "user_id": identity_key,
...@@ -88,7 +91,8 @@ async def chat_controller( ...@@ -88,7 +91,8 @@ async def chat_controller(
"metadata": { "metadata": {
"trace_id": trace_id, "trace_id": trace_id,
"session_id": session_id, "session_id": session_id,
} },
"run_name": "CANIFAGraph",
} }
# 5. Graph Invocation (Non-Streaming High-Performance) # 5. Graph Invocation (Non-Streaming High-Performance)
......
...@@ -21,15 +21,6 @@ from config import USE_LOCAL_SQLITE ...@@ -21,15 +21,6 @@ from config import USE_LOCAL_SQLITE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Constants for Product Enrichment (Hardcoded for clarity and decoupling)
TABLE_NAME = "shared_source.magento_product_dimension_with_text_embedding"
SELECT_COLUMNS = """
magento_ref_code, internal_ref_code, product_name,
sale_price, original_price, product_image_url_thumbnail,
product_web_url, size_scale, gender_by_product, product_line_vn,
product_color_code, product_color_name
"""
def route_after_classifier(state: StylistProState) -> str: def route_after_classifier(state: StylistProState) -> str:
"""Điều hướng dựa trên early_exit.""" """Điều hướng dựa trên early_exit."""
if state.get("early_exit"): if state.get("early_exit"):
...@@ -129,13 +120,33 @@ class CANIFAGraph: ...@@ -129,13 +120,33 @@ class CANIFAGraph:
ai_response = _extract_text(msg.content) ai_response = _extract_text(msg.content)
break break
# --- PRODUCT ENRICHMENT LAYER --- # --- PRODUCT ENRICHMENT LAYER (Refactored for Clean Base/Variant Mapping) ---
final_products = [] final_products = []
ai_product_ids = result.get("product_ids", []) ai_product_ids = [pid.upper().strip() for pid in result.get("product_ids", []) if pid]
tool_result_raw = result.get("tool_result") tool_result_raw = result.get("tool_result")
product_dict = {} product_dict = {}
base_sku_mapping = {} # Map base_sku -> list of variant skus
def _add_to_dicts(p: dict):
sku = str(p.get("magento_ref_code") or p.get("sku") or p.get("sku_code") or "").upper().strip()
if not sku:
return
# Ensure color fields exist
if "color_code" not in p: p["color_code"] = p.get("product_color_code", "")
if "color_name" not in p: p["color_name"] = p.get("product_color_name", "")
if "sku" not in p: p["sku"] = sku
product_dict[sku] = p
# Also map the base SKU (before the dash)
base_sku = sku.split("-")[0]
if base_sku not in base_sku_mapping:
base_sku_mapping[base_sku] = []
if sku not in base_sku_mapping[base_sku]:
base_sku_mapping[base_sku].append(sku)
# 1. Map products already in tool result (if any) # 1. Map products already in tool result (if any)
if tool_result_raw: if tool_result_raw:
try: try:
...@@ -143,102 +154,32 @@ class CANIFAGraph: ...@@ -143,102 +154,32 @@ class CANIFAGraph:
if parsed.get("status") == "success": if parsed.get("status") == "success":
all_products = parsed.get("products", []) or parsed.get("results", []) all_products = parsed.get("products", []) or parsed.get("results", [])
for p in all_products: for p in all_products:
sku = str(p.get("magento_ref_code") or p.get("sku") or p.get("sku_code") or "").upper().strip() _add_to_dicts(p)
if sku:
if "sku" not in p: p["sku"] = sku
if "color_code" not in p: p["color_code"] = p.get("product_color_code", "")
if "color_name" not in p: p["color_name"] = p.get("product_color_name", "")
product_dict[sku] = p
# Also map outfit recommendations # Also map outfit recommendations
recs = p.get("outfit_recommendations", []) recs = p.get("outfit_recommendations", [])
if isinstance(recs, list): if isinstance(recs, list):
for r in recs: for r in recs:
r_sku = str(r.get("match_product_code") or r.get("sku") or "").upper().strip() _add_to_dicts(r)
if r_sku:
if "sku" not in r: r["sku"] = r_sku
product_dict[r_sku] = r
except Exception: pass except Exception: pass
# 2. Enrich missing or incomplete product data via StarRocks (or SQLite fallback) # 2. (Removed) We no longer do explicit DB queries here.
wanted_ids = [pid.upper().strip() for pid in ai_product_ids if pid] # The tool_result should act as the single source of truth for the AI's context.
missing_ids = [pid for pid in wanted_ids if pid not in product_dict or not product_dict[pid].get("image")] # Any SKUs that are not in product_dict are hallucinations and will be skipped.
if missing_ids:
skus = list(dict.fromkeys(missing_ids))
base_skus = list(dict.fromkeys([s.split("-")[0] for s in skus]))
keys = list(dict.fromkeys(skus + base_skus))
try:
if USE_LOCAL_SQLITE:
raise Exception("StarRocks skipped due to USE_LOCAL_SQLITE=True")
db = get_db_connection()
if not db:
raise Exception("StarRocks connection is None")
placeholders = ",".join(["%s"] * len(keys))
sql = f"""
SELECT {SELECT_COLUMNS}
FROM {TABLE_NAME}
WHERE UPPER(magento_ref_code) IN ({placeholders})
OR UPPER(internal_ref_code) IN ({placeholders})
LIMIT 200
"""
rows = await db.execute_query_async(sql, params=tuple(keys + keys))
for r in rows or []:
card = {
"sku": (r.get("magento_ref_code") or r.get("internal_ref_code") or "").strip(),
"name": r.get("product_name", ""),
"price": int(r.get("sale_price") or 0),
"original_price": int(r.get("original_price") or 0),
"image": r.get("product_image_url_thumbnail", ""),
"url": r.get("product_web_url", ""),
"sizes": r.get("size_scale", ""),
"gender": r.get("gender_by_product", ""),
"product_line": r.get("product_line_vn", ""),
"color_code": r.get("product_color_code", ""),
"color_name": r.get("product_color_name", ""),
}
product_dict[card["sku"].upper()] = card
except Exception as e:
logger.warning(f"⚠️ StarRocks Enrichment failed: {e}. Falling back to SQLite...")
from common.sqlite_db import sqlite_db
placeholders = ",".join(["?"] * len(keys))
sql = f"""
SELECT {SELECT_COLUMNS}
FROM sr__test_db__magento_product_dimension_with_text_embedding
WHERE UPPER(magento_ref_code) IN ({placeholders})
OR UPPER(internal_ref_code) IN ({placeholders})
LIMIT 200
"""
try:
rows = await sqlite_db.fetch_all(sql, params=tuple(keys + keys))
for r in rows or []:
card = {
"sku": (r.get("magento_ref_code") or r.get("internal_ref_code") or "").strip(),
"name": r.get("product_name", ""),
"price": int(r.get("sale_price") or 0),
"original_price": int(r.get("original_price") or 0),
"image": r.get("product_image_url_thumbnail", ""),
"url": r.get("product_web_url", ""),
"sizes": r.get("size_scale", ""),
"gender": r.get("gender_by_product", ""),
"product_line": r.get("product_line_vn", ""),
"color_code": r.get("product_color_code", ""),
"color_name": r.get("product_color_name", ""),
}
product_dict[card["sku"].upper()] = card
except Exception as sqlite_e:
logger.error(f"❌ SQLite Enrichment fallback failed: {sqlite_e}")
# 3. Assemble final products in the order requested by AI # 3. Assemble final products in the order requested by AI
# If AI requested a base SKU, we return the first available variant for it
seen = set() seen = set()
for pid in wanted_ids: for pid in ai_product_ids:
if pid in product_dict and pid not in seen: target_sku = None
final_products.append(product_dict[pid]) if pid in product_dict:
seen.add(pid) target_sku = pid
elif pid in base_sku_mapping and base_sku_mapping[pid]:
target_sku = base_sku_mapping[pid][0]
if target_sku and target_sku not in seen:
final_products.append(product_dict[target_sku])
seen.add(target_sku)
return { return {
"response": ai_response, "response": ai_response,
......
...@@ -119,36 +119,7 @@ async def classifier_node(state: StylistProState, config: RunnableConfig): ...@@ -119,36 +119,7 @@ async def classifier_node(state: StylistProState, config: RunnableConfig):
elif output.tool_args: elif output.tool_args:
tool_args = output.tool_args tool_args = output.tool_args
# EXECUTE TOOL INLINE # Tool execution is deferred to tools_node.py to prevent double execution and ensure proper tracing.
if tool_name and tool_args is not None:
from agent.tools.get_tools import get_all_tools
all_tools = get_all_tools()
# Fallback handle name 'lead_search_tool' -> 'data_retrieval_tool'
search_name = "data_retrieval_tool" if tool_name == "lead_search_tool" else tool_name
target_tool = next((t for t in all_tools if t.name == search_name), None)
if target_tool:
try:
logger.info(f"🚀 Executing inline tool: {search_name}")
# Some tools might be sync or async, checking for ainvoke support
if hasattr(target_tool, "ainvoke"):
tool_res_str = await target_tool.ainvoke(tool_args, config=config)
else:
tool_res_str = target_tool.invoke(tool_args)
tool_result = tool_res_str
diagnostics.append({
"step": "inline_tool",
"label": f"🛠️ Executed {search_name}",
"content": f"Result length: {len(str(tool_res_str))}",
"elapsed_ms": 0
})
except Exception as e:
logger.error(f"❌ Inline tool error: {e}")
tool_result = json.dumps({"status": "error", "message": str(e), "products": []})
else:
logger.warning(f"⚠️ Tool {search_name} not found in get_all_tools()")
return { return {
"tool_name_used": tool_name, "tool_name_used": tool_name,
......
...@@ -38,16 +38,11 @@ class ClassifierOutput(BaseModel): ...@@ -38,16 +38,11 @@ class ClassifierOutput(BaseModel):
product_ids: List[str] = Field(default_factory=list) product_ids: List[str] = Field(default_factory=list)
class StylistProInsight(BaseModel): class StylistProInsight(BaseModel):
USER: Optional[str] = "Chưa rõ" USER: Optional[Any] = "Chưa rõ"
TARGET: Optional[str] = "Chưa rõ" TARGET: Optional[Any] = "Chưa rõ"
GOAL: Optional[str] = "Chưa rõ" STAGE: Optional[Any] = "BROWSE"
CONSTRAINTS: Optional[str] = None LATEST_PRODUCT_INTEREST: Optional[Any] = ""
STAGE: Optional[str] = "BROWSE" SUMMARY_HISTORY: Optional[Any] = ""
STAGE_NUM: int = 1
TONE: Optional[str] = "Friendly"
LATEST_PRODUCT_INTEREST: Optional[str] = ""
SUMMARY_HISTORY: Optional[str] = ""
BEHAVIORAL_HINTS: List[str] = Field(default_factory=list)
class StylistOutput(BaseModel): class StylistOutput(BaseModel):
ai_response: str ai_response: str
......
...@@ -16,8 +16,11 @@ async def stylist_node(state: StylistProState, config: RunnableConfig): ...@@ -16,8 +16,11 @@ async def stylist_node(state: StylistProState, config: RunnableConfig):
""" """
Stylist Node: Sinh câu trả lời tư vấn và cập nhật Insight khách hàng. Stylist Node: Sinh câu trả lời tư vấn và cập nhật Insight khách hàng.
""" """
llm = create_llm(model_name=config.get("configurable", {}).get("model_name"), streaming=False) llm = create_llm(
model_name=config.get("configurable", {}).get("model_name"),
streaming=False,
max_tokens=1500 # Tăng nhẹ lên 1500 (trước là 500) vì AI hay viết dài mô tả sản phẩm
)
insight_dict = json.loads(state.get("user_insight") or "{}") insight_dict = json.loads(state.get("user_insight") or "{}")
injection = format_stylist_pro_injection(insight_dict) injection = format_stylist_pro_injection(insight_dict)
......
...@@ -45,7 +45,7 @@ async def tools_node(state: StylistProState, config: RunnableConfig): ...@@ -45,7 +45,7 @@ async def tools_node(state: StylistProState, config: RunnableConfig):
logger.info(f"🛠️ [ToolsNode] Executing '{tool_name}' with args: {tool_args}") logger.info(f"🛠️ [ToolsNode] Executing '{tool_name}' with args: {tool_args}")
# Execute Tool # Execute Tool
tool_result = await tool_fn.ainvoke(tool_args) tool_result = await tool_fn.ainvoke(tool_args, config=config)
tool_elapsed = (time.time() - t_start) * 1000 tool_elapsed = (time.time() - t_start) * 1000
diagnostics.append({ diagnostics.append({
......
...@@ -105,10 +105,12 @@ Bạn là Chuyên gia Thời trang (Stylist Pro) của CANIFA. Bạn tư vấn d ...@@ -105,10 +105,12 @@ Bạn là Chuyên gia Thời trang (Stylist Pro) của CANIFA. Bạn tư vấn d
</styling_philosophy> </styling_philosophy>
<user_memory_update> <user_memory_update>
Cập nhật 12 trường Insight: Cập nhật 5 trường Insight:
- STAGE: BROWSE (tìm kiếm) -> COMPARE (hỏi sâu tính năng/giá) -> DECIDE (chốt đơn/hỏi size). - USER: Ai đang chat (VD: mẹ, vợ, nam thanh niên).
- LATEST_PRODUCT_INTEREST: Ghi rõ Tên + SKU sản phẩm vừa tư vấn chính. - TARGET: Đối tượng mua đồ cho (VD: bé trai, chồng, bản thân).
- SUMMARY_HISTORY: Tóm tắt các bước đã tư vấn. - STAGE: BROWSE (tìm kiếm) -> COMPARE (hỏi sâu) -> DECIDE (chốt đơn/hỏi size).
- LATEST_PRODUCT_INTEREST: Tên + SKU sản phẩm vừa tư vấn chính.
- SUMMARY_HISTORY: Tóm tắt 1 câu khách đang cần gì.
</user_memory_update> </user_memory_update>
<formatting_rules> <formatting_rules>
...@@ -134,7 +136,7 @@ Trả về DUY NHẤT JSON: ...@@ -134,7 +136,7 @@ Trả về DUY NHẤT JSON:
{{ {{
"ai_response": "nội dung tư vấn chuyên nghiệp", "ai_response": "nội dung tư vấn chuyên nghiệp",
"product_ids": ["SKU1", "SKU2", ...], "product_ids": ["SKU1", "SKU2", ...],
"user_insight": {{ ...updated 12 fields... }} "user_insight": {{ ...updated 5 fields... }}
}} }}
</output_format> </output_format>
...@@ -147,12 +149,9 @@ def format_stylist_pro_injection(insight: dict) -> str: ...@@ -147,12 +149,9 @@ def format_stylist_pro_injection(insight: dict) -> str:
"""Format InsightJSON dict thành đoạn XML context cho prompt.""" """Format InsightJSON dict thành đoạn XML context cho prompt."""
template = """ template = """
<user_memory> <user_memory>
- User Type: {user} - User: {user}
- Target: {target} - Target: {target}
- Goal: {goal} - Stage: {stage}
- Constraints: {constraints}
- Stage: {stage} (Level {stage_num})
- Tone: {tone}
- Latest Interest: {latest} - Latest Interest: {latest}
- History Summary: {summary} - History Summary: {summary}
</user_memory> </user_memory>
...@@ -160,11 +159,7 @@ def format_stylist_pro_injection(insight: dict) -> str: ...@@ -160,11 +159,7 @@ def format_stylist_pro_injection(insight: dict) -> str:
return template.format( return template.format(
user=insight.get("USER", "Chưa rõ"), user=insight.get("USER", "Chưa rõ"),
target=insight.get("TARGET", "Chưa rõ"), target=insight.get("TARGET", "Chưa rõ"),
goal=insight.get("GOAL", "Chưa rõ"),
constraints=insight.get("CONSTRAINTS") or "Không",
stage=insight.get("STAGE", "BROWSE"), stage=insight.get("STAGE", "BROWSE"),
stage_num=insight.get("STAGE_NUM", 1),
tone=insight.get("TONE", "Friendly"),
latest=insight.get("LATEST_PRODUCT_INTEREST") or "Chưa có", latest=insight.get("LATEST_PRODUCT_INTEREST") or "Chưa có",
summary=insight.get("SUMMARY_HISTORY") or "Chưa có" summary=insight.get("SUMMARY_HISTORY") or "Chưa có"
) )
...@@ -332,6 +332,8 @@ async def data_retrieval_tool( ...@@ -332,6 +332,8 @@ async def data_retrieval_tool(
else: else:
use_sqlite = (db_source == "sqlite") use_sqlite = (db_source == "sqlite")
logger.info("Gía trị truyển xuống searches: %s", searches)
shared_user_insight = user_insight or configurable.get("user_insight") shared_user_insight = user_insight or configurable.get("user_insight")
per_search_results: List[Dict[str, Any]] = [] per_search_results: List[Dict[str, Any]] = []
......
...@@ -404,7 +404,7 @@ class SearchEngine: ...@@ -404,7 +404,7 @@ class SearchEngine:
for p in products: for p in products:
raw = p.get("outfit_recommendations") raw = p.get("outfit_recommendations")
# Fallback to product DB tags if no inferred occasion # Fallback to product DB tags if no inferred occasion
prod_occ = occasion_context prod_occ = occasion_context
if not prod_occ: if not prod_occ:
...@@ -421,7 +421,7 @@ class SearchEngine: ...@@ -421,7 +421,7 @@ class SearchEngine:
t_list = prod_tags_raw t_list = prod_tags_raw
else: else:
t_list = [] t_list = []
found_occ = [t for t in t_list if any(o in str(t).lower() for o in ["công sở", "đi làm", "đi chơi", "dạo phố", "mặc nhà", "mặc ngủ", "thể thao", "đi tiệc"])] found_occ = [t for t in t_list if any(o in str(t).lower() for o in ["công sở", "đi làm", "đi chơi", "dạo phố", "mặc nhà", "mặc ngủ", "thể thao", "đi tiệc"])]
if found_occ: if found_occ:
prod_occ = f" Rất hợp để {found_occ[0].lower()}." prod_occ = f" Rất hợp để {found_occ[0].lower()}."
...@@ -434,13 +434,13 @@ class SearchEngine: ...@@ -434,13 +434,13 @@ class SearchEngine:
parsed = json.loads(raw) parsed = json.loads(raw)
# Giới hạn max 5 outfit để tránh hàng trăm SP # Giới hạn max 5 outfit để tránh hàng trăm SP
parsed = parsed[:5] parsed = parsed[:5]
# Thêm context dịp mặc vào reason # Thêm context dịp mặc vào reason
if prod_occ: if prod_occ:
for outfit in parsed: for outfit in parsed:
if "reason" in outfit and prod_occ not in outfit["reason"]: if "reason" in outfit and prod_occ not in outfit["reason"]:
outfit["reason"] += prod_occ outfit["reason"] += prod_occ
p["outfit_recommendations"] = parsed p["outfit_recommendations"] = parsed
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
p["outfit_recommendations"] = [] p["outfit_recommendations"] = []
...@@ -463,10 +463,10 @@ class SearchEngine: ...@@ -463,10 +463,10 @@ class SearchEngine:
if "cross_sell" in data: metadata["cross_sell"] = data["cross_sell"] if "cross_sell" in data: metadata["cross_sell"] = data["cross_sell"]
if "tinh_nang_vai" in data: metadata["tinh_nang_vai"] = data["tinh_nang_vai"] if "tinh_nang_vai" in data: metadata["tinh_nang_vai"] = data["tinh_nang_vai"]
if "tags" in data: metadata["tags"] = data["tags"] if "tags" in data: metadata["tags"] = data["tags"]
if "mo_ta_chinh" in data and data["mo_ta_chinh"]: if "mo_ta_chinh" in data and data["mo_ta_chinh"]:
p["description_text"] = data["mo_ta_chinh"] p["description_text"] = data["mo_ta_chinh"]
p["styling_metadata"] = metadata p["styling_metadata"] = metadata
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
pass pass
......
...@@ -43,7 +43,7 @@ class ConversationManager: ...@@ -43,7 +43,7 @@ class ConversationManager:
"""Create the chat history table if it doesn't exist""" """Create the chat history table if it doesn't exist"""
try: try:
pool = await self._get_pool() pool = await self._get_pool()
async with pool.connection(timeout=2.0) as conn: async with pool.connection(timeout=10.0) as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
# Set timezone to Vietnam for this session # Set timezone to Vietnam for this session
await cursor.execute("SET timezone = 'Asia/Ho_Chi_Minh'") await cursor.execute("SET timezone = 'Asia/Ho_Chi_Minh'")
...@@ -84,7 +84,7 @@ class ConversationManager: ...@@ -84,7 +84,7 @@ class ConversationManager:
vietnam_tz = timezone(timedelta(hours=7)) vietnam_tz = timezone(timedelta(hours=7))
timestamp = datetime.now(vietnam_tz) timestamp = datetime.now(vietnam_tz)
# Transaction block: atomic insert # Transaction block: atomic insert
async with pool.connection(timeout=2.0) as conn: async with pool.connection(timeout=10.0) as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
# Set timezone to Vietnam for this session # Set timezone to Vietnam for this session
await cursor.execute("SET timezone = 'Asia/Ho_Chi_Minh'") await cursor.execute("SET timezone = 'Asia/Ho_Chi_Minh'")
...@@ -178,7 +178,7 @@ class ConversationManager: ...@@ -178,7 +178,7 @@ class ConversationManager:
final_query = sql.SQL(" ").join(query_parts) final_query = sql.SQL(" ").join(query_parts)
pool = await self._get_pool() pool = await self._get_pool()
async with pool.connection(timeout=2.0) as conn, conn.cursor() as cursor: async with pool.connection(timeout=10.0) as conn, conn.cursor() as cursor:
# Set timezone to Vietnam for this session # Set timezone to Vietnam for this session
await cursor.execute("SET timezone = 'Asia/Ho_Chi_Minh'") await cursor.execute("SET timezone = 'Asia/Ho_Chi_Minh'")
...@@ -235,7 +235,7 @@ class ConversationManager: ...@@ -235,7 +235,7 @@ class ConversationManager:
end_of_day = now.replace(hour=23, minute=59, second=59, microsecond=999999) end_of_day = now.replace(hour=23, minute=59, second=59, microsecond=999999)
pool = await self._get_pool() pool = await self._get_pool()
async with pool.connection(timeout=2.0) as conn: async with pool.connection(timeout=10.0) as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
query = sql.SQL(""" query = sql.SQL("""
UPDATE {table} UPDATE {table}
...@@ -258,7 +258,7 @@ class ConversationManager: ...@@ -258,7 +258,7 @@ class ConversationManager:
"""Clear all chat history for an identity""" """Clear all chat history for an identity"""
try: try:
pool = await self._get_pool() pool = await self._get_pool()
async with pool.connection(timeout=2.0) as conn: async with pool.connection(timeout=10.0) as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
query = sql.SQL("DELETE FROM {table} WHERE identity_key = %s").format( query = sql.SQL("DELETE FROM {table} WHERE identity_key = %s").format(
table=sql.Identifier(self.table_name) table=sql.Identifier(self.table_name)
......
...@@ -80,7 +80,7 @@ class LangfuseClientManager: ...@@ -80,7 +80,7 @@ class LangfuseClientManager:
except Exception as e: except Exception as e:
logger.warning(f"⚠️ Async flush failed: {e}") logger.warning(f"⚠️ Async flush failed: {e}")
def get_callback_handler(self) -> CallbackHandler | None: def get_callback_handler(self, **kwargs) -> CallbackHandler | None:
"""Get CallbackHandler instance.""" """Get CallbackHandler instance."""
client = self.get_client() client = self.get_client()
if not client: if not client:
...@@ -88,7 +88,7 @@ class LangfuseClientManager: ...@@ -88,7 +88,7 @@ class LangfuseClientManager:
return None return None
try: try:
handler = CallbackHandler() handler = CallbackHandler(**kwargs)
logger.debug("✅ Langfuse CallbackHandler created") logger.debug("✅ Langfuse CallbackHandler created")
return handler return handler
except Exception as e: except Exception as e:
...@@ -102,6 +102,6 @@ get_langfuse_client = _manager.get_client ...@@ -102,6 +102,6 @@ get_langfuse_client = _manager.get_client
async_flush_langfuse = _manager.async_flush async_flush_langfuse = _manager.async_flush
def get_callback_handler() -> CallbackHandler | None: def get_callback_handler(**kwargs) -> CallbackHandler | None:
"""Get CallbackHandler instance (wrapper for manager).""" """Get CallbackHandler instance (wrapper for manager)."""
return _manager.get_callback_handler() return _manager.get_callback_handler(**kwargs)
...@@ -34,6 +34,7 @@ class LLMFactory: ...@@ -34,6 +34,7 @@ class LLMFactory:
streaming: bool = True, streaming: bool = True,
json_mode: bool = False, json_mode: bool = False,
api_key: str | None = None, api_key: str | None = None,
max_tokens: int | None = None,
) -> BaseChatModel: ) -> BaseChatModel:
""" """
Get or create an LLM instance from cache. Get or create an LLM instance from cache.
...@@ -48,14 +49,14 @@ class LLMFactory: ...@@ -48,14 +49,14 @@ class LLMFactory:
Configured LLM instance Configured LLM instance
""" """
clean_model = model_name.split("/")[-1] if "/" in model_name else model_name clean_model = model_name.split("/")[-1] if "/" in model_name else model_name
cache_key = (clean_model, streaming, json_mode, api_key) cache_key = (clean_model, streaming, json_mode, api_key, max_tokens)
if cache_key in self._cache: if cache_key in self._cache:
logger.debug(f"♻️ Using cached model: {clean_model}") logger.debug(f"♻️ Using cached model: {clean_model} (tokens: {max_tokens or 'default'})")
return self._cache[cache_key] return self._cache[cache_key]
logger.info(f"Creating new LLM instance: {model_name}") logger.info(f"Creating new LLM instance: {model_name} (tokens: {max_tokens or 'default'})")
return self._create_instance(model_name, streaming, json_mode, api_key) return self._create_instance(model_name, streaming, json_mode, api_key, max_tokens)
def _create_instance( def _create_instance(
self, self,
...@@ -63,12 +64,13 @@ class LLMFactory: ...@@ -63,12 +64,13 @@ class LLMFactory:
streaming: bool = False, streaming: bool = False,
json_mode: bool = False, json_mode: bool = False,
api_key: str | None = None, api_key: str | None = None,
max_tokens: int | None = None,
) -> BaseChatModel: ) -> BaseChatModel:
"""Create and cache a new OpenAI LLM instance.""" """Create and cache a new OpenAI LLM instance."""
try: try:
llm = self._create_openai(model_name, streaming, json_mode, api_key) llm = self._create_openai(model_name, streaming, json_mode, api_key, max_tokens)
cache_key = (model_name, streaming, json_mode, api_key) cache_key = (model_name, streaming, json_mode, api_key, max_tokens)
self._cache[cache_key] = llm self._cache[cache_key] = llm
return llm return llm
...@@ -76,7 +78,7 @@ class LLMFactory: ...@@ -76,7 +78,7 @@ class LLMFactory:
logger.error(f"❌ Failed to create model {model_name}: {e}") logger.error(f"❌ Failed to create model {model_name}: {e}")
raise raise
def _create_openai(self, model_name: str, streaming: bool, json_mode: bool, api_key: str | None) -> BaseChatModel: def _create_openai(self, model_name: str, streaming: bool, json_mode: bool, api_key: str | None, max_tokens: int | None = None) -> BaseChatModel:
"""Create OpenAI-compatible model instance (OpenAI or Groq).""" """Create OpenAI-compatible model instance (OpenAI or Groq)."""
# --- Auto-detect provider --- # --- Auto-detect provider ---
...@@ -105,7 +107,7 @@ class LLMFactory: ...@@ -105,7 +107,7 @@ class LLMFactory:
"streaming": streaming, "streaming": streaming,
"api_key": key, "api_key": key,
"temperature": 0, "temperature": 0,
"max_tokens": 1500, "max_tokens": max_tokens or 800, # Default an toàn, chống cháy túi
} }
if base_url: if base_url:
...@@ -159,9 +161,10 @@ def create_llm( ...@@ -159,9 +161,10 @@ def create_llm(
streaming: bool = True, streaming: bool = True,
json_mode: bool = False, json_mode: bool = False,
api_key: str | None = None, api_key: str | None = None,
max_tokens: int | None = None,
) -> BaseChatModel: ) -> BaseChatModel:
"""Create or get cached LLM instance.""" """Create or get cached LLM instance."""
return _factory.get_model(model_name, streaming=streaming, json_mode=json_mode, api_key=api_key) return _factory.get_model(model_name, streaming=streaming, json_mode=json_mode, api_key=api_key, max_tokens=max_tokens)
def init_llm_factory(skip_warmup: bool = True) -> None: def init_llm_factory(skip_warmup: bool = True) -> None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment