Commit 28274420 authored by Hoanganhvu123's avatar Hoanganhvu123

feat: Migrate from LangGraph to Agno framework

parent f057ad1e
""" """
Fashion Q&A Agent Package Fashion Q&A Agent Package - Agno Framework
""" """
from .graph import build_graph # Only export what's needed for Agno
from .models import AgentConfig, AgentState, get_config from .agno_agent import get_agno_agent
from .agno_controller import chat_controller
from .models import QueryRequest
__all__ = [ __all__ = [
"AgentConfig", "get_agno_agent",
"AgentState", "chat_controller",
"build_graph", "QueryRequest",
"get_config",
] ]
"""
CANIFA Agent với Agno Framework
Thay thế LangGraph bằng Agno
"""
import logging
from typing import TYPE_CHECKING, Any, cast
# Type checking imports (only used for type hints)
if TYPE_CHECKING:
from agno.agent import Agent as AgentType
from agno.db.base import BaseDb as BaseDbType
from agno.models.openai import OpenAIChat as OpenAIChatType
else:
AgentType = Any # type: ignore
BaseDbType = Any # type: ignore
OpenAIChatType = Any # type: ignore
# Runtime imports with fallback
try:
from agno.agent import Agent
from agno.db.base import BaseDb
from agno.models.openai import OpenAIChat
except ImportError:
# Fallback nếu chưa install agno
Agent = None
BaseDb = Any # type: ignore
OpenAIChat = None
from common.conversation_manager import get_conversation_manager
from config import DEFAULT_MODEL, OPENAI_API_KEY
from .prompt import get_system_prompt
from .tools.agno_tools import get_agno_tools
logger = logging.getLogger(__name__)
def create_agno_model(model_name: str = DEFAULT_MODEL, json_mode: bool = False):
"""
Tạo Agno model từ config.py
"""
if OpenAIChat is None:
raise ImportError("Agno not installed. Run: pip install agno")
return OpenAIChat(
id=model_name,
api_key=OPENAI_API_KEY,
# Agno sẽ handle json_mode nếu cần
)
async def create_agno_agent(
model_name: str = DEFAULT_MODEL,
json_mode: bool = False,
) -> AgentType: # type: ignore
"""
Tạo Agno Agent với ConversationManager (có memory)
Args:
model_name: Model name từ config.py
json_mode: Enable JSON output
Returns:
Configured Agno Agent
"""
# Tạo model từ config
model = create_agno_model(model_name, json_mode)
# Lấy tools (đã convert sang Agno format)
tools = get_agno_tools()
# Lấy system prompt
system_prompt = get_system_prompt()
# Lấy ConversationManager (đã implement BaseDb interface)
db = await get_conversation_manager()
if Agent is None:
raise ImportError("Agno not installed. Run: pip install agno")
# Type cast: ConversationManager implements BaseDb interface (duck typing)
# Runtime sẽ hoạt động vì ConversationManager có đủ methods cần thiết
db_cast = cast(BaseDbType, db) # type: ignore[assignment]
# Tạo Agno Agent với DB (có memory)
agent = Agent(
name="CANIFA Agent",
model=model,
db=db_cast, # Dùng ConversationManager (implement BaseDb interface)
tools=tools,
instructions=system_prompt, # Agno dùng instructions thay vì system_prompt
add_history_to_context=True, # Bật history
num_history_runs=20, # Load 20 messages gần nhất
markdown=True,
)
logger.info(f"✅ Agno Agent created with model: {model_name} (WITH MEMORY)")
return agent
# Singleton instance
_agno_agent_instance: AgentType | None = None # type: ignore
async def get_agno_agent(
model_name: str = DEFAULT_MODEL,
json_mode: bool = False,
) -> AgentType: # type: ignore
"""
Get or create Agno Agent singleton (với memory)
"""
global _agno_agent_instance
if _agno_agent_instance is None:
# Tạo agent với ConversationManager (có memory)
_agno_agent_instance = await create_agno_agent(
model_name=model_name,
json_mode=json_mode,
)
return _agno_agent_instance
def reset_agno_agent():
"""Reset singleton for testing"""
global _agno_agent_instance
_agno_agent_instance = None
"""
CANIFA Agent Controller với Agno Framework
"""
import json
import logging
from typing import Any
from fastapi import BackgroundTasks
from common.langfuse_client import langfuse_trace_context
from config import DEFAULT_MODEL
from .agno_agent import get_agno_agent
logger = logging.getLogger(__name__)
async def chat_controller(
query: str,
user_id: str,
background_tasks: BackgroundTasks,
model_name: str = DEFAULT_MODEL,
images: list[str] | None = None,
) -> dict:
"""
Controller với Agno Agent (có memory tự động).
Agno tự động load/save history qua ConversationManager.
"""
logger.info(f"▶️ Agno chat_controller | User: {user_id} | Model: {model_name}")
try:
agent = await get_agno_agent(model_name=model_name, json_mode=True)
with langfuse_trace_context(user_id=user_id, session_id=user_id):
# Agno tự động load history và save sau khi respond (memory enabled)
result = agent.run(query, session_id=user_id)
# Extract response
ai_content = str(result.content if hasattr(result, "content") and result.content else str(result))
logger.info(f"💾 AI Response: {ai_content[:200]}...")
# Parse response và extract products
ai_text, product_ids = _parse_agno_response(result, ai_content)
return {
"ai_response": ai_text,
"product_ids": product_ids,
}
except Exception as e:
logger.error(f"💥 Agno chat error for user {user_id}: {e}", exc_info=True)
raise
def _parse_agno_response(result: Any, ai_content: str) -> tuple[str, list[dict]]:
"""
Parse Agno response và extract AI text + product IDs.
Returns: (ai_text_response, product_ids)
"""
ai_text = ai_content
product_ids = []
# Try parse JSON response
try:
ai_json = json.loads(ai_content)
ai_text = ai_json.get("ai_response", ai_content)
product_ids = ai_json.get("product_ids", []) or []
except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Response is not JSON, using raw text: {e}")
# Extract products từ tool results
if hasattr(result, "messages"):
tool_products = _extract_products_from_messages(result.messages)
# Merge và deduplicate
seen_skus = {p.get("sku") for p in product_ids if isinstance(p, dict) and "sku" in p}
for product in tool_products:
if isinstance(product, dict) and product.get("sku") not in seen_skus:
product_ids.append(product)
seen_skus.add(product.get("sku"))
return ai_text, product_ids
def _extract_products_from_messages(messages: list) -> list[dict]:
"""Extract products từ Agno tool messages."""
products = []
seen_skus = set()
for msg in messages:
if not (hasattr(msg, "content") and isinstance(msg.content, str)):
continue
try:
tool_result = json.loads(msg.content)
if tool_result.get("status") != "success":
continue
# Handle multi-search format
if "results" in tool_result:
for result_item in tool_result["results"]:
products.extend(_parse_products(result_item.get("products", []), seen_skus))
# Handle single search format
elif "products" in tool_result:
products.extend(_parse_products(tool_result["products"], seen_skus))
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.debug(f"Skip invalid tool message: {e}")
continue
return products
def _parse_products(products: list[dict], seen_skus: set[str]) -> list[dict]:
"""Parse và format products, skip duplicates."""
parsed = []
for product in products:
if not isinstance(product, dict):
continue
sku = product.get("internal_ref_code")
if not sku or sku in seen_skus:
continue
seen_skus.add(sku)
parsed.append({
"sku": sku,
"name": product.get("magento_product_name", ""),
"price": product.get("price_vnd", 0),
"sale_price": product.get("sale_price_vnd"),
"url": product.get("magento_url_key", ""),
"thumbnail_image_url": product.get("thumbnail_image_url", ""),
})
return parsed
"""
Fashion Q&A Agent Controller
Langfuse will auto-trace via LangChain integration (no code changes needed).
"""
import json
import logging
import uuid
from fastapi import BackgroundTasks
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from common.conversation_manager import ConversationManager, get_conversation_manager
from common.langfuse_client import get_callback_handler, langfuse_trace_context
from common.llm_factory import create_llm
from config import DEFAULT_MODEL
from .graph import build_graph
from .models import AgentState, get_config
from .tools.get_tools import get_all_tools
logger = logging.getLogger(__name__)
async def chat_controller(
query: str,
user_id: str,
background_tasks: BackgroundTasks,
model_name: str = DEFAULT_MODEL,
images: list[str] | None = None,
) -> dict:
"""
Controller main logic for non-streaming chat requests.
Langfuse will automatically trace all LangChain operations.
"""
logger.info(f"▶️ Starting chat_controller with model: {model_name} for user: {user_id}")
config = get_config()
config.model_name = model_name
# Enable JSON mode to ensure structured output
llm = create_llm(model_name=model_name, streaming=False, json_mode=True)
tools = get_all_tools()
graph = build_graph(config, llm=llm, tools=tools)
# Init ConversationManager (Singleton)
memory = await get_conversation_manager()
# LOAD HISTORY & Prepare State (Optimize: history logic remains solid)
history_dicts = await memory.get_chat_history(user_id, limit=20)
history = []
for h in reversed(history_dicts):
msg_cls = HumanMessage if h["is_human"] else AIMessage
history.append(msg_cls(content=h["message"]))
initial_state, exec_config = _prepare_execution_context(
query=query, user_id=user_id, history=history, images=images
)
try:
# 🔥 Wrap graph execution với langfuse_trace_context để set user_id cho tất cả observations
with langfuse_trace_context(user_id=user_id, session_id=user_id):
# TỐI ƯU: Chạy Graph
result = await graph.ainvoke(initial_state, config=exec_config)
# TỐI ƯU: Extract IDs từ Tool Messages một lần duy nhất
all_product_ids = _extract_product_ids(result.get("messages", []))
# TỐI ƯU: Xử lý AI Response
ai_raw_content = result.get("ai_response").content if result.get("ai_response") else ""
logger.info(f"💾 [RAW AI OUTPUT]:\n{ai_raw_content}")
# Parse JSON để lấy text response và product_ids từ AI
ai_text_response = ai_raw_content
try:
# Vì json_mode=True, OpenAI sẽ nhả raw JSON
ai_json = json.loads(ai_raw_content)
# Extract text response từ JSON
ai_text_response = ai_json.get("ai_response", ai_raw_content)
# Merge product_ids từ AI JSON (nếu có) - KHÔNG dùng set() vì dict unhashable
explicit_ids = ai_json.get("product_ids", [])
if explicit_ids and isinstance(explicit_ids, list):
# Merge và deduplicate by SKU
seen_skus = {p["sku"] for p in all_product_ids if "sku" in p}
for product in explicit_ids:
if isinstance(product, dict) and product.get("sku") not in seen_skus:
all_product_ids.append(product)
seen_skus.add(product.get("sku"))
except (json.JSONDecodeError, Exception) as e:
# Nếu AI trả về text thường (hiếm khi xảy ra trong JSON mode) thì ignore
logger.warning(f"Could not parse AI response as JSON: {e}")
pass
# BACKGROUND TASK: Lưu history nhanh gọn
background_tasks.add_task(
_handle_post_chat_async,
memory=memory,
user_id=user_id,
human_query=query,
ai_msg=AIMessage(content=ai_text_response),
)
return {
"ai_response": ai_text_response, # CHỈ text, không phải JSON
"product_ids": all_product_ids, # Array of product objects
}
except Exception as e:
logger.error(f"💥 Chat error for user {user_id}: {e}", exc_info=True)
raise
def _extract_product_ids(messages: list) -> list[dict]:
"""
Extract full product info from tool messages (data_retrieval_tool results).
Returns list of product objects with: sku, name, price, sale_price, url, thumbnail_image_url.
"""
products = []
seen_skus = set()
for msg in messages:
if isinstance(msg, ToolMessage):
try:
# Tool result is JSON string
tool_result = json.loads(msg.content)
# Check if tool returned products
if tool_result.get("status") == "success" and "products" in tool_result:
for product in tool_result["products"]:
sku = product.get("internal_ref_code")
if sku and sku not in seen_skus:
seen_skus.add(sku)
# Extract full product info
product_obj = {
"sku": sku,
"name": product.get("magento_product_name", ""),
"price": product.get("price_vnd", 0),
"sale_price": product.get("sale_price_vnd"), # null nếu không sale
"url": product.get("magento_url_key", ""),
"thumbnail_image_url": product.get("thumbnail_image_url", ""),
}
products.append(product_obj)
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.debug(f"Could not parse tool message for products: {e}")
continue
return products
def _prepare_execution_context(query: str, user_id: str, history: list, images: list | None):
"""Prepare initial state and execution config for the graph run."""
initial_state: AgentState = {
"user_query": HumanMessage(content=query),
"messages": [HumanMessage(content=query)],
"history": history,
"user_id": user_id,
"images_embedding": [],
"ai_response": None,
}
run_id = str(uuid.uuid4())
# Metadata for LangChain (tags for logging/filtering)
metadata = {
"run_id": run_id,
"tags": "chatbot,production",
}
# 🔥 CallbackHandler - sẽ được wrap trong langfuse_trace_context để set user_id
# Per Langfuse docs: propagate_attributes() handles user_id propagation
langfuse_handler = get_callback_handler()
exec_config = RunnableConfig(
configurable={
"user_id": user_id,
"transient_images": images or [],
"run_id": run_id,
},
run_id=run_id,
metadata=metadata,
callbacks=[langfuse_handler] if langfuse_handler else [],
)
return initial_state, exec_config
async def _handle_post_chat_async(
memory: ConversationManager, user_id: str, human_query: str, ai_msg: AIMessage | None
):
"""Save chat history in background task after response is sent."""
if ai_msg:
try:
await memory.save_conversation_turn(user_id, human_query, ai_msg.content)
logger.debug(f"Saved conversation for user {user_id}")
except Exception as e:
logger.error(f"Failed to save conversation for user {user_id}: {e}", exc_info=True)
"""
Fashion Q&A Agent Graph
LangGraph workflow với clean architecture.
Tất cả resources (LLM, Tools) khởi tạo trong __init__.
Sử dụng ConversationManager (Postgres) để lưu history thay vì checkpoint.
"""
import logging
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableConfig
from langgraph.cache.memory import InMemoryCache
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.types import CachePolicy
from common.llm_factory import create_llm
from .models import AgentConfig, AgentState, get_config
from .prompt import get_system_prompt
from .tools.get_tools import get_all_tools, get_collection_tools
logger = logging.getLogger(__name__)
class CANIFAGraph:
"""
Fashion Q&A Agent Graph Manager.
"""
def __init__(
self,
config: AgentConfig | None = None,
llm: BaseChatModel | None = None,
tools: list | None = None,
):
self.config = config or get_config()
self._compiled_graph: Any | None = None
self.llm: BaseChatModel = llm or create_llm(
model_name=self.config.model_name, api_key=self.config.openai_api_key, streaming=True
)
self.all_tools = tools or get_all_tools()
self.collection_tools = get_collection_tools() # Vẫn lấy list name để routing
self.retrieval_tools = self.all_tools
self.llm_with_tools = self.llm.bind_tools(self.all_tools, strict=True)
self.system_prompt = get_system_prompt()
self.prompt_template = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt),
MessagesPlaceholder(variable_name="history"),
MessagesPlaceholder(variable_name="user_query"),
MessagesPlaceholder(variable_name="messages"),
]
)
self.chain = self.prompt_template | self.llm_with_tools
self.cache = InMemoryCache()
async def _agent_node(self, state: AgentState, config: RunnableConfig) -> dict:
"""Agent node - Chỉ việc đổ dữ liệu riêng vào khuôn đã có sẵn."""
messages = state.get("messages", [])
history = state.get("history", [])
user_query = state.get("user_query")
transient_images = config.get("configurable", {}).get("transient_images", [])
if transient_images and messages:
pass
# Invoke chain with user_query, history, and messages
response = await self.chain.ainvoke({
"user_query": [user_query] if user_query else [],
"history": history,
"messages": messages
})
return {"messages": [response], "ai_response": response}
def _should_continue(self, state: AgentState) -> str:
"""Routing: tool nodes hoặc end."""
last_message = state["messages"][-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
logger.info("🏁 Agent finished")
return "end"
tool_names = [tc["name"] for tc in last_message.tool_calls]
collection_names = [t.name for t in self.collection_tools]
if any(name in collection_names for name in tool_names):
logger.info(f"🔄 → collect_tools: {tool_names}")
return "collect_tools"
logger.info(f"🔄 → retrieve_tools: {tool_names}")
return "retrieve_tools"
def build(self) -> Any:
"""Build và compile LangGraph workflow."""
if self._compiled_graph is not None:
return self._compiled_graph
workflow = StateGraph(AgentState)
# Nodes
workflow.add_node("agent", self._agent_node)
workflow.add_node("retrieve_tools", ToolNode(self.retrieval_tools), cache_policy=CachePolicy(ttl=3600))
workflow.add_node("collect_tools", ToolNode(self.collection_tools))
# Edges
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
self._should_continue,
{"retrieve_tools": "retrieve_tools", "collect_tools": "collect_tools", "end": END},
)
workflow.add_edge("retrieve_tools", "agent")
workflow.add_edge("collect_tools", "agent")
self._compiled_graph = workflow.compile(cache=self.cache) # No Checkpointer
logger.info("✅ Graph compiled (Langfuse callback will be per-run)")
return self._compiled_graph
@property
def graph(self) -> Any:
return self.build()
# --- Singleton & Public API ---
_instance: list[CANIFAGraph | None] = [None]
def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None) -> Any:
"""Get compiled graph (singleton)."""
if _instance[0] is None:
_instance[0] = CANIFAGraph(config, llm, tools)
return _instance[0].build()
def get_graph_manager(
config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None
) -> CANIFAGraph:
"""Get CANIFAGraph instance."""
if _instance[0] is None:
_instance[0] = CANIFAGraph(config, llm, tools)
return _instance[0]
def reset_graph() -> None:
"""Reset singleton for testing."""
_instance[0] = None
"""
Agno Tools - Pure Python functions cho Agno Agent
Đã convert từ LangChain @tool decorator sang Agno format
"""
from .data_retrieval_tool import data_retrieval_tool
from .brand_knowledge_tool import canifa_knowledge_search
from .customer_info_tool import collect_customer_info
def get_agno_tools():
"""
Get tools cho Agno Agent.
Agno tự động convert Python functions thành tool definitions.
Returns:
List of Python functions (Agno tools)
"""
return [
data_retrieval_tool,
canifa_knowledge_search,
collect_customer_info,
]
import logging import logging
from langchain_core.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from common.embedding_service import create_embedding_async from common.embedding_service import create_embedding_async
...@@ -15,7 +14,6 @@ class KnowledgeSearchInput(BaseModel): ...@@ -15,7 +14,6 @@ class KnowledgeSearchInput(BaseModel):
) )
@tool("canifa_knowledge_search", args_schema=KnowledgeSearchInput)
async def canifa_knowledge_search(query: str) -> str: async def canifa_knowledge_search(query: str) -> str:
""" """
Tra cứu TOÀN BỘ thông tin về thương hiệu và dịch vụ của Canifa. Tra cứu TOÀN BỘ thông tin về thương hiệu và dịch vụ của Canifa.
...@@ -35,6 +33,10 @@ async def canifa_knowledge_search(query: str) -> str: ...@@ -35,6 +33,10 @@ async def canifa_knowledge_search(query: str) -> str:
- 'Cho mình xem bảng size áo nam.' - 'Cho mình xem bảng size áo nam.'
- 'Phí vận chuyển đi tỉnh là bao nhiêu?' - 'Phí vận chuyển đi tỉnh là bao nhiêu?'
- 'Canifa thành lập năm nào?' - 'Canifa thành lập năm nào?'
Args:
query: Câu hỏi hoặc nhu cầu tìm kiếm thông tin phi sản phẩm của khách hàng
(ví dụ: tìm cửa hàng, hỏi chính sách, tra bảng size...)
""" """
logger.info(f"🔍 [Semantic Search] Brand Knowledge query: {query}") logger.info(f"🔍 [Semantic Search] Brand Knowledge query: {query}")
......
...@@ -6,16 +6,14 @@ Dùng để đẩy data về CRM hoặc hệ thống lưu trữ khách hàng. ...@@ -6,16 +6,14 @@ Dùng để đẩy data về CRM hoặc hệ thống lưu trữ khách hàng.
import json import json
import logging import logging
from langchain_core.tools import tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@tool async def collect_customer_info(name: str, phone: str, email: str | None = None) -> str:
async def collect_customer_info(name: str, phone: str, email: str | None) -> str:
""" """
Sử dụng tool này để ghi lại thông tin khách hàng khi họ muốn tư vấn sâu hơn, Sử dụng tool này để ghi lại thông tin khách hàng khi họ muốn tư vấn sâu hơn,
nhận khuyến mãi hoặc đăng ký mua hàng. nhận khuyến mãi hoặc đăng ký mua hàng.
Args: Args:
name: Tên của khách hàng name: Tên của khách hàng
phone: Số điện thoại của khách hàng phone: Số điện thoại của khách hàng
......
...@@ -9,7 +9,6 @@ import logging ...@@ -9,7 +9,6 @@ import logging
import time import time
from decimal import Decimal from decimal import Decimal
from langchain_core.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from agent.tools.product_search_helpers import build_starrocks_query, save_preview_to_log from agent.tools.product_search_helpers import build_starrocks_query, save_preview_to_log
...@@ -50,8 +49,6 @@ class MultiSearchParams(BaseModel): ...@@ -50,8 +49,6 @@ class MultiSearchParams(BaseModel):
searches: list[SearchItem] = Field(..., description="Danh sách các truy vấn tìm kiếm chạy song song") searches: list[SearchItem] = Field(..., description="Danh sách các truy vấn tìm kiếm chạy song song")
@tool(args_schema=MultiSearchParams)
# @traceable(run_type="tool", name="data_retrieval_tool")
async def data_retrieval_tool(searches: list[SearchItem]) -> str: async def data_retrieval_tool(searches: list[SearchItem]) -> str:
""" """
Siêu công cụ tìm kiếm sản phẩm CANIFA - Hỗ trợ Parallel Multi-Search (Chạy song song nhiều query). Siêu công cụ tìm kiếm sản phẩm CANIFA - Hỗ trợ Parallel Multi-Search (Chạy song song nhiều query).
...@@ -86,6 +83,14 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str: ...@@ -86,6 +83,14 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str:
{"query": "Quần jean nam slim fit năng động"}, {"query": "Quần jean nam slim fit năng động"},
{"query": "Áo khoác nam thể thao trẻ trung"} {"query": "Áo khoác nam thể thao trẻ trung"}
] ]
Args:
searches: Danh sách các truy vấn tìm kiếm chạy song song. Mỗi item là SearchItem với:
- query: Mô tả sản phẩm chi tiết (bắt buộc)
- magento_ref_code: Mã sản phẩm cụ thể (nếu có)
- price_min: Giá thấp nhất (nếu có)
- price_max: Giá cao nhất (nếu có)
- action: 'search' hoặc 'visual_search'
""" """
logger.info("🔧 [DEBUG] data_retrieval_tool STARTED") logger.info("🔧 [DEBUG] data_retrieval_tool STARTED")
try: try:
......
"""
CANIFA Data Retrieval Tool - Tối giản cho Agentic Workflow.
Hỗ trợ Hybrid Search: Semantic (Vector) + Metadata Filter.
"""
import asyncio
import json
import logging
import time
from decimal import Decimal
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from agent.tools.product_search_helpers import build_starrocks_query
from common.starrocks_connection import StarRocksConnection
# from langsmith import traceable
logger = logging.getLogger(__name__)
class DecimalEncoder(json.JSONEncoder):
"""Xử lý kiểu Decimal từ Database khi convert sang JSON."""
def default(self, obj):
if isinstance(obj, Decimal):
return float(obj)
return super().default(obj)
class SearchItem(BaseModel):
"""Cấu trúc một mục tìm kiếm đơn lẻ trong Multi-Search."""
query: str = Field(
...,
description="Câu hỏi/mục đích tự do của user (đi chơi, dự tiệc, phỏng vấn,...) - dùng cho Semantic Search",
)
keywords: str | None = Field(
..., description="Từ khóa sản phẩm cụ thể (áo polo, quần jean,...) - dùng cho LIKE search"
)
magento_ref_code: str | None = Field(
..., description="Mã sản phẩm hoặc mã màu/SKU (Ví dụ: 8TS24W001 hoặc 8TS24W001-SK010)."
)
product_line_vn: str | None = Field(..., description="Dòng sản phẩm (Áo phông, Quần short,...)")
gender_by_product: str | None = Field(..., description="Giới tính: male, female")
age_by_product: str | None = Field(..., description="Độ tuổi: adult, kids, baby, others")
master_color: str | None = Field(..., description="Màu sắc chính (Đen/ Black, Trắng/ White,...)")
material_group: str | None = Field(
...,
description="Nhóm chất liệu. BẮT BUỘC dùng đúng: 'Yarn - Sợi', 'Knit - Dệt Kim', 'Woven - Dệt Thoi', 'Knit/Woven - Dệt Kim/Dệt Thoi'.",
)
season: str | None = Field(..., description="Mùa (Spring Summer, Autumn Winter)")
style: str | None = Field(..., description="Phong cách (Basic Update, Fashion,...)")
fitting: str | None = Field(..., description="Form dáng (Regular, Slim, Loose,...)")
form_neckline: str | None = Field(..., description="Kiểu cổ (Crew Neck, V-neck,...)")
form_sleeve: str | None = Field(..., description="Kiểu tay (Short Sleeve, Long Sleeve,...)")
price_min: float | None = Field(..., description="Giá thấp nhất")
price_max: float | None = Field(..., description="Giá cao nhất")
action: str = Field(..., description="Hành động: 'search' (tìm kiếm) hoặc 'visual_search' (phân tích ảnh)")
class MultiSearchParams(BaseModel):
"""Tham số cho Parallel Multi-Search."""
searches: list[SearchItem] = Field(..., description="Danh sách các truy vấn tìm kiếm chạy song song")
@tool(args_schema=MultiSearchParams)
# @traceable(run_type="tool", name="data_retrieval_tool")
async def data_retrieval_tool(searches: list[SearchItem]) -> str:
"""
Siêu công cụ tìm kiếm sản phẩm CANIFA - Hỗ trợ Parallel Multi-Search (Chạy song song nhiều query).
💡 ĐIỂM ĐẶC BIỆT:
Công cụ này cho phép thực hiện NHIỀU truy vấn tìm kiếm CÙNG LÚC.
Hãy dùng nó khi cần SO SÁNH sản phẩm hoặc tìm trọn bộ OUTFIT (mix & match).
⚠️ QUAN TRỌNG - KHI NÀO DÙNG GÌ:
1️⃣ DÙNG 'query' (Semantic Search - BUỘC PHẢI CÓ):
- Áp dụng cho mọi lượt search để cung cấp bối cảnh (context).
- Ví dụ: "áo thun nam đi biển", "quần tây công sở", "đồ cho bé màu xanh"...
2️⃣ DÙNG METADATA FILTERS (Exact/Partial Match):
- Khi khách nói rõ THUỘC TÍNH: Màu sắc, giá, giới tính, độ tuổi, mã sản phẩm.
- **QUY TẮC MÃ SẢN PHẨM:** Mọi loại mã (VD: `8TS...` hoặc `8TS...-SK...`) → Điền vào `magento_ref_code`.
- **QUY TẮC CHẤT LIÊU (material_group):** Chỉ dùng: `Yarn - Sợi`, `Knit - Dệt Kim`, `Woven - Dệt Thoi`, `Knit/Woven - Dệt Kim/Dệt Thoi`.
📝 VÍ DỤ CHI TIẾT (Single Search):
- Example 1: searches=[{"query": "áo polo nam giá dưới 400k", "keywords": "áo polo", "gender_by_product": "male", "price_max": 400000}]
- Example 2: searches=[{"query": "sản phẩm mã 8TS24W001", "magento_ref_code": "8TS24W001"}]
🚀 VÍ DỤ CẤP CAO (Multi-Search Parallel):
- Example 3 - So sánh: "So sánh áo thun nam đen và áo sơ mi trắng dưới 500k"
Tool Call: searches=[
{"query": "áo thun nam màu đen dưới 500k", "keywords": "áo thun", "master_color": "Đen", "gender_by_product": "male", "price_max": 500000},
{"query": "áo sơ mi nam trắng dưới 500k", "keywords": "áo sơ mi", "master_color": "Trắng", "gender_by_product": "male", "price_max": 500000}
]
- Example 4 - Phối đồ: "Tìm cho mình một cái quần jean và một cái áo khoác để đi chơi"
Tool Call: searches=[
{"query": "quần jean đi chơi năng động", "keywords": "quần jean"},
{"query": "áo khoác đi chơi năng động", "keywords": "áo khoác"}
]
- Example 5 - Cả gia đình: "Tìm áo phông màu xanh cho bố, mẹ và bé trai"
Tool Call: searches=[
{"query": "áo phông nam người lớn màu xanh", "keywords": "áo phông", "master_color": "Xanh", "gender_by_product": "male", "age_by_product": "adult"},
{"query": "áo phông nữ người lớn màu xanh", "keywords": "áo phông", "master_color": "Xanh", "gender_by_product": "female", "age_by_product": "adult"},
{"query": "áo phông bé trai màu xanh", "keywords": "áo phông", "master_color": "Xanh", "gender_by_product": "male", "age_by_product": "others"}
]
"""
logger.info("🔧 [DEBUG] data_retrieval_tool STARTED")
try:
logger.info("🔧 [DEBUG] Creating StarRocksConnection instance")
db = StarRocksConnection()
logger.info("🔧 [DEBUG] StarRocksConnection created successfully")
# 0. Log input parameters (Đúng ý bro)
logger.info(f"📥 [Tool Input] data_retrieval_tool received {len(searches)} items:")
for idx, item in enumerate(searches):
logger.info(f" 🔹 Item [{idx}]: {item.dict(exclude_none=True)}")
# 1. Tạo tasks chạy song song (Parallel)
logger.info("🔧 [DEBUG] Creating parallel tasks")
tasks = []
for item in searches:
tasks.append(_execute_single_search(db, item))
logger.info(f"🚀 [Parallel Search] Executing {len(searches)} queries simultaneously...")
logger.info("🔧 [DEBUG] About to call asyncio.gather()")
results = await asyncio.gather(*tasks)
logger.info(f"🔧 [DEBUG] asyncio.gather() completed with {len(results)} results")
# 2. Tổng hợp kết quả
combined_results = []
for i, products in enumerate(results):
combined_results.append(
{
"search_index": i,
"search_criteria": searches[i].dict(exclude_none=True),
"count": len(products),
"products": products,
}
)
return json.dumps({"status": "success", "results": combined_results}, ensure_ascii=False, cls=DecimalEncoder)
except Exception as e:
logger.error(f"Error in Multi-Search data_retrieval_tool: {e}")
return json.dumps({"status": "error", "message": str(e)})
async def _execute_single_search(db: StarRocksConnection, item: SearchItem) -> list[dict]:
"""Thực thi một search query đơn lẻ (Async)."""
try:
logger.info(f"🔧 [DEBUG] _execute_single_search STARTED for query: {item.query[:50] if item.query else 'None'}")
# ⏱️ Timer: Build query (bao gồm embedding nếu có)
query_build_start = time.time()
logger.info("🔧 [DEBUG] Calling build_starrocks_query()")
sql = await build_starrocks_query(item)
query_build_time = (time.time() - query_build_start) * 1000 # Convert to ms
logger.info(f"🔧 [DEBUG] SQL query built, length: {len(sql)}")
logger.info(f"⏱️ [TIMER] Query Build Time (bao gồm embedding): {query_build_time:.2f}ms")
# ⏱️ Timer: Execute DB query
db_start = time.time()
logger.info("🔧 [DEBUG] Calling db.execute_query_async()")
products = await db.execute_query_async(sql)
db_time = (time.time() - db_start) * 1000 # Convert to ms
logger.info(f"🔧 [DEBUG] Query executed, got {len(products)} products")
logger.info(f"⏱️ [TIMER] DB Query Execution Time: {db_time:.2f}ms")
logger.info(f"⏱️ [TIMER] Total Time (Build + DB): {query_build_time + db_time:.2f}ms")
return _format_product_results(products)
except Exception as e:
logger.error(f"Single search error for item {item}: {e}")
return []
def _format_product_results(products: list[dict]) -> list[dict]:
"""Lọc và format kết quả trả về cho Agent."""
allowed_fields = {
"internal_ref_code",
"description_text_full",
}
return [{k: v for k, v in p.items() if k in allowed_fields} for p in products[:5]]
# import logging
# from common.embedding_service import create_embedding_async
# logger = logging.getLogger(__name__)
# def _escape(val: str) -> str:
# """Thoát dấu nháy đơn để tránh SQL Injection cơ bản."""
# return val.replace("'", "''")
# def _get_where_clauses(params) -> list[str]:
# """
# Xây dựng WHERE clauses theo thứ tự ưu tiên dựa trên selectivity thực tế
# FILTER PRIORITY (Based on Canifa catalog analysis):
# 🔥 TIER 1 (99% selectivity):
# 1. SKU Code → 1-5 records
# 🎯 TIER 2 (50-70% selectivity):
# 2. Gender → Splits catalog in half
# 3. Age → Kids vs Adults split
# 4. Product Category → 10-15 categories
# 💎 TIER 3 (30-50% selectivity):
# 5. Material Group → Knit vs Woven (2 groups)
# 6. Price Range → Numeric filtering
# 🎨 TIER 4 (10-30% selectivity):
# 7. Season → 4 seasons
# 8. Style/Fitting → Multiple options
# ⚠️ TIER 5 (<10% selectivity):
# 9. Form details → Granular attributes
# 10. Color → LOWEST selectivity (many SKUs share colors)
# Early return: If SKU exists, skip low-selectivity filters
# """
# clauses = []
# # 🔥 TIER 1: SKU/Product Code (Unique identifier)
# # Selectivity: ~99% → 1 SKU = 1 style (3-5 colors max)
# sku_clause = _get_sku_clause(params)
# if sku_clause:
# clauses.append(sku_clause)
# # Early return optimization: SKU đã xác định product rõ ràng
# # CHỈ GIỮ LẠI price filter (nếu có) để verify budget constraint
# # BỎ QUA: gender, color, style, fitting... vì SKU đã unique
# price_clauses = _get_price_clauses(params)
# if price_clauses:
# clauses.extend(price_clauses)
# return clauses # ⚡ STOP - Không thêm filter khác!
# # 🎯 TIER 2: High-level categorization (50-70% reduction)
# # Gender + Age + Category có selectivity cao nhất trong non-SKU filters
# clauses.extend(_get_high_selectivity_clauses(params))
# # 💎 TIER 3: Material & Price (30-50% reduction)
# material_clause = _get_material_clause(params)
# if material_clause:
# clauses.append(material_clause)
# clauses.extend(_get_price_clauses(params))
# # 🎨 TIER 4: Attributes (10-30% reduction)
# clauses.extend(_get_attribute_clauses(params))
# # ⚠️ TIER 5: Granular details & Color (LAST - lowest selectivity)
# clauses.extend(_get_form_detail_clauses(params))
# color_clause = _get_color_clause(params)
# if color_clause:
# clauses.append(color_clause) # Color ALWAYS LAST!
# return clauses
# def _get_sku_clause(params) -> str | None:
# """
# TIER 1: SKU/Product Code (Highest selectivity - 99%)
# 1 SKU code = 1 product style (may have 3-5 color variants)
# WHY SKU is always priority #1:
# - 1 code = 1 unique product design
# - Adding other filters (color, style, gender) is redundant
# - Only price filter may be kept for budget validation
# Example queries:
# - "Mã 6OT25W010" → Only SKU needed
# - "Mã 6OT25W010 màu xám" → Only SKU (color is for display/selection, not filtering)
# - "Mã 6OT25W010 dưới 500k" → SKU + price (validate budget)
# """
# m_code = getattr(params, "magento_ref_code", None)
# if m_code:
# m = _escape(m_code)
# return f"(magento_ref_code = '{m}' OR internal_ref_code = '{m}')"
# return None
# def _get_color_clause(params) -> str | None:
# """
# TIER 5: Color (LOWEST selectivity - 5-10%)
# Multiple SKUs share the same color (e.g., 50+ gray products)
# ALWAYS filter color LAST after other constraints
# """
# color = getattr(params, "master_color", None)
# if color:
# c = _escape(color).lower()
# return f"(LOWER(master_color) LIKE '%{c}%' OR LOWER(product_color_name) LIKE '%{c}%')"
# return None
# def _get_high_selectivity_clauses(params) -> list[str]:
# """
# TIER 2: High-level categorization (50-70% reduction per filter)
# Order: Gender → Age → Product Category
# """
# clauses = []
# # Gender: Male/Female/Unisex split (50-70% reduction)
# gender = getattr(params, "gender_by_product", None)
# if gender:
# clauses.append(f"gender_by_product = '{_escape(gender)}'")
# # Age: Kids/Adults split (50% reduction of remaining)
# age = getattr(params, "age_by_product", None)
# if age:
# clauses.append(f"age_by_product = '{_escape(age)}'")
# # Product Category: Váy/Áo/Quần... (30-50% reduction)
# product_line = getattr(params, "product_line_vn", None)
# if product_line:
# p = _escape(product_line).lower()
# clauses.append(f"LOWER(product_line_vn) LIKE '%{p}%'")
# return clauses
# def _get_material_clause(params) -> str | None:
# """TIER 3: Material Group - Knit vs Woven (50% split)"""
# material = getattr(params, "material_group", None)
# if material:
# m = _escape(material).lower()
# return f"LOWER(material_group) LIKE '%{m}%'"
# return None
# def _get_price_clauses(params) -> list[str]:
# """TIER 3: Price Range - Numeric filtering (30-40% reduction)"""
# clauses = []
# p_min = getattr(params, "price_min", None)
# if p_min is not None:
# clauses.append(f"sale_price >= {p_min}")
# p_max = getattr(params, "price_max", None)
# if p_max is not None:
# clauses.append(f"sale_price <= {p_max}")
# return clauses
# def _get_attribute_clauses(params) -> list[str]:
# """
# TIER 4: Attributes (10-30% reduction)
# Season, Style, Fitting
# """
# clauses = []
# # Season: 4 seasons (~25% each)
# season = getattr(params, "season", None)
# if season:
# s = _escape(season).lower()
# clauses.append(f"LOWER(season) LIKE '%{s}%'")
# # Style: Basic/Feminine/Sporty... (~15-20% reduction)
# style = getattr(params, "style", None)
# if style:
# st = _escape(style).lower()
# clauses.append(f"LOWER(style) LIKE '%{st}%'")
# # Fitting: Regular/Slim/Loose (~15% reduction)
# fitting = getattr(params, "fitting", None)
# if fitting:
# f = _escape(fitting).lower()
# clauses.append(f"LOWER(fitting) LIKE '%{f}%'")
# # Size Scale: S, M, L, 29, 30... (Specific filtering)
# size = getattr(params, "size_scale", None)
# if size:
# sz = _escape(size).lower()
# clauses.append(f"LOWER(size_scale) LIKE '%{sz}%'")
# return clauses
# def _get_form_detail_clauses(params) -> list[str]:
# """
# TIER 5: Granular form details (<10% reduction each)
# Neckline, Sleeve type
# """
# clauses = []
# form_fields = [
# ("form_neckline", "form_neckline"),
# ("form_sleeve", "form_sleeve"),
# ]
# for param_name, col_name in form_fields:
# val = getattr(params, param_name, None)
# if val:
# v = _escape(val).lower()
# clauses.append(f"LOWER({col_name}) LIKE '%{v}%'")
# return clauses
# async def build_starrocks_query(params, query_vector: list[float] | None = None) -> str:
# """
# Build SQL Hybrid tối ưu với Filter Priority:
# 1. Pre-filtering theo độ ưu tiên (SKU → Exact → Price → Partial)
# 2. Vector Search (HNSW Index) - Semantic understanding
# 3. Flexible Keyword Search (OR + Scoring) - Fuzzy matching fallback
# 4. Grouping (Gom màu theo style)
# """
# # --- Process vector in query field ---
# query_text = getattr(params, "query", None)
# # if query_text and query_vector is None:
# # query_vector = await create_embedding_async(query_text)
# # --- Build filter clauses (OPTIMIZED ORDER) ---
# where_clauses = _get_where_clauses(params)
# where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
# # --- Build SQL ---
# if query_vector and len(query_vector) > 0:
# v_str = "[" + ",".join(str(v) for v in query_vector) + "]"
# sql = f"""
# WITH top_sku_candidates AS (
# SELECT
# approx_cosine_similarity(vector, {v_str}) as similarity_score,
# internal_ref_code,
# product_name,
# sale_price,
# original_price,
# master_color,
# product_image_url,
# product_image_url_thumbnail,
# product_web_url,
# description_text,
# material,
# material_group,
# gender_by_product,
# age_by_product,
# season,
# style,
# fitting,
# form_neckline,
# form_sleeve,
# product_line_vn,
# product_color_name
# FROM shared_source.magento_product_dimension_with_text_embedding
# WHERE {where_sql} AND vector IS NOT NULL
# ORDER BY similarity_score DESC
# LIMIT 50
# )
# SELECT
# internal_ref_code,
# ANY_VALUE(product_name) as product_name,
# ANY_VALUE(sale_price) as sale_price,
# ANY_VALUE(original_price) as original_price,
# GROUP_CONCAT(DISTINCT master_color ORDER BY master_color SEPARATOR ', ') as available_colors,
# ANY_VALUE(product_image_url) as product_image_url,
# ANY_VALUE(product_image_url_thumbnail) as product_image_url_thumbnail,
# ANY_VALUE(product_web_url) as product_web_url,
# ANY_VALUE(description_text) as description_text,
# ANY_VALUE(material) as material,
# ANY_VALUE(material_group) as material_group,
# ANY_VALUE(gender_by_product) as gender_by_product,
# ANY_VALUE(age_by_product) as age_by_product,
# ANY_VALUE(season) as season,
# ANY_VALUE(style) as style,
# ANY_VALUE(fitting) as fitting,
# ANY_VALUE(form_neckline) as form_neckline,
# ANY_VALUE(form_sleeve) as form_sleeve,
# ANY_VALUE(product_line_vn) as product_line_vn,
# MAX(similarity_score) as max_score
# FROM top_sku_candidates
# GROUP BY internal_ref_code
# ORDER BY max_score DESC
# LIMIT 10
# """
# else:
# # ⚡ FALLBACK: FLEXIBLE KEYWORD SEARCH (OR + SCORING)
# # Giải quyết case: User search "áo khoác nỉ" → DB có "Áo nỉ nam"
# keywords = getattr(params, "keywords", None)
# keyword_score_sql = ""
# keyword_filter = ""
# if keywords:
# k_clean = _escape(keywords).lower().strip()
# if k_clean:
# words = k_clean.split()
# # Build scoring expression: Each matched word = +1 point
# # Example: "áo khoác nỉ" (3 words)
# # - "Áo nỉ nam" matches 2/3 → Score = 2
# # - "Áo khoác nỉ hoodie" matches 3/3 → Score = 3
# score_terms = [
# f"(CASE WHEN LOWER(product_name) LIKE '%{w}%' THEN 1 ELSE 0 END)"
# for w in words
# ]
# keyword_score_sql = f"({' + '.join(score_terms)}) as keyword_match_score"
# # Minimum threshold: At least 50% of words must match
# # Example: 3 words → need at least 2 matches (66%)
# # 2 words → need at least 1 match (50%)
# min_matches = max(1, len(words) // 2)
# keyword_filter = f" AND ({' + '.join(score_terms)}) >= {min_matches}"
# # Select clause with optional scoring
# select_score = f", {keyword_score_sql}" if keyword_score_sql else ""
# order_by = "keyword_match_score DESC, sale_price ASC" if keyword_score_sql else "sale_price ASC"
# sql = f"""
# SELECT
# internal_ref_code,
# ANY_VALUE(product_name) as product_name,
# ANY_VALUE(sale_price) as sale_price,
# ANY_VALUE(original_price) as original_price,
# GROUP_CONCAT(DISTINCT master_color ORDER BY master_color SEPARATOR ', ') as available_colors,
# ANY_VALUE(product_image_url) as product_image_url,
# ANY_VALUE(product_image_url_thumbnail) as product_image_url_thumbnail,
# ANY_VALUE(product_web_url) as product_web_url,
# ANY_VALUE(description_text) as description_text,
# ANY_VALUE(material) as material,
# ANY_VALUE(material_group) as material_group,
# ANY_VALUE(gender_by_product) as gender_by_product,
# ANY_VALUE(age_by_product) as age_by_product,
# ANY_VALUE(season) as season,
# ANY_VALUE(style) as style,
# ANY_VALUE(fitting) as fitting,
# ANY_VALUE(form_neckline) as form_neckline,
# ANY_VALUE(form_sleeve) as form_sleeve,
# ANY_VALUE(product_line_vn) as product_line_vn
# {select_score}
# FROM shared_source.magento_product_dimension_with_text_embedding
# WHERE {where_sql} {keyword_filter}
# GROUP BY internal_ref_code
# HAVING COUNT(*) > 0
# ORDER BY {order_by}
# LIMIT 10
# """
# # Log filter statistics
# filter_info = f"Mode: {'Vector' if query_vector else 'Keyword'}, Filters: {len(where_clauses)}"
# if where_clauses:
# # Identify high-priority filters used
# has_sku = any('internal_ref_code' in c or 'magento_ref_code' in c for c in where_clauses)
# has_gender = any('gender_by_product' in c for c in where_clauses)
# has_category = any('product_line_vn' in c for c in where_clauses)
# priority_info = []
# if has_sku:
# priority_info.append("SKU")
# if has_gender:
# priority_info.append("Gender")
# if has_category:
# priority_info.append("Category")
# if priority_info:
# filter_info += f", Priority: {'+'.join(priority_info)}"
# logger.info(f"📊 {filter_info}")
# # Write SQL to file for debugging
# try:
# with open(r"d:\cnf\chatbot_canifa\backend\embedding.txt", "w", encoding="utf-8") as f:
# f.write(sql)
# except Exception as e:
# logger.error(f"Failed to write SQL to embedding.txt: {e}")
# return sql
import logging
import time
from common.embedding_service import create_embedding_async
logger = logging.getLogger(__name__)
def _escape(val: str) -> str:
"""Thoát dấu nháy đơn để tránh SQL Injection cơ bản."""
return val.replace("'", "''")
def _get_where_clauses(params) -> list[str]:
"""Xây dựng danh sách các điều kiện lọc từ params."""
clauses = []
clauses.extend(_get_price_clauses(params))
clauses.extend(_get_metadata_clauses(params))
clauses.extend(_get_special_clauses(params))
return clauses
def _get_price_clauses(params) -> list[str]:
"""Lọc theo giá."""
clauses = []
p_min = getattr(params, "price_min", None)
if p_min is not None:
clauses.append(f"sale_price >= {p_min}")
p_max = getattr(params, "price_max", None)
if p_max is not None:
clauses.append(f"sale_price <= {p_max}")
return clauses
def _get_metadata_clauses(params) -> list[str]:
"""Xây dựng điều kiện lọc từ metadata (Phối hợp Exact và Partial)."""
clauses = []
# 1. Exact Match (Giới tính, Độ tuổi) - Các trường này cần độ chính xác tuyệt đối
exact_fields = [
("gender_by_product", "gender_by_product"),
("age_by_product", "age_by_product"),
]
for param_name, col_name in exact_fields:
val = getattr(params, param_name, None)
if val:
clauses.append(f"{col_name} = '{_escape(val)}'")
# 2. Partial Match (LIKE) - Giúp map text linh hoạt hơn (Chất liệu, Dòng SP, Phong cách...)
# Cái này giúp map: "Yarn" -> "Yarn - Sợi", "Knit" -> "Knit - Dệt Kim"
partial_fields = [
("season", "season"),
("material_group", "material_group"),
("product_line_vn", "product_line_vn"),
("style", "style"),
("fitting", "fitting"),
("form_neckline", "form_neckline"),
("form_sleeve", "form_sleeve"),
]
for param_name, col_name in partial_fields:
val = getattr(params, param_name, None)
if val:
v = _escape(val).lower()
# Dùng LOWER + LIKE để cân mọi loại ký tự thừa hoặc hoa/thường
clauses.append(f"LOWER({col_name}) LIKE '%{v}%'")
return clauses
def _get_special_clauses(params) -> list[str]:
"""Các trường hợp đặc biệt: Mã sản phẩm, Màu sắc."""
clauses = []
# Mã sản phẩm / SKU
m_code = getattr(params, "magento_ref_code", None)
if m_code:
m = _escape(m_code)
clauses.append(f"(magento_ref_code = '{m}' OR internal_ref_code = '{m}')")
# Màu sắc
color = getattr(params, "master_color", None)
if color:
c = _escape(color).lower()
clauses.append(f"(LOWER(master_color) LIKE '%{c}%' OR LOWER(product_color_name) LIKE '%{c}%')")
return clauses
async def build_starrocks_query(params, query_vector: list[float] | None = None) -> str:
"""
Build SQL Hybrid tối ưu với POST-FILTERING Strategy & Anti-Duplication.
🔥 CHIẾN LƯỢC TỐI ƯU:
1. Vector Search TRƯỚC (LIMIT 100) để tận dụng HNSW Index (tốc độ ~50ms).
2. JOIN chính xác theo (code + màu) để tránh bùng nổ dữ liệu (Data Explosion).
3. Dùng MAX_BY để lấy description của đúng thằng có score cao nhất.
"""
logger.info("🔧 [DEBUG] build_starrocks_query STARTED")
# --- 1. Xử lý Vector ---
query_text = getattr(params, "query", None)
if query_text and query_vector is None:
emb_start = time.time()
query_vector = await create_embedding_async(query_text)
emb_time = (time.time() - emb_start) * 1000
logger.info(f"⏱️ [TIMER] Embedding Generation: {emb_time:.2f}ms")
# --- 2. Xây dựng Filter cho POST-FILTERING ---
where_clauses = _get_where_clauses(params)
post_filter_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
# --- 3. Build SQL ---
if query_vector and len(query_vector) > 0:
v_str = "[" + ",".join(str(v) for v in query_vector) + "]"
# Alias các trường trong filter sang bảng t2 để tránh lỗi ambiguous
post_filter_aliased = post_filter_sql
fields_to_alias = [
"sale_price",
"gender_by_product",
"age_by_product",
"material_group",
"season",
"style",
"fitting",
"form_neckline",
"form_sleeve",
"product_line_vn",
"magento_ref_code",
"internal_ref_code",
"master_color",
"product_color_name",
]
for field in fields_to_alias:
post_filter_aliased = post_filter_aliased.replace(field, f"t2.{field}")
sql = f"""
WITH top_candidates AS (
SELECT /*+ SET_VAR(ann_params='{{"ef_search":64}}') */
internal_ref_code,
product_color_code,
approx_cosine_similarity(vector, {v_str}) as similarity_score
FROM shared_source.magento_product_dimension_with_text_embedding
WHERE vector IS NOT NULL
ORDER BY similarity_score DESC
LIMIT 100
)
SELECT
t1.internal_ref_code,
-- MAX_BY đảm bảo mô tả đi kèm đúng với thằng cao điểm nhất (Data Integrity)
MAX_BY(t2.description_text_full, t1.similarity_score) as description_text_full,
MAX(t1.similarity_score) as max_score
FROM top_candidates t1
JOIN shared_source.magento_product_dimension_with_text_embedding t2
ON t1.internal_ref_code = t2.internal_ref_code
AND t1.product_color_code = t2.product_color_code -- QUAN TRỌNG: Tránh nhân bản dòng theo màu
WHERE {post_filter_aliased}
GROUP BY t1.internal_ref_code
ORDER BY max_score DESC
LIMIT 10
"""
else:
# FALLBACK: Keyword search
keywords = getattr(params, "keywords", None)
k_filter = ""
if keywords:
k = _escape(keywords).lower()
k_filter = f" AND LOWER(product_name) LIKE '%{k}%'"
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
sql = f"""
SELECT
internal_ref_code,
-- Lấy đại diện 1 mô tả cho keyword search
MAX(description_text_full) as description_text_full,
MIN(sale_price) as min_price
FROM shared_source.magento_product_dimension_with_text_embedding
WHERE {where_sql} {k_filter}
GROUP BY internal_ref_code
ORDER BY min_price ASC
LIMIT 10
"""
# --- 4. Ghi Log Debug ---
try:
debug_path = r"d:\cnf\chatbot_canifa\backend\query.txt"
with open(debug_path, "w", encoding="utf-8") as f:
f.write(sql)
logger.info(f"💾 SQL saved to: {debug_path}")
except Exception as e:
logger.error(f"Save log failed: {e}")
return sql
...@@ -9,7 +9,7 @@ import logging ...@@ -9,7 +9,7 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import APIRouter, BackgroundTasks, HTTPException
from opentelemetry import trace from opentelemetry import trace
from agent.controller import chat_controller from agent.agno_controller import chat_controller
from agent.models import QueryRequest from agent.models import QueryRequest
from config import DEFAULT_MODEL from config import DEFAULT_MODEL
......
...@@ -2,54 +2,106 @@ import logging ...@@ -2,54 +2,106 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from psycopg import sql
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
from config import CHECKPOINT_POSTGRES_URL from config import CHECKPOINT_POSTGRES_URL
# Runtime imports with fallback
try:
from agno.db.base import BaseDb
from agno.models import Message # type: ignore[import-untyped]
except ImportError:
# Create stub class if agno not installed
class BaseDbStub: # type: ignore
pass
# Create a simple Message-like class for when Agno is not available
class MessageStub: # type: ignore
def __init__(self, role: str, content: str, created_at: Any = None):
self.role = role
self.content = content
self.created_at = created_at
BaseDb = BaseDbStub # type: ignore
Message = MessageStub # type: ignore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ConversationManager: # Use composition instead of inheritance to avoid implementing all BaseDb methods
class ConversationManager: # Don't inherit BaseDb directly
"""
Conversation Manager với Agno BaseDb interface.
Hỗ trợ cả legacy methods và Agno Agent.
"""
def __init__( def __init__(
self, self,
connection_url: str = CHECKPOINT_POSTGRES_URL, connection_url: str | None = None,
table_name: str = "langgraph_chat_histories", table_name: str = "langgraph_chat_histories",
): ):
self.connection_url = connection_url self.connection_url: str = connection_url or CHECKPOINT_POSTGRES_URL or ""
if not self.connection_url:
raise ValueError("connection_url is required")
self.table_name = table_name self.table_name = table_name
self._pool: AsyncConnectionPool | None = None self._pool: AsyncConnectionPool | None = None
async def _get_pool(self) -> AsyncConnectionPool: async def _get_pool(self) -> AsyncConnectionPool:
"""Get or create async connection pool.""" """Get or create async connection pool với config hợp lý."""
if self._pool is None: if self._pool is None:
self._pool = AsyncConnectionPool(self.connection_url, open=False) # Pool config: min_size=1, max_size=5, timeout=10s
self._pool = AsyncConnectionPool(
self.connection_url,
min_size=1,
max_size=5,
timeout=10.0, # 10s timeout thay vì default 30s
open=False,
)
try:
await self._pool.open() await self._pool.open()
logger.info(f"✅ PostgreSQL connection pool opened: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}")
except Exception as e:
logger.error(f"❌ Failed to open PostgreSQL pool: {e}")
self._pool = None
raise
return self._pool return self._pool
async def initialize_table(self): async def initialize_table(self):
"""Create the chat history table if it doesn't exist""" """Create the chat history table if it doesn't exist"""
try: try:
logger.info(f"🔌 Initializing PostgreSQL table: {self.table_name}")
pool = await self._get_pool() pool = await self._get_pool()
async with pool.connection() as conn:
# Use connection với timeout ngắn hơn
async with pool.connection(timeout=5.0) as conn: # 5s timeout cho connection
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(f""" await cursor.execute(
CREATE TABLE IF NOT EXISTS {self.table_name} ( sql.SQL("""
CREATE TABLE IF NOT EXISTS {} (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
user_id VARCHAR(255) NOT NULL, user_id VARCHAR(255) NOT NULL,
message TEXT NOT NULL, message TEXT NOT NULL,
is_human BOOLEAN NOT NULL, is_human BOOLEAN NOT NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) )
""") """).format(sql.Identifier(self.table_name))
)
await cursor.execute(f""" await cursor.execute(
CREATE INDEX IF NOT EXISTS idx_{self.table_name}_user_timestamp sql.SQL("""
ON {self.table_name} (user_id, timestamp) CREATE INDEX IF NOT EXISTS idx_{}_user_timestamp
""") ON {} (user_id, timestamp)
""").format(
sql.Identifier(self.table_name),
sql.Identifier(self.table_name),
)
)
await conn.commit() await conn.commit()
logger.info(f"Table {self.table_name} initialized successfully") logger.info(f"Table {self.table_name} initialized successfully")
except Exception as e: except Exception as e:
logger.error(f"Error initializing table: {e}") logger.error(f"❌ Error initializing table: {e}")
logger.error(f" Connection URL: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}")
raise raise
async def save_conversation_turn(self, user_id: str, human_message: str, ai_message: str): async def save_conversation_turn(self, user_id: str, human_message: str, ai_message: str):
...@@ -60,8 +112,10 @@ class ConversationManager: ...@@ -60,8 +112,10 @@ class ConversationManager:
async with pool.connection() as conn: async with pool.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
f"""INSERT INTO {self.table_name} (user_id, message, is_human, timestamp) sql.SQL("""
VALUES (%s, %s, %s, %s), (%s, %s, %s, %s)""", INSERT INTO {} (user_id, message, is_human, timestamp)
VALUES (%s, %s, %s, %s), (%s, %s, %s, %s)
""").format(sql.Identifier(self.table_name)),
( (
user_id, user_id,
human_message, human_message,
...@@ -84,23 +138,25 @@ class ConversationManager: ...@@ -84,23 +138,25 @@ class ConversationManager:
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Retrieve chat history for a user using cursor-based pagination.""" """Retrieve chat history for a user using cursor-based pagination."""
try: try:
query = f""" base_query = sql.SQL("SELECT message, is_human, timestamp, id FROM {} WHERE user_id = %s").format(
SELECT message, is_human, timestamp, id sql.Identifier(self.table_name)
FROM {self.table_name} )
WHERE user_id = %s params: list[Any] = [user_id]
"""
params = [user_id] query_parts: list[sql.Composable] = [base_query]
if before_id: if before_id:
query += " AND id < %s" query_parts.append(sql.SQL(" AND id < %s"))
params.append(before_id) params.append(before_id)
query += " ORDER BY id DESC" query_parts.append(sql.SQL(" ORDER BY id DESC"))
if limit: if limit:
query += " LIMIT %s" query_parts.append(sql.SQL(" LIMIT %s"))
params.append(limit) params.append(limit)
query = sql.Composed(query_parts)
pool = await self._get_pool() pool = await self._get_pool()
async with pool.connection() as conn, conn.cursor() as cursor: async with pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(query, tuple(params)) await cursor.execute(query, tuple(params))
...@@ -125,7 +181,12 @@ class ConversationManager: ...@@ -125,7 +181,12 @@ class ConversationManager:
pool = await self._get_pool() pool = await self._get_pool()
async with pool.connection() as conn: async with pool.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(f"DELETE FROM {self.table_name} WHERE user_id = %s", (user_id,)) await cursor.execute(
sql.SQL("DELETE FROM {} WHERE user_id = %s").format(
sql.Identifier(self.table_name)
),
(user_id,),
)
await conn.commit() await conn.commit()
logger.info(f"Cleared chat history for user {user_id}") logger.info(f"Cleared chat history for user {user_id}")
except Exception as e: except Exception as e:
...@@ -136,7 +197,11 @@ class ConversationManager: ...@@ -136,7 +197,11 @@ class ConversationManager:
try: try:
pool = await self._get_pool() pool = await self._get_pool()
async with pool.connection() as conn, conn.cursor() as cursor: async with pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(f"SELECT COUNT(DISTINCT user_id) FROM {self.table_name}") await cursor.execute(
sql.SQL("SELECT COUNT(DISTINCT user_id) FROM {}").format(
sql.Identifier(self.table_name)
)
)
result = await cursor.fetchone() result = await cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0
except Exception as e: except Exception as e:
...@@ -147,6 +212,132 @@ class ConversationManager: ...@@ -147,6 +212,132 @@ class ConversationManager:
"""Close the connection pool""" """Close the connection pool"""
if self._pool: if self._pool:
await self._pool.close() await self._pool.close()
self._pool = None
# ========== Agno BaseDb Interface Methods ==========
# Giữ nguyên methods cũ ở trên để backward compatible
async def initialize(self):
"""Agno interface: Initialize table (alias của initialize_table)"""
return await self.initialize_table()
async def load_history(self, session_id: str, limit: int = 20) -> list[Any]:
"""
Agno interface: Load history và convert sang Agno Message format.
Reuse code từ get_chat_history().
Args:
session_id: User ID (Agno dùng session_id, map với user_id)
limit: Số messages tối đa
Returns:
List of Agno Message objects
"""
try:
# Reuse method cũ
history_dicts = await self.get_chat_history(user_id=session_id, limit=limit)
# Convert từ DB format → Agno Message format
messages = []
for h in reversed(history_dicts): # Reverse để chronological order
role = "user" if h["is_human"] else "assistant"
agno_message = Message(
role=role,
content=h["message"],
created_at=h["timestamp"],
)
messages.append(agno_message)
logger.debug(f"📥 [Agno] Loaded {len(messages)} messages for session {session_id}")
return messages
except Exception as e:
logger.error(f"❌ [Agno] Error loading history for {session_id}: {e}")
return []
async def save_message(self, session_id: str, message: Any):
"""
Agno interface: Save single message.
Args:
session_id: User ID
message: Agno Message object
"""
try:
pool = await self._get_pool()
is_human = message.role == "user"
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
sql.SQL("""
INSERT INTO {} (user_id, message, is_human, timestamp)
VALUES (%s, %s, %s, %s)
""").format(sql.Identifier(self.table_name)),
(
session_id,
message.content,
is_human,
message.created_at or datetime.now(),
),
)
await conn.commit()
logger.debug(f"💾 [Agno] Saved message for session {session_id}")
except Exception as e:
logger.error(f"❌ [Agno] Error saving message for {session_id}: {e}", exc_info=True)
raise
async def save_session(self, session_id: str, messages: list[Any]):
"""
Agno interface: Save multiple messages (batch).
Args:
session_id: User ID
messages: List of Agno Message objects
"""
try:
pool = await self._get_pool()
timestamp = datetime.now()
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# Batch insert
values = []
for msg in messages:
is_human = msg.role == "user"
values.append(
(
session_id,
msg.content,
is_human,
msg.created_at or timestamp,
)
)
await cursor.executemany(
sql.SQL("""
INSERT INTO {} (user_id, message, is_human, timestamp)
VALUES (%s, %s, %s, %s)
""").format(sql.Identifier(self.table_name)),
values,
)
await conn.commit()
logger.debug(f"💾 [Agno] Saved {len(messages)} messages for session {session_id}")
except Exception as e:
logger.error(f"❌ [Agno] Error saving session for {session_id}: {e}", exc_info=True)
raise
async def get_session_messages(self, session_id: str) -> list[Any]:
"""Agno interface: Get all messages for a session"""
return await self.load_history(session_id, limit=1000)
async def clear_session(self, session_id: str):
"""Agno interface: Clear session (alias của clear_history)"""
return await self.clear_history(session_id)
# ConversationManager implements BaseDb interface methods
# but doesn't inherit BaseDb to avoid implementing all abstract methods
# Agno will accept it as long as it has the required methods
# --- Singleton --- # --- Singleton ---
...@@ -157,6 +348,12 @@ async def get_conversation_manager() -> ConversationManager: ...@@ -157,6 +348,12 @@ async def get_conversation_manager() -> ConversationManager:
"""Get or create async ConversationManager singleton""" """Get or create async ConversationManager singleton"""
global _instance global _instance
if _instance is None: if _instance is None:
try:
_instance = ConversationManager() _instance = ConversationManager()
await _instance.initialize_table() await _instance.initialize_table()
except Exception as e:
logger.error(f"❌ Failed to initialize ConversationManager: {e}")
# Reset instance để retry lần sau
_instance = None
raise
return _instance return _instance
""" """
Simple Langfuse Client Wrapper Langfuse Client với OpenInference instrumentation cho Agno
Minimal setup using langfuse.langchain module Tự động trace tất cả Agno calls (LLM, tools, agent runs)
With propagate_attributes for proper user_id tracking
""" """
import asyncio import asyncio
import base64
import logging import logging
import os import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
...@@ -19,6 +19,21 @@ from config import ( ...@@ -19,6 +19,21 @@ from config import (
LANGFUSE_SECRET_KEY, LANGFUSE_SECRET_KEY,
) )
# OpenInference imports (optional - only if available)
_OPENINFERENCE_AVAILABLE = False
AgnoInstrumentor = None # type: ignore
try:
from openinference.instrumentation.agno import AgnoInstrumentor # type: ignore[import-untyped]
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
_OPENINFERENCE_AVAILABLE = True
except ImportError:
pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ⚡ Global state for async batch export # ⚡ Global state for async batch export
...@@ -31,9 +46,10 @@ _batch_lock = asyncio.Lock if hasattr(asyncio, "Lock") else None ...@@ -31,9 +46,10 @@ _batch_lock = asyncio.Lock if hasattr(asyncio, "Lock") else None
def initialize_langfuse() -> bool: def initialize_langfuse() -> bool:
""" """
1. Set environment variables 1. Setup OpenInference instrumentation cho Agno (nếu available)
2. Initialize Langfuse client 2. Configure OTLP exporter để gửi traces đến Langfuse
3. Setup thread pool for async batch export 3. Initialize Langfuse client (fallback)
4. Register shutdown handler
""" """
global _langfuse_client, _export_executor global _langfuse_client, _export_executor
...@@ -44,27 +60,95 @@ def initialize_langfuse() -> bool: ...@@ -44,27 +60,95 @@ def initialize_langfuse() -> bool:
# Set environment # Set environment
os.environ["LANGFUSE_PUBLIC_KEY"] = LANGFUSE_PUBLIC_KEY os.environ["LANGFUSE_PUBLIC_KEY"] = LANGFUSE_PUBLIC_KEY
os.environ["LANGFUSE_SECRET_KEY"] = LANGFUSE_SECRET_KEY os.environ["LANGFUSE_SECRET_KEY"] = LANGFUSE_SECRET_KEY
os.environ["LANGFUSE_BASE_URL"] = LANGFUSE_BASE_URL or "https://cloud.langfuse.com" base_url = LANGFUSE_BASE_URL or "https://cloud.langfuse.com"
os.environ["LANGFUSE_TIMEOUT"] = "10" # 10s timeout, not blocking os.environ["LANGFUSE_BASE_URL"] = base_url
os.environ["LANGFUSE_TIMEOUT"] = "10"
# Disable default flush to prevent blocking os.environ["LANGFUSE_FLUSHINTERVAL"] = "300"
os.environ["LANGFUSE_FLUSHINTERVAL"] = "300" # 5 min, very infrequent
try: try:
# ========== Setup OpenInference cho Agno ==========
global _OPENINFERENCE_AVAILABLE
if _OPENINFERENCE_AVAILABLE:
try:
# Determine Langfuse OTLP endpoint
if "localhost" in base_url or "127.0.0.1" in base_url:
otlp_endpoint = f"{base_url}/api/public/otel"
elif "us.cloud" in base_url:
otlp_endpoint = "https://us.cloud.langfuse.com/api/public/otel"
elif "eu.cloud" in base_url:
otlp_endpoint = "https://eu.cloud.langfuse.com/api/public/otel"
else:
# Custom deployment
otlp_endpoint = f"{base_url}/api/public/otel"
# Create auth header
langfuse_auth = base64.b64encode(
f"{LANGFUSE_PUBLIC_KEY}:{LANGFUSE_SECRET_KEY}".encode()
).decode()
# Set OTLP environment variables
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"Authorization=Basic {langfuse_auth}"
# Configure TracerProvider
tracer_provider = TracerProvider()
tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter()))
trace_api.set_tracer_provider(tracer_provider=tracer_provider)
# Instrument Agno
if AgnoInstrumentor:
AgnoInstrumentor().instrument()
logger.info(f"✅ OpenInference instrumentation enabled for Agno")
logger.info(f" → Sending traces to: {otlp_endpoint}")
except Exception as e:
logger.warning(f"⚠️ Failed to setup OpenInference: {e}. Falling back to Langfuse SDK.")
_OPENINFERENCE_AVAILABLE = False
# ========== Fallback: Langfuse SDK ==========
if not _OPENINFERENCE_AVAILABLE:
_langfuse_client = get_client() _langfuse_client = get_client()
_export_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="langfuse_export") _export_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="langfuse_export")
if _langfuse_client.auth_check(): # Register shutdown handler
logger.info("✅ Langfuse Ready! (async batch export)") import atexit
atexit.register(shutdown_langfuse)
logger.info(f"✅ Langfuse initialized (BASE_URL: {base_url})")
return True return True
logger.error("❌ Langfuse auth failed")
return False
except Exception as e: except Exception as e:
logger.error(f"❌ Langfuse init error: {e}") logger.error(f"❌ Langfuse init error: {e}")
return False return False
def shutdown_langfuse():
"""Shutdown Langfuse client gracefully để tránh nghẽn khi exit"""
global _langfuse_client, _export_executor
try:
if _langfuse_client:
# Flush pending traces trước khi shutdown
try:
_langfuse_client.flush()
except Exception as e:
logger.debug(f"Langfuse flush error during shutdown: {e}")
# Shutdown client (non-blocking với timeout)
try:
if hasattr(_langfuse_client, "shutdown"):
_langfuse_client.shutdown()
except Exception as e:
logger.debug(f"Langfuse shutdown error: {e}")
if _export_executor:
_export_executor.shutdown(wait=False) # Non-blocking shutdown
logger.debug("🔒 Langfuse client shutdown completed")
except Exception as e:
logger.debug(f"Error during Langfuse shutdown: {e}")
async def async_flush_langfuse(): async def async_flush_langfuse():
""" """
Async wrapper to flush Langfuse without blocking event loop. Async wrapper to flush Langfuse without blocking event loop.
......
# services/law_db.py
import asyncio
import os
from typing import Any
import httpx
# Support absolute import when run as module, and fallback when run as script
try:
from common.supabase_client import (
close_supabase_client,
init_supabase_client,
supabase_rpc_call,
)
from common.openai_client import get_openai_client
except ImportError:
import os as _os
import sys
_ROOT = _os.path.dirname(_os.path.dirname(_os.path.dirname(_os.path.abspath(__file__))))
if _ROOT not in sys.path:
sys.path.append(_ROOT)
from common.supabase_client import (
init_supabase_client,
supabase_rpc_call,
)
from common.openai_client import get_openai_client
# ====================== CONFIG ======================
def get_supabase_config():
"""Lazy load config để tránh lỗi khi chưa có env vars"""
return {
"url": f"{os.environ['SUPABASE_URL']}/rest/v1/rpc/hoi_phap_luat_all_in_one",
"headers": {
"apikey": os.environ["SUPABASE_ANON_KEY"],
"Authorization": f"Bearer {os.environ['SUPABASE_ANON_KEY']}",
"Content-Type": "application/json; charset=utf-8",
"Accept": "application/json",
},
}
# ====================== HTTP HELPERS ======================
async def _post_with_retry(client: httpx.AsyncClient, url: str, **kw) -> httpx.Response:
"""POST với retry/backoff."""
last_exc: Exception | None = None
for i in range(3):
try:
r = await client.post(url, **kw)
r.raise_for_status()
return r
except Exception as e:
last_exc = e
if i == 2:
raise
await asyncio.sleep(0.4 * (2**i))
raise last_exc # về lý thuyết không tới đây
# ====================== LABEL HELPER ======================
def _label_for_call(call: dict[str, Any], index: int) -> str:
"""Tạo nhãn hiển thị cho mỗi lệnh gọi dựa trên tham số (để phân biệt kết quả)."""
params = call.get("params") or {}
if params.get("p_so_hieu"):
return str(params["p_so_hieu"])
if params.get("p_vb_pattern"):
return str(params["p_vb_pattern"])
return f"Truy vấn {index}"
# ====================== RAW FETCHERS ======================
async def _get_embedding(text: str) -> list[float]:
"""Gọi OpenAI API để lấy embedding vector của đoạn văn bản."""
if not text:
return []
# already imported get_openai_client above with fallback
OAI = get_openai_client()
normalized_text = (text or "").strip().lower()
input_text = normalized_text[:8000] # Giới hạn độ dài
try:
resp = await OAI.embeddings.create(
model="text-embedding-3-small",
input=input_text,
)
return resp.data[0].embedding
except Exception as e:
print(f"❌ Lỗi OpenAI embedding: {e}")
return []
async def law_db_fetch_one(params: dict[str, Any]) -> list[dict[str, Any]]:
"""
Gọi trực tiếp bằng httpx thay vì RPC
"""
# Xử lý p_vector_text thành embedding
processed_params = dict(params)
if processed_params.get("p_vector_text"):
vector_text = processed_params.pop("p_vector_text")
embedding = await _get_embedding(vector_text)
if embedding and len(embedding) > 0:
processed_params["p_vector"] = embedding
# Set mode to semantic nếu chưa có
if "p_mode" not in processed_params and "_mode" not in processed_params:
processed_params["p_mode"] = "semantic"
# Map parameters to correct function signature
mapped_params = {}
for key, value in processed_params.items():
# Skip None values
if value is None:
continue
# Validate p_vector format
if key == "p_vector":
if isinstance(value, list) and len(value) > 0:
# Check for NaN or invalid values
import math
valid_vector = [v for v in value if isinstance(v, (int, float)) and not math.isnan(v)]
if len(valid_vector) == len(value):
mapped_params["p_vector"] = value
else:
print("⚠️ Warning: p_vector contains invalid values (NaN/inf), skipping")
elif isinstance(value, list) and len(value) == 0:
print("⚠️ Warning: p_vector is empty, skipping")
else:
mapped_params["p_vector"] = value
elif key in {"_mode", "mode"}:
mapped_params["p_mode"] = value
elif key == "p_vb_pattern":
mapped_params["p_vb_pattern"] = value
elif key == "p_so_hieu":
mapped_params["p_so_hieu"] = value
elif key == "p_trang_thai":
mapped_params["p_trang_thai"] = value
elif key == "p_co_quan":
mapped_params["p_co_quan"] = value
elif key == "p_loai_vb":
mapped_params["p_loai_vb"] = value
elif key == "p_nam_from":
# Validate integer
if isinstance(value, (int, float)) and not (isinstance(value, float) and value != int(value)):
mapped_params["p_nam_from"] = int(value)
elif value is not None:
print(f"⚠️ Warning: p_nam_from has invalid type/value: {value}, skipping")
elif key == "p_nam_to":
# Validate integer
if isinstance(value, (int, float)) and not (isinstance(value, float) and value != int(value)):
mapped_params["p_nam_to"] = int(value)
elif value is not None:
print(f"⚠️ Warning: p_nam_to has invalid type/value: {value}, skipping")
elif key == "p_only_source":
mapped_params["p_only_source"] = value
elif key == "p_chapter":
mapped_params["p_chapter"] = value
elif key == "p_article":
mapped_params["p_article"] = value
elif key == "p_phu_luc":
mapped_params["p_phu_luc"] = value
elif key == "p_limit":
# Validate integer
if isinstance(value, (int, float)) and not (isinstance(value, float) and value != int(value)):
mapped_params["p_limit"] = int(value)
elif value is not None:
print(f"⚠️ Warning: p_limit has invalid type/value: {value}, skipping")
elif key == "p_ef_search":
# Validate integer
if isinstance(value, (int, float)) and not (isinstance(value, float) and value != int(value)):
mapped_params["p_ef_search"] = int(value)
elif value is not None:
print(f"⚠️ Warning: p_ef_search has invalid type/value: {value}, skipping")
# Gọi qua Supabase shared client (đảm bảo đã init)
try:
await init_supabase_client()
# DEBUG: Print ra JSON sẽ gửi
print("📤 GỬI JSON PAYLOAD:")
import json
debug_payload = json.dumps(mapped_params, ensure_ascii=False, indent=2)
print(debug_payload[:500]) # Print 500 chars đầu để check
rows = await supabase_rpc_call("hoi_phap_luat_all_in_one", mapped_params)
print(f"✅ NHẬN RESULT: {len(rows)} rows")
if rows:
print(f"First row keys: {list(rows[0].keys())}")
_nd = rows[0].get("NoiDung") or rows[0].get("NoiDungDieu") or rows[0].get("NoiDungPhuLuc") or ""
try:
print(f"NoiDung length: {len(_nd)}")
except Exception:
print("NoiDung length: (unavailable)")
return rows or []
except Exception as e:
print(f"❌ HTTPX call failed: {e}")
import traceback
traceback.print_exc()
return []
async def law_db_fetch_plan(calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Nhận danh sách calls (mỗi call có .params) -> chạy song song -> trả:
[
{"label": "...", "rows": [ ... ]},
...
]
"""
if not calls:
return []
async def run_one(call: dict[str, Any], idx: int) -> dict[str, Any]:
label = _label_for_call(call, idx)
params = call.get("params", {})
# Truyền mode xuống để map thành p_mode trong law_db_fetch_one
mode_value = call.get("mode")
print(f"DEBUG: call.mode = {mode_value}")
if mode_value is not None and "_mode" not in params and "mode" not in params:
# Ưu tiên dùng _mode để tránh đè lên tên trường khác
params = dict(params)
params["_mode"] = mode_value
print(f"DEBUG: added _mode = {mode_value} to params")
print(f"DEBUG: final params keys = {list(params.keys())}")
try:
rows = await law_db_fetch_one(params)
return {"label": label, "rows": rows}
except Exception:
return {"label": label, "rows": []}
tasks = [run_one(c, i) for i, c in enumerate(calls, start=1)]
return await asyncio.gather(*tasks)
# ====================== PREVIEW BUILDERS ======================
def build_db_preview(rows: list[dict[str, Any]]) -> str:
"""
Xây dựng chuỗi văn bản nội dung từ danh sách hàng kết quả (có nội dung văn bản).
Nhóm theo văn bản pháp luật và các đơn vị (điều, phụ lục) bên trong.
"""
if not rows:
return ""
docs: dict[str, dict[str, Any]] = {}
for r in rows:
so_hieu = (r.get("SoHieu") or "").strip() or "(Không rõ số hiệu)"
title = (r.get("TieuDe") or r.get("TieuDeVanBan") or r.get("TieuDeDieu") or so_hieu).strip()
docs.setdefault(so_hieu, {"title": title, "groups": {}})
content = (r.get("NoiDung") or r.get("NoiDungDieu") or r.get("NoiDungPhuLuc") or "").strip()
if not content:
continue
chunk_idx = int(r.get("chunk_idx") or r.get("ChunkIdx") or r.get("ChunkIndex") or 1)
phu_luc = r.get("PhuLuc") or r.get("Phu_Luc") or r.get("Phu_luc")
dieu = r.get("Dieu")
chuong = r.get("Chuong")
tieu_de_dieu = r.get("TieuDeDieu") or ""
tieu_de_pl = r.get("TieuDePhuLuc") or r.get("TenPhuLuc") or ""
if phu_luc is not None:
group_key = ("PL", str(phu_luc))
group_title = f"Phụ lục {phu_luc}"
subtitle = tieu_de_pl
elif dieu is not None:
group_key = ("DIEU", str(dieu))
group_title = f"Điều {dieu}"
subtitle = tieu_de_dieu
else:
group_key = ("KHAC", f"Chương {chuong}" if chuong is not None else "Khác")
group_title = group_key[1]
subtitle = ""
groups = docs[so_hieu]["groups"]
groups.setdefault(group_key, {"title": group_title, "subtitle": subtitle, "segs": []})
groups[group_key]["segs"].append((chunk_idx, content))
parts: list[str] = []
for so_hieu, doc in docs.items():
header = f"=== {doc['title']} ({so_hieu}) ==="
parts.append(header)
for group_key, group in sorted(
doc["groups"].items(),
key=lambda x: (
{"PL": 0, "DIEU": 1, "KHAC": 2}.get(x[0][0], 3),
int(x[0][1]) if x[0][1].isdigit() else float("inf"),
),
):
title_line = group["title"] + (f" — {group['subtitle']}" if group["subtitle"] else "")
parts.append(title_line)
for _, text in sorted(group["segs"], key=lambda seg: seg[0]):
parts.append(text)
parts.append("") # ngắt giữa các nhóm
parts.append("") # ngắt giữa các văn bản
return "\n".join(parts).strip()
def build_meta_preview(rows: list[dict[str, Any]]) -> str:
"""
Xây dựng chuỗi văn bản liệt kê các văn bản (chỉ meta, không có nội dung chi tiết).
"""
if not rows:
return ""
lines = []
for i, r in enumerate(rows[:50], start=1):
title = r.get("TieuDe") or r.get("TieuDeVanBan") or "(Không rõ tiêu đề)"
so_hieu = r.get("SoHieu") or "—"
loai = r.get("LoaiVanBan") or ""
cq = r.get("CoQuanBanHanh") or ""
nam = r.get("NamBanHanh") or ""
trang_thai = r.get("TrangThaiVB") or r.get("TrangThai") or ""
meta = " • ".join(filter(None, [loai, cq, f"Năm {nam}" if nam else "", trang_thai]))
lines.append(f"{i}. {title} — {so_hieu}{(' (' + meta + ')') if meta else ''}")
return "\n".join(lines)
def _build_multi_preview(labeled_results: list[dict[str, Any]]) -> str:
"""
Ghép nội dung từ nhiều kết quả truy vấn thành một chuỗi,
có tiêu đề cho từng nhóm kết quả tương ứng với từng truy vấn.
"""
sections: list[str] = []
for item in labeled_results:
label = item.get("label", "Truy vấn")
rows = item.get("rows") or []
if any(r.get("NoiDung") or r.get("NoiDungDieu") or r.get("NoiDungPhuLuc") for r in rows):
preview_text = build_db_preview(rows)
else:
preview_text = build_meta_preview(rows)
if not preview_text:
preview_text = "(Không có dữ liệu)"
sections.append(f"### {label}\n{preview_text}")
return "\n\n".join(sections).strip()
# ====================== PUBLIC APIS ======================
async def fetch_data_db_law(calls: list[dict[str, Any]]) -> str:
"""
Hàm chính: nhận calls -> fetch raw data -> build preview theo mode -> return preview
"""
# Bước 1: Fetch raw data song song
labeled = await law_db_fetch_plan(calls or [])
# Bước 2: Build preview theo mode của từng call
sections: list[str] = []
for item in labeled:
label = item.get("label", "Truy vấn")
rows = item.get("rows") or []
# Tìm mode từ calls tương ứng
call_mode = "content" # default
for call in calls:
if (call.get("params", {}).get("p_so_hieu") and str(call.get("params", {}).get("p_so_hieu")) in label) or (
call.get("params", {}).get("p_vb_pattern") and str(call.get("params", {}).get("p_vb_pattern")) in label
):
call_mode = call.get("mode", "content")
break
# Build preview theo mode
if call_mode == "content":
preview_text = build_db_preview(rows)
elif call_mode == "meta":
preview_text = build_meta_preview(rows)
elif call_mode == "semantic":
preview_text = build_db_preview(rows)
else:
preview_text = build_db_preview(rows)
if not preview_text:
preview_text = "(Không có dữ liệu)"
sections.append(f"### {label}\n{preview_text}")
return "\n\n".join(sections).strip()
__all__ = [
"build_db_preview",
"build_meta_preview",
"fetch_data_db_law",
"law_db_fetch_one",
"law_db_fetch_plan",
]
# if __name__ == "__main__":
# import argparse
# import json as _json
# async def main():
# parser = argparse.ArgumentParser(description="Test fetch_data_db_law nhanh qua CLI")
# parser.add_argument("--mode", "-m", type=str, default="content", help="content|semantic|meta (có thể phân tách bằng dấu phẩy)")
# parser.add_argument("--so_hieu", type=str, default=None, help="Giá trị cho p_so_hieu")
# parser.add_argument("--vb_pattern", type=str, default=None, help="Regex/từ khóa cho p_vb_pattern")
# parser.add_argument("--vector_text", type=str, default=None, help="Văn bản để embedding semantic (p_vector_text)")
# parser.add_argument("--loai_vb", type=str, default=None)
# parser.add_argument("--co_quan", type=str, default=None)
# parser.add_argument("--nam_from", type=int, default=None)
# parser.add_argument("--nam_to", type=int, default=None)
# parser.add_argument("--only_source", type=str, default=None, help="chinh_thong|dia_phuong")
# parser.add_argument("--limit", type=int, default=50)
# parser.add_argument("--article", type=int, default=None)
# parser.add_argument("--chapter", type=int, default=None)
# parser.add_argument("--phu_luc", type=str, default=None)
# parser.add_argument("--multi", action="store_true", help="Nếu bật, tạo nhiều calls mẫu để so sánh")
# args = parser.parse_args()
# modes = [m.strip() for m in (args.mode or "content").split(",") if m.strip()]
# def _build_params():
# return {
# "p_so_hieu": args.so_hieu,
# "p_vb_pattern": args.vb_pattern,
# "p_co_quan": args.co_quan,
# "p_loai_vb": args.loai_vb,
# "p_nam_from": args.nam_from,
# "p_nam_to": args.nam_to,
# "p_only_source": args.only_source,
# "p_article": args.article,
# "p_chapter": args.chapter,
# "p_phu_luc": args.phu_luc,
# "p_limit": args.limit,
# "p_vector_text": args.vector_text,
# }
# calls = []
# if args.multi:
# # Tạo một số tổ hợp mẫu để tiện so sánh
# sample_targets = []
# if args.so_hieu:
# sample_targets.append({"p_so_hieu": args.so_hieu})
# if args.vb_pattern:
# sample_targets.append({"p_vb_pattern": args.vb_pattern})
# if args.vector_text:
# sample_targets.append({"p_vector_text": args.vector_text})
# if not sample_targets:
# # nếu không có gì, dùng một pattern mặc định
# sample_targets = [
# {"p_vb_pattern": "BVMT|bảo vệ môi trường|môi trường"},
# {"p_vector_text": "nội dung nghị định chưa được xác định theo yêu cầu của người hỏi"},
# ]
# base_params = _build_params()
# for mode in modes:
# for target in sample_targets:
# params = dict(base_params)
# params.update({k: v for k, v in target.items() if v is not None})
# calls.append({"mode": mode, "params": {k: v for k, v in params.items() if v is not None}})
# else:
# params = {k: v for k, v in _build_params().items() if v is not None}
# if not params:
# # nếu không truyền gì, tạo ví dụ tối thiểu
# params = {"p_vb_pattern": "BVMT|bảo vệ môi trường|môi trường", "p_limit": 30}
# for mode in modes:
# calls.append({"mode": mode, "params": params})
# print("CLI calls:")
# print(_json.dumps(calls, ensure_ascii=False, indent=2))
# # Init Supabase client trước khi gọi
# await init_supabase_client()
# try:
# preview = await fetch_data_db_law(calls)
# print("\n===== PREVIEW =====\n")
# print(preview)
# finally:
# await close_supabase_client()
# asyncio.run(main())
"""
LLM Factory - OpenAI LLM creation with caching.
Manages initialization and caching of OpenAI models.
"""
import contextlib
import logging
from langchain_core.language_models import BaseChatModel
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from config import OPENAI_API_KEY
logger = logging.getLogger(__name__)
class LLMFactory:
"""Singleton factory for managing OpenAI LLM instances with caching."""
COMMON_MODELS: list[str] = [
"gpt-4o-mini",
"gpt-4o",
"gpt-5-nano",
"gpt-5-mini",
]
def __init__(self):
"""Initialize LLM factory with empty cache."""
self._cache: dict[tuple[str, bool, bool, str | None], BaseChatModel] = {}
def get_model(
self,
model_name: str,
streaming: bool = True,
json_mode: bool = False,
api_key: str | None = None,
) -> BaseChatModel:
"""
Get or create an LLM instance from cache.
Args:
model_name: Model identifier (e.g., "gpt-4o-mini", "gemini-2.0-flash-lite-preview-02-05")
streaming: Enable streaming responses
json_mode: Enable JSON output format
api_key: Optional API key override
Returns:
Configured LLM instance
"""
clean_model = model_name.split("/")[-1] if "/" in model_name else model_name
cache_key = (clean_model, streaming, json_mode, api_key)
if cache_key in self._cache:
logger.debug(f"♻️ Using cached model: {clean_model}")
return self._cache[cache_key]
logger.info(f"Creating new LLM instance: {clean_model}")
return self._create_instance(clean_model, streaming, json_mode, api_key)
def _create_instance(
self,
model_name: str,
streaming: bool = False,
json_mode: bool = False,
api_key: str | None = None,
) -> BaseChatModel:
"""Create and cache a new OpenAI LLM instance."""
try:
llm = self._create_openai(model_name, streaming, json_mode, api_key)
cache_key = (model_name, streaming, json_mode, api_key)
self._cache[cache_key] = llm
return llm
except Exception as e:
logger.error(f"❌ Failed to create model {model_name}: {e}")
raise
def _create_openai(self, model_name: str, streaming: bool, json_mode: bool, api_key: str | None) -> BaseChatModel:
"""Create OpenAI model instance."""
key = api_key or OPENAI_API_KEY
if not key:
raise ValueError("OPENAI_API_KEY is required")
llm_kwargs = {
"model": model_name,
"streaming": streaming,
"api_key": key,
"temperature": 0,
}
# Nếu bật json_mode, tiêm trực tiếp vào constructor
if json_mode:
llm_kwargs["model_kwargs"] = {"response_format": {"type": "json_object"}}
logger.info(f"⚙️ Initializing OpenAI in JSON mode: {model_name}")
llm = ChatOpenAI(**llm_kwargs)
logger.info(f"✅ Created OpenAI: {model_name}")
return llm
def _enable_json_mode(self, llm: BaseChatModel, model_name: str) -> BaseChatModel:
"""Enable JSON mode for the LLM."""
try:
llm = llm.bind(response_format={"type": "json_object"})
logger.debug(f"⚙️ JSON mode enabled for {model_name}")
except Exception as e:
logger.warning(f"⚠️ JSON mode not supported: {e}")
return llm
def initialize(self, skip_warmup: bool = True) -> None:
"""
Pre-initialize common models.
Args:
skip_warmup: Skip initialization if True
"""
if skip_warmup or self._cache:
return
logger.info("🔥 Warming up LLM Factory...")
for model_name in self.COMMON_MODELS:
with contextlib.suppress(Exception):
self.get_model(model_name, streaming=True)
# --- Singleton Instance & Public API ---
_factory = LLMFactory()
def create_llm(
model_name: str,
streaming: bool = True,
json_mode: bool = False,
api_key: str | None = None,
) -> BaseChatModel:
"""Create or get cached LLM instance."""
return _factory.get_model(model_name, streaming=streaming, json_mode=json_mode, api_key=api_key)
def init_llm_factory(skip_warmup: bool = True) -> None:
"""Initialize the LLM factory."""
_factory.initialize(skip_warmup)
def create_embedding_model() -> OpenAIEmbeddings:
"""Create OpenAI embeddings model."""
return OpenAIEmbeddings(model="text-embedding-3-small", api_key=OPENAI_API_KEY)
...@@ -3,8 +3,8 @@ StarRocks Database Connection Utility ...@@ -3,8 +3,8 @@ StarRocks Database Connection Utility
Based on chatbot-rsa pattern Based on chatbot-rsa pattern
""" """
import logging
import asyncio import asyncio
import logging
from typing import Any from typing import Any
import aiomysql import aiomysql
...@@ -34,11 +34,11 @@ class StarRocksConnection: ...@@ -34,11 +34,11 @@ class StarRocksConnection:
password: str | None = None, password: str | None = None,
port: int | None = None, port: int | None = None,
): ):
self.host = host or STARROCKS_HOST self.host = host or STARROCKS_HOST or ""
self.database = database or STARROCKS_DB self.database = database or STARROCKS_DB or ""
self.user = user or STARROCKS_USER self.user = user or STARROCKS_USER or ""
self.password = password or STARROCKS_PASSWORD self.password = password or STARROCKS_PASSWORD or ""
self.port = port or STARROCKS_PORT self.port = port or STARROCKS_PORT or 3306
# self.conn references the shared connection # self.conn references the shared connection
self.conn = None self.conn = None
...@@ -61,11 +61,15 @@ class StarRocksConnection: ...@@ -61,11 +61,15 @@ class StarRocksConnection:
print(f" [DB] 🔌 Đang kết nối StarRocks (New Session): {self.host}:{self.port}...") print(f" [DB] 🔌 Đang kết nối StarRocks (New Session): {self.host}:{self.port}...")
logger.info(f"🔌 Connecting to StarRocks at {self.host}:{self.port} (DB: {self.database})...") logger.info(f"🔌 Connecting to StarRocks at {self.host}:{self.port} (DB: {self.database})...")
try: try:
# Ensure all required parameters are strings (not None)
if not all([self.host, self.user, self.password, self.database]):
raise ValueError("Missing required StarRocks connection parameters")
new_conn = pymysql.connect( new_conn = pymysql.connect(
host=self.host, host=self.host,
port=self.port, port=self.port,
user=self.user, user=self.user,
password=self.password, password=self.password, # Now guaranteed to be str, not None
database=self.database, database=self.database,
charset="utf8mb4", charset="utf8mb4",
cursorclass=DictCursor, cursorclass=DictCursor,
...@@ -121,11 +125,15 @@ class StarRocksConnection: ...@@ -121,11 +125,15 @@ class StarRocksConnection:
# Double-check inside lock to prevent multiple pools # Double-check inside lock to prevent multiple pools
if StarRocksConnection._shared_pool is None: if StarRocksConnection._shared_pool is None:
logger.info(f"🔌 Creating Async Pool to {self.host}:{self.port}...") logger.info(f"🔌 Creating Async Pool to {self.host}:{self.port}...")
# Ensure all required parameters are strings (not None)
if not all([self.host, self.user, self.password, self.database]):
raise ValueError("Missing required StarRocks connection parameters")
StarRocksConnection._shared_pool = await aiomysql.create_pool( StarRocksConnection._shared_pool = await aiomysql.create_pool(
host=self.host, host=self.host,
port=self.port, port=self.port,
user=self.user, user=self.user,
password=self.password, password=self.password, # Now guaranteed to be str, not None
db=self.database, db=self.database,
charset="utf8mb4", charset="utf8mb4",
cursorclass=aiomysql.DictCursor, cursorclass=aiomysql.DictCursor,
...@@ -160,15 +168,16 @@ class StarRocksConnection: ...@@ -160,15 +168,16 @@ class StarRocksConnection:
# Nếu StarRocks OOM, đợi một chút rồi thử lại # Nếu StarRocks OOM, đợi một chút rồi thử lại
await asyncio.sleep(0.5 * (attempt + 1)) await asyncio.sleep(0.5 * (attempt + 1))
continue continue
elif "Disconnected" in str(e) or "Lost connection" in str(e): if "Disconnected" in str(e) or "Lost connection" in str(e):
# Nếu mất kết nối, có thể pool bị stale, thử lại ngay # Nếu mất kết nối, có thể pool bị stale, thử lại ngay
continue continue
else:
# Các lỗi khác (cú pháp,...) thì raise luôn # Các lỗi khác (cú pháp,...) thì raise luôn
raise raise
logger.error(f"❌ Failed after {max_retries} attempts: {last_error}") logger.error(f"❌ Failed after {max_retries} attempts: {last_error}")
if last_error:
raise last_error raise last_error
raise RuntimeError("Failed to execute query after multiple attempts")
def close(self): def close(self):
"""Explicitly close if needed (e.g. app shutdown)""" """Explicitly close if needed (e.g. app shutdown)"""
......
...@@ -49,7 +49,9 @@ langchain==1.2.0 ...@@ -49,7 +49,9 @@ langchain==1.2.0
langchain-core==1.2.3 langchain-core==1.2.3
langchain-google-genai==4.1.2 langchain-google-genai==4.1.2
langchain-openai==1.1.6 langchain-openai==1.1.6
agno==2.3.24
langfuse==3.11.0 langfuse==3.11.0
openinference-instrumentation-agno==1.0.0
langgraph==1.0.5 langgraph==1.0.5
langgraph-checkpoint==3.0.1 langgraph-checkpoint==3.0.1
langgraph-checkpoint-postgres==3.0.2 langgraph-checkpoint-postgres==3.0.2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment