Commit 3baf7736 authored by Hoanganhvu123's avatar Hoanganhvu123

feat: integrate Claude API via zunef gateway with JSON parsing

Cherry-pick commit 88341cc4 with conflict resolution in graph.py.
Replace with_structured_output with manual JSON parsing for Claude API.
Co-Authored-By: 's avatarClaude Sonnet 4.6 <noreply@anthropic.com>
parent 7659d548
...@@ -269,13 +269,15 @@ class LeadStageGraph: ...@@ -269,13 +269,15 @@ class LeadStageGraph:
"add_to_cart": add_to_cart, "add_to_cart": add_to_cart,
} }
# AI 1: Classifier — structured output, NHẸ (streaming=False) # AI 1: Classifier — structured output bypass for Claude API (zunef gateway)
_classifier_base = create_llm(model_name=self.model_name, streaming=False) _classifier_base = create_llm(model_name=self.model_name, streaming=False)
self.classifier_llm = _classifier_base.with_structured_output(ClassifierOutput, method="function_calling") self.classifier_llm = _classifier_base
# Note: Claude models bypass structured output, parse JSON manually
# AI 2: Stylist — structured output, NẶNG (streaming=False vì structured) # AI 2: Stylist — structured output bypass for Claude API (zunef gateway)
_stylist_base = create_llm(model_name=self.model_name, streaming=False) _stylist_base = create_llm(model_name=self.model_name, streaming=False)
self.stylist_llm = _stylist_base.with_structured_output(StylistOutput, method="function_calling") self.stylist_llm = _stylist_base
# Note: Claude models bypass structured output, parse JSON manually
self._compiled = None self._compiled = None
logger.info(f"✅ LeadStageGraph initialized | model: {self.model_name}") logger.info(f"✅ LeadStageGraph initialized | model: {self.model_name}")
...@@ -386,7 +388,33 @@ class LeadStageGraph: ...@@ -386,7 +388,33 @@ class LeadStageGraph:
start = time.time() start = time.time()
try: try:
output: ClassifierOutput = await self.classifier_llm.ainvoke(classifier_messages, config=config) response = await self.classifier_llm.ainvoke(
classifier_messages, config=config
)
# Parse JSON from text response (Claude models return text, not structured output)
import json as json_lib
text = response.content if hasattr(response, "content") else str(response)
logger.info(f"[CLASSIFIER] Raw response repr: {repr(text[:500])}")
# Extract JSON from text (may have extra text)
text = text.strip()
if text.startswith('```json'):
text = text[7:]
if text.startswith('```'):
text = text[3:]
if text.endswith('```'):
text = text[:-3]
text = text.strip()
try:
output = json_lib.loads(text)
except json_lib.JSONDecodeError as je:
logger.error(f"[CLASSIFIER] JSON parse error: {je}. Text was: {repr(text[:500])}")
raise
# Convert dict to object-like access
class DictToObj:
def __init__(self, data):
self.__dict__.update(data)
output = DictToObj(output)
except Exception as e: except Exception as e:
logger.error(f"❌ Classifier error: {e}") logger.error(f"❌ Classifier error: {e}")
return { return {
...@@ -403,7 +431,13 @@ class LeadStageGraph: ...@@ -403,7 +431,13 @@ class LeadStageGraph:
# ── Resolve tool_args: typed lead_search_args hoac raw dict ── # ── Resolve tool_args: typed lead_search_args hoac raw dict ──
if tool_name == "lead_search_tool" and output.lead_search_args: if tool_name == "lead_search_tool" and output.lead_search_args:
tool_args = output.lead_search_args.model_dump(exclude_none=False) # Handle both dict and Pydantic model cases
if isinstance(output.lead_search_args, dict):
tool_args = output.lead_search_args
elif hasattr(output.lead_search_args, 'model_dump'):
tool_args = output.lead_search_args.model_dump(exclude_none=False)
else:
tool_args = dict(output.lead_search_args) if output.lead_search_args else {}
elif tool_name == "lead_search_tool" and output.tool_args: elif tool_name == "lead_search_tool" and output.tool_args:
# fallback: model dung tool_args thay vi lead_search_args # fallback: model dung tool_args thay vi lead_search_args
tool_args = output.tool_args tool_args = output.tool_args
...@@ -557,13 +591,39 @@ class LeadStageGraph: ...@@ -557,13 +591,39 @@ class LeadStageGraph:
HumanMessage(content="=== KHÔNG CÓ TOOL RESULT === (Khách chào hỏi hoặc câu hỏi chung)") HumanMessage(content="=== KHÔNG CÓ TOOL RESULT === (Khách chào hỏi hoặc câu hỏi chung)")
) )
# ── LLM Call: Structured Output ── # ── LLM Call: Parse JSON from text (Claude returns text, not structured) ──
# ── LLM Call: Parse JSON from text (Claude returns text, not structured) ──
start = time.time() start = time.time()
try: try:
output: StylistOutput = await self.stylist_llm.ainvoke(stylist_messages, config=config) response = await self.stylist_llm.ainvoke(
stylist_messages, config=config
)
# Parse JSON from text response
text = response.content if hasattr(response, "content") else str(response)
logger.info(f"[STYLIST] Raw response text (first 500 chars): {repr(text[:500])}")
text = text.strip()
if text.startswith('```json'):
text = text[7:]
if text.startswith('```'):
text = text[3:]
if text.endswith('```'):
text = text[:-3]
text = text.strip()
output = json.loads(text)
# Convert dict to object-like access
class DictToObj:
def __init__(self, data):
self.__dict__.update(data)
def __getattr__(self, name):
return self.__dict__.get(name)
output = DictToObj(output)
# Handle nested user_insight dict
if hasattr(output, "user_insight") and isinstance(output.user_insight, dict):
output.user_insight = DictToObj(output.user_insight)
except Exception as e: except Exception as e:
logger.error(f"❌ Stylist structured output error: {e}") logger.error(f"❌ Stylist JSON parse error: {e}")
# Fallback: trả response text thuần
fallback = "Dạ bạn cho mình hỏi thêm để tư vấn chính xác hơn nhé!" fallback = "Dạ bạn cho mình hỏi thêm để tư vấn chính xác hơn nhé!"
return { return {
"messages": [AIMessage(content=fallback)], "messages": [AIMessage(content=fallback)],
...@@ -582,8 +642,18 @@ class LeadStageGraph: ...@@ -582,8 +642,18 @@ class LeadStageGraph:
from_text = _extract_skus_from_text(output.ai_response) from_text = _extract_skus_from_text(output.ai_response)
merged_product_ids = _dedupe_preserve_order((output.product_ids or []) + from_text) merged_product_ids = _dedupe_preserve_order((output.product_ids or []) + from_text)
# Extract insight # Extract insight - handle both dict and object cases
insight_dict = output.user_insight.model_dump() merged_product_ids = _dedupe_preserve_order((output.product_ids or []) + from_text)
# Extract insight - handle both dict and object cases
insight_raw = getattr(output, "user_insight", None)
if isinstance(insight_raw, dict):
insight_dict = insight_raw
elif hasattr(insight_raw, "__dict__"):
insight_dict = insight_raw.__dict__
else:
insight_dict = {}
lead_stage = { lead_stage = {
"stage": insight_dict.get("STAGE_NUM", 1), "stage": insight_dict.get("STAGE_NUM", 1),
"stage_name": insight_dict.get("STAGE", "BROWSE"), "stage_name": insight_dict.get("STAGE", "BROWSE"),
...@@ -592,21 +662,24 @@ class LeadStageGraph: ...@@ -592,21 +662,24 @@ class LeadStageGraph:
} }
stage_injection = format_insight_injection(insight_dict) stage_injection = format_insight_injection(insight_dict)
ai_response = getattr(output, "ai_response", "Dạ mình chưa hiểu, bạn nói lại nhé!")
product_ids = getattr(output, "product_ids", []) or []
logger.info( logger.info(
f"💬 Stylist: stage={insight_dict.get('STAGE')} | goal={insight_dict.get('GOAL')} | " f"💬 Stylist: stage={insight_dict.get('STAGE')} | goal={insight_dict.get('GOAL')} | "
f"response_len={len(output.ai_response)} | products={merged_product_ids} | {elapsed_ms:.0f}ms" f"response_len={len(ai_response)} | products={merged_product_ids} | {elapsed_ms:.0f}ms"
) )
diag = { diag = {
"step": "stylist", "step": "stylist",
"label": "💬 Stylist (Response + Insight)", "label": "💬 Stylist (Response + Insight)",
"content": output.ai_response[:500], "content": ai_response[:500],
"elapsed_ms": round(elapsed_ms), "elapsed_ms": round(elapsed_ms),
"raw_json": json.dumps(output.model_dump(), ensure_ascii=False, indent=2), "raw_json": json.dumps(insight_dict, ensure_ascii=False, indent=2),
} }
return { return {
"messages": [AIMessage(content=output.ai_response)], "messages": [AIMessage(content=ai_response)],
"updated_insight": insight_dict, "updated_insight": insight_dict,
"lead_stage": lead_stage, "lead_stage": lead_stage,
"stage_injection": stage_injection, "stage_injection": stage_injection,
......
...@@ -225,6 +225,18 @@ Query: "combo gì cho ngày 2/9" ...@@ -225,6 +225,18 @@ Query: "combo gì cho ngày 2/9"
→ inferred.tags: ["Màu đỏ", "Đi chơi / dạo phố"] → inferred.tags: ["Màu đỏ", "Đi chơi / dạo phố"]
→ inferred.keywords: ["quốc khánh", "2/9", "lễ lớn"] → inferred.keywords: ["quốc khánh", "2/9", "lễ lớn"]
``` ```
**QUAN TRỌNG: BẮT BUỘC OUTPUT FORMAT**
Bạn PHẢI trả về DUY NHẤT một JSON object hợp lệ, KHÔNG text thêm thắt.
JSON schema:
{
"reasoning": "lý luận ngắn gọn",
"tool_name": null hoặc "lead_search_tool",
"lead_search_args": {"literal": {"raw_text": "..."}, "inferred": {"product_line_vn": [], "gender_by_product": null, ...}, "magento_ref_code": null, "reasoning": "..."},
"tool_args": null,
"ai_response": null hoặc "câu chào hỏi",
"product_ids": []
}
""" """
# ═══════════════════════════════════════════════ # ═══════════════════════════════════════════════
...@@ -420,6 +432,15 @@ Cập nhật 12 trường dưới đây sau mỗi turn. Cộng dồn thông tin. ...@@ -420,6 +432,15 @@ Cập nhật 12 trường dưới đây sau mỗi turn. Cộng dồn thông tin.
- KHÔNG dán link URL - KHÔNG dán link URL
- KHÔNG tự bịa tính năng không có trong tool results - KHÔNG tự bịa tính năng không có trong tool results
- Mở đầu bằng 1 câu đồng cảm dựa trên GOAL - Mở đầu bằng 1 câu đồng cảm dựa trên GOAL
**QUAN TRỌNG: BẮT BUỘC OUTPUT FORMAT**
Bạn PHẢI trả về DUY NHẤT một JSON object hợp lệ, KHÔNG text thêm thắt.
JSON schema:
{
"ai_response": "câu trả lời cho khách (max 250 từ)",
"product_ids": [],
"user_insight": {"USER": "...", "TARGET": "...", "GOAL": "...", ...}
}
""" """
# ═══════════════════════════════════════════════ # ═══════════════════════════════════════════════
......
This diff is collapsed.
...@@ -81,7 +81,14 @@ OPENAI_API_KEY: str | None = os.getenv("OPENAI_API_KEY") ...@@ -81,7 +81,14 @@ OPENAI_API_KEY: str | None = os.getenv("OPENAI_API_KEY")
GOOGLE_API_KEY: str | None = os.getenv("GOOGLE_API_KEY") GOOGLE_API_KEY: str | None = os.getenv("GOOGLE_API_KEY")
GROQ_API_KEY: str | None = os.getenv("GROQ_API_KEY") GROQ_API_KEY: str | None = os.getenv("GROQ_API_KEY")
DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "gemini-3.1-flash-lite-preview") # DS2API (DeepSeek via OpenAI-compatible wrapper)
DS2API_BASE_URL: str | None = os.getenv("DS2API_BASE_URL")
DS2API_API_KEY: str | None = os.getenv("DS2API_API_KEY")
# Claude API (via zunef gateway)
CLAUDE_BASE_URL: str | None = os.getenv("CLAUDE_BASE_URL", "https://claude-api.zunef.com/v1/ai")
DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "claude-sonnet-4-6")
# DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "gpt-5.1-codex-mini") # DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "gpt-5.1-codex-mini")
# ====================== JWT CONFIGURATION ====================== # ====================== JWT CONFIGURATION ======================
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment