Commit 37643d07 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

fix: resolve empty product_ids, fix sql bug, handle decimal json error & add check_is_stock tool

parent 1090ad4e
......@@ -17,7 +17,7 @@ from config import DEFAULT_MODEL, REDIS_CACHE_TURN_ON
from langfuse import propagate_attributes
from .graph import build_graph
from .helper import extract_product_ids, handle_post_chat_async, parse_ai_response
from .helper import extract_product_ids, handle_post_chat_async, parse_ai_response_async
from .models import AgentState, get_config
logger = logging.getLogger(__name__)
......@@ -44,6 +44,7 @@ async def chat_controller(
"""
effective_identity_key = identity_key or user_id
logger.info(
"chat_controller start: model=%s, user_id=%s, identity_key=%s",
model_name, user_id, effective_identity_key
......@@ -129,10 +130,13 @@ async def chat_controller(
result = await graph.ainvoke(initial_state, config=exec_config)
# Parse Response
all_product_ids = extract_product_ids(result.get("messages", []))
final_messages = result.get("messages", [])
# Combine history + current messages to find ALL products mentioned in conversation
full_conversation = messages + final_messages
all_product_ids = extract_product_ids(full_conversation)
ai_raw_content = result.get("ai_response").content if result.get("ai_response") else ""
# Unpack 3 values now
ai_text_response, final_product_ids, new_insight = parse_ai_response(ai_raw_content, all_product_ids)
ai_text_response, final_product_ids, new_insight = await parse_ai_response_async(ai_raw_content, all_product_ids)
# Save new insight to Redis if available
if new_insight and effective_identity_key:
......
......@@ -6,17 +6,31 @@ Các hàm tiện ích cho chat controller.
import json
import logging
import uuid
from decimal import Decimal
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from common.conversation_manager import ConversationManager
from common.langfuse_client import get_callback_handler
from common.langfuse_client import get_callback_handler
from common.starrocks_connection import get_db_connection
from agent.tools.data_retrieval_filter import format_product_results
from .models import AgentState
logger = logging.getLogger(__name__)
def decimal_default(obj):
"""
JSON serializer for objects not serializable by default json code.
Handles Decimal objects.
"""
if isinstance(obj, Decimal):
return float(obj)
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
def extract_product_ids(messages: list) -> list[dict]:
"""
Extract full product info from tool messages (data_retrieval_tool results).
......@@ -37,13 +51,23 @@ def extract_product_ids(messages: list) -> list[dict]:
product_list = []
if "results" in tool_result:
# New format: {"results": [{"products": [...]}]}
for result_item in tool_result["results"]:
product_list.extend(result_item.get("products", []))
results_data = tool_result["results"]
if results_data and isinstance(results_data, list):
# Check first item to determine format
first_item = results_data[0] if len(results_data) > 0 else {}
if isinstance(first_item, dict) and "products" in first_item:
# Nested format: {"results": [{"products": [...]}]}
for result_item in results_data:
product_list.extend(result_item.get("products", []))
else:
# Flat format: {"results": [product1, product2]} (Current)
product_list = results_data
elif "products" in tool_result:
# Legacy format: {"products": [...]}
product_list = tool_result["products"]
logger.warning(f"🛠️ [EXTRACT] Extracted {len(product_list)} products")
for product in product_list:
sku = product.get("sku") or product.get("internal_ref_code")
if sku and sku not in seen_skus:
......@@ -66,24 +90,65 @@ def extract_product_ids(messages: list) -> list[dict]:
return products
def parse_ai_response(ai_raw_content: str, all_products: list) -> tuple[str, list, str | None]:
async def fetch_products_by_skus(skus: list[str]) -> list[dict]:
"""
Fetch product details from DB for a list of SKUs.
Used when AI mentions products that are not in the current tool output context.
"""
if not skus:
return []
db = get_db_connection()
if not db:
logger.error("❌ DB Connection failed in fetch_products_by_skus")
return []
# Format SKUs for SQL IN clause
placeholders = ",".join(["%s"] * len(skus))
sql = f"""
SELECT
internal_ref_code,
description_text_full,
sale_price,
original_price,
discount_amount,
product_line_vn,
product_line_en,
1.0 as max_score
FROM shared_source.magento_product_dimension_with_text_embedding
WHERE internal_ref_code IN ({placeholders}) OR magento_ref_code IN ({placeholders})
"""
# Params: Pass SKUs twice (once for internal_ref, once for magento_ref)
params = skus + skus
try:
results = await db.execute_query_async(sql, params=params)
logger.info(f"🔄 Fetched {len(results)} fallback products from DB for SKUs: {skus}")
return format_product_results(results)
except Exception as e:
logger.error(f"❌ Error fetching fallback products: {e}")
return []
async def parse_ai_response_async(ai_raw_content: str, all_products: list) -> tuple[str, list, str | None]:
"""
Async version of parse_ai_response with DB fallback.
Parse AI response từ LLM output và map SKUs với product data.
Nếu SKU được mention nhưng không có trong all_products (context hiện tại),
sẽ query trực tiếp DB để lấy thông tin.
Flow:
- LLM trả về: {"ai_response": "...", "product_ids": ["SKU1", "SKU2"], "user_insight": "..."}
- all_products: List products enriched từ tool messages
- Map SKUs → enriched products
Args:
ai_raw_content: Raw content từ AI response
all_products: Products extracted từ tool messages (đã có đầy đủ info)
Returns:
tuple: (ai_text_response, final_products, user_insight)
- LLM trả về: {"ai_response": "...", "product_ids": ["SKU1"], ...}
- Map SKUs → enriched products từ context
- Nếu thiếu → Query DB
"""
from .structured_models import ChatResponse, UserInsight
import re
ai_text_response = ai_raw_content
final_products = all_products # Default: trả về tất cả products từ tool
final_products = []
user_insight = None
logger.info(f"🤖 Raw AI JSON: {ai_raw_content}")
......@@ -91,37 +156,86 @@ def parse_ai_response(ai_raw_content: str, all_products: list) -> tuple[str, lis
try:
# Try to parse if it's a JSON string from LLM
ai_json = json.loads(ai_raw_content)
ai_text_response = ai_json.get("ai_response", ai_raw_content)
explicit_skus = ai_json.get("product_ids", [])
user_insight = ai_json.get("user_insight")
# === PYDANTIC VALIDATION ===
try:
# Try strict Pydantic validation
parsed_response = ChatResponse.model_validate(ai_json)
ai_text_response = parsed_response.ai_response
explicit_skus = parsed_response.product_ids
# Convert user_insight to dict/string for storage
if parsed_response.user_insight:
user_insight = parsed_response.user_insight.model_dump_json(indent=2)
logger.info("✅ Pydantic validation passed for ChatResponse")
except Exception as validation_error:
# Fallback to manual parsing if Pydantic fails
logger.warning(f"⚠️ Pydantic validation failed, using fallback: {validation_error}")
ai_text_response = ai_json.get("ai_response", ai_raw_content)
explicit_skus = ai_json.get("product_ids", [])
raw_insight = ai_json.get("user_insight")
if raw_insight:
if isinstance(raw_insight, dict):
user_insight = json.dumps(raw_insight, ensure_ascii=False, indent=2)
elif isinstance(raw_insight, str):
user_insight = raw_insight
# === CRITICAL: Filter/Fetch products ===
# Extract SKUs mentioned in ai_response text using regex pattern [SKU]
mentioned_skus_in_text = set(re.findall(r'\[([A-Z0-9]+)\]', ai_text_response))
logger.info(f"📝 SKUs mentioned in ai_response: {mentioned_skus_in_text}")
# Determine target SKUs
target_skus = set()
# 1. Use explicit SKUs if available and confirmed by text, OR just explicit
if explicit_skus and isinstance(explicit_skus, list):
# LLM trả về list SKUs → Map với products đã có
# Build lookup dict từ all_products
# Optional: Filter explicit SKUs to only those actually in text to reduce hallucination
# But if explicit list is provided, we generally trust it unless we want strict text-match
if mentioned_skus_in_text:
explicit_set = set(str(s) for s in explicit_skus)
target_skus = explicit_set.intersection(mentioned_skus_in_text)
if not target_skus: # If intersection empty, fallback to text mentions
target_skus = mentioned_skus_in_text
else:
target_skus = set(str(s) for s in explicit_skus)
elif mentioned_skus_in_text:
# 2. If no explicit SKUs, use text mentions
target_skus = mentioned_skus_in_text
logger.info(f"🎯 Target SKUs to return: {target_skus}")
if target_skus:
# Build lookup from current context
product_lookup = {p["sku"]: p for p in all_products if p.get("sku")}
# Map SKUs → enriched products
mapped_products = []
for sku in explicit_skus:
if isinstance(sku, str) and sku in product_lookup:
mapped_products.append(product_lookup[sku])
elif isinstance(sku, dict):
# LLM có thể trả về dict (legacy) → giữ nguyên
mapped_products.append(sku)
found_products = []
missing_skus = []
if mapped_products:
final_products = mapped_products
else:
# If explicit SKUs provided but none found in DB, return empty list
# This prevents showing unrelated products when AI hallucinates or references old/invalid SKUs
final_products = []
for sku in target_skus:
if sku in product_lookup:
found_products.append(product_lookup[sku])
else:
missing_skus.append(sku)
except (json.JSONDecodeError, TypeError):
pass
# Fetch missing SKUs from DB
if missing_skus:
logger.info(f"🕵️ Missing SKUs in context, fetching from DB: {missing_skus}")
fallback_products = await fetch_products_by_skus(missing_skus)
found_products.extend(fallback_products)
final_products = found_products
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"⚠️ Failed to parse AI response as JSON: {e}")
return ai_text_response, final_products, user_insight
def prepare_execution_context(query: str, user_id: str, history: list, images: list | None):
"""
Prepare initial state and execution config for the graph run.
......@@ -173,7 +287,8 @@ async def handle_post_chat_async(
if ai_response:
try:
# Convert dict thành JSON string để lưu vào TEXT field
ai_response_json = json.dumps(ai_response, ensure_ascii=False)
# Use decimal_default to handle Decimal types from DB
ai_response_json = json.dumps(ai_response, ensure_ascii=False, default=decimal_default)
await memory.save_conversation_turn(identity_key, human_query, ai_response_json)
logger.debug(f"Saved conversation for identity_key {identity_key}")
except Exception as e:
......
import logging
import httpx
from langchain_core.tools import tool
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class StockCheckInput(BaseModel):
skus: str = Field(
description="Danh sách mã SKU sản phẩm cần kiểm tra tồn kho, phân cách bằng dấu phẩy. Ví dụ: '6ST25W005-SE091-L,6ST25W005-SE091-M'"
)
@tool("check_is_stock", args_schema=StockCheckInput)
async def check_is_stock(skus: str) -> str:
"""
Kiểm tra tình trạng tồn kho của các mã sản phẩm (SKU) thực tế từ hệ thống Canifa.
Sử dụng tool này khi người dùng hỏi về tình trạng còn hàng, hết hàng của sản phẩm cụ thể.
Input nhận vào là chuỗi các SKU phân cách bởi dấu phẩy.
"""
logger.info(f"🔍 [Stock Check] Checking stock for SKUs: {skus}")
url = "https://canifa.com/v1/middleware/stock_get_stock_list"
params = {"skus": skus}
try:
async with httpx.AsyncClient() as client:
response = await client.get(url, params=params, timeout=10.0)
response.raise_for_status()
data = response.json()
logger.info(f"✅ Stock Check response: {str(data)[:200]}...")
# Trả về raw JSON để LLM tự xử lý thông tin
return str(data)
except httpx.RequestError as e:
logger.error(f"❌ Network error checking stock: {e}")
return f"Lỗi kết nối khi kiểm tra tồn kho: {str(e)}"
except httpx.HTTPStatusError as e:
logger.error(f"❌ HTTP error {e.response.status_code}: {e}")
return f"Lỗi server khi kiểm tra tồn kho (Status {e.response.status_code})"
except Exception as e:
logger.error(f"❌ Unexpected error in check_is_stock: {e}")
return f"Lỗi không xác định khi kiểm tra tồn kho: {str(e)}"
......@@ -9,11 +9,12 @@ from .brand_knowledge_tool import canifa_knowledge_search
from .customer_info_tool import collect_customer_info
from .data_retrieval_tool import data_retrieval_tool
from .promotion_canifa_tool import canifa_get_promotions
from .check_is_stock import check_is_stock
def get_retrieval_tools() -> list[Tool]:
"""Các tool chỉ dùng để đọc/truy vấn dữ liệu (Có thể cache)"""
return [data_retrieval_tool, canifa_knowledge_search, canifa_get_promotions]
return [data_retrieval_tool, canifa_knowledge_search, canifa_get_promotions, check_is_stock]
def get_collection_tools() -> list[Tool]:
......
......@@ -99,6 +99,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
product_line_en,
1.0 as max_score
FROM shared_source.magento_product_dimension_with_text_embedding
WHERE internal_ref_code = %s OR magento_ref_code = %s
"""
return sql, [magento_code, magento_code]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment