Commit 479c7fdb authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

fix: Agno framework integration - fix session management, JSON output, and async methods

parent 28274420
...@@ -34,10 +34,17 @@ async def chat_controller( ...@@ -34,10 +34,17 @@ async def chat_controller(
with langfuse_trace_context(user_id=user_id, session_id=user_id): with langfuse_trace_context(user_id=user_id, session_id=user_id):
# Agno tự động load history và save sau khi respond (memory enabled) # Agno tự động load history và save sau khi respond (memory enabled)
result = agent.run(query, session_id=user_id) result = await agent.arun(query, session_id=user_id)
# Extract response # Extract response
ai_content = str(result.content if hasattr(result, "content") and result.content else str(result)) ai_content = str(result.content if hasattr(result, "content") and result.content else str(result))
# Strip markdown JSON wrapper if present (```json ... ```)
if ai_content.startswith("```json"):
ai_content = ai_content.replace("```json", "").replace("```", "").strip()
elif ai_content.startswith("```"):
ai_content = ai_content.replace("```", "").strip()
logger.info(f"💾 AI Response: {ai_content[:200]}...") logger.info(f"💾 AI Response: {ai_content[:200]}...")
# Parse response và extract products # Parse response và extract products
...@@ -70,7 +77,7 @@ def _parse_agno_response(result: Any, ai_content: str) -> tuple[str, list[dict]] ...@@ -70,7 +77,7 @@ def _parse_agno_response(result: Any, ai_content: str) -> tuple[str, list[dict]]
logger.debug(f"Response is not JSON, using raw text: {e}") logger.debug(f"Response is not JSON, using raw text: {e}")
# Extract products từ tool results # Extract products từ tool results
if hasattr(result, "messages"): if hasattr(result, "messages") and result.messages is not None:
tool_products = _extract_products_from_messages(result.messages) tool_products = _extract_products_from_messages(result.messages)
# Merge và deduplicate # Merge và deduplicate
seen_skus = {p.get("sku") for p in product_ids if isinstance(p, dict) and "sku" in p} seen_skus = {p.get("sku") for p in product_ids if isinstance(p, dict) and "sku" in p}
...@@ -122,13 +129,14 @@ def _parse_products(products: list[dict], seen_skus: set[str]) -> list[dict]: ...@@ -122,13 +129,14 @@ def _parse_products(products: list[dict], seen_skus: set[str]) -> list[dict]:
continue continue
seen_skus.add(sku) seen_skus.add(sku)
parsed.append({ parsed.append(
"sku": sku, {
"name": product.get("magento_product_name", ""), "sku": sku,
"price": product.get("price_vnd", 0), "name": product.get("magento_product_name", ""),
"sale_price": product.get("sale_price_vnd"), "price": product.get("price_vnd", 0),
"url": product.get("magento_url_key", ""), "sale_price": product.get("sale_price_vnd"),
"thumbnail_image_url": product.get("thumbnail_image_url", ""), "url": product.get("magento_url_key", ""),
}) "thumbnail_image_url": product.get("thumbnail_image_url", ""),
}
)
return parsed return parsed
import asyncio
import time import time
from common.embedding_service import create_embedding_async from common.embedding_service import create_embedding_async
...@@ -112,10 +111,10 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None) ...@@ -112,10 +111,10 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
""" """
print("✅ [CODE SEARCH] Query built - No vector search needed!") print("✅ [CODE SEARCH] Query built - No vector search needed!")
# Ghi log debug query FULL vào Background Task (Không làm chậm Request) # Ghi log debug query FULL vào Background Task (Không làm chậm Request)
asyncio.create_task(save_query_to_log(sql)) # asyncio.create_task(save_query_to_log(sql))
return sql return sql
# ============================================================ # ============================================================
...@@ -174,7 +173,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None) ...@@ -174,7 +173,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
""" """
# Ghi log debug query FULL vào Background Task (Không làm chậm Request) # Ghi log debug query FULL vào Background Task (Không làm chậm Request)
asyncio.create_task(save_query_to_log(sql)) # asyncio.create_task(save_query_to_log(sql))
return sql return sql
...@@ -182,6 +181,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None) ...@@ -182,6 +181,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
async def save_query_to_log(sql: str): async def save_query_to_log(sql: str):
"""Lưu query full vào file hyde_pure_query.txt.""" """Lưu query full vào file hyde_pure_query.txt."""
import os import os
log_path = r"D:\cnf\chatbot_canifa\backend\logs\hyde_pure_query.txt" log_path = r"D:\cnf\chatbot_canifa\backend\logs\hyde_pure_query.txt"
try: try:
log_dir = os.path.dirname(log_path) log_dir = os.path.dirname(log_path)
...@@ -196,8 +196,8 @@ async def save_query_to_log(sql: str): ...@@ -196,8 +196,8 @@ async def save_query_to_log(sql: str):
async def save_preview_to_log(search_query: str, products: list[dict]): async def save_preview_to_log(search_query: str, products: list[dict]):
"""Lưu kết quả DB trả về vào db_preview.txt (Format đẹp cho AI).""" """Lưu kết quả DB trả về vào db_preview.txt (Format đẹp cho AI)."""
import json
import os import os
preview_path = r"D:\cnf\chatbot_canifa\backend\logs\db_preview.txt" preview_path = r"D:\cnf\chatbot_canifa\backend\logs\db_preview.txt"
try: try:
log_dir = os.path.dirname(preview_path) log_dir = os.path.dirname(preview_path)
...@@ -205,12 +205,12 @@ async def save_preview_to_log(search_query: str, products: list[dict]): ...@@ -205,12 +205,12 @@ async def save_preview_to_log(search_query: str, products: list[dict]):
os.makedirs(log_dir) os.makedirs(log_dir)
with open(preview_path, "a", encoding="utf-8") as f: with open(preview_path, "a", encoding="utf-8") as f:
f.write(f"\n{'='*60}\n") f.write(f"\n{'=' * 60}\n")
f.write(f"⏰ TIME: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"⏰ TIME: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"🔍 SEARCH: {search_query}\n") f.write(f"🔍 SEARCH: {search_query}\n")
f.write(f"📊 RESULTS COUNT: {len(products)}\n") f.write(f"📊 RESULTS COUNT: {len(products)}\n")
f.write(f"{'-'*60}\n") f.write(f"{'-' * 60}\n")
if not products: if not products:
f.write("❌ NO PRODUCTS FOUND\n") f.write("❌ NO PRODUCTS FOUND\n")
else: else:
...@@ -221,12 +221,12 @@ async def save_preview_to_log(search_query: str, products: list[dict]): ...@@ -221,12 +221,12 @@ async def save_preview_to_log(search_query: str, products: list[dict]):
disc = p.get("discount_amount", "0") disc = p.get("discount_amount", "0")
score = p.get("max_score", p.get("similarity_score", "N/A")) score = p.get("max_score", p.get("similarity_score", "N/A"))
desc = p.get("description_text_full", "No Description") desc = p.get("description_text_full", "No Description")
f.write(f"{idx}. [{code}] Score: {score}\n") f.write(f"{idx}. [{code}] Score: {score}\n")
f.write(f" 💰 Price: {sale} (Orig: {orig}, Disc: {disc}%)\n") f.write(f" 💰 Price: {sale} (Orig: {orig}, Disc: {disc}%)\n")
f.write(f" 📝 Desc: {desc}\n") f.write(f" 📝 Desc: {desc}\n")
f.write(f"{'='*60}\n") f.write(f"{'=' * 60}\n")
print(f"💾 DB Preview (Results) saved to: {preview_path}") print(f"💾 DB Preview (Results) saved to: {preview_path}")
except Exception as e: except Exception as e:
print(f"Save preview log failed: {e}") print(f"Save preview log failed: {e}")
...@@ -29,6 +29,28 @@ except ImportError: ...@@ -29,6 +29,28 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SessionData:
"""Simple Session object để Agno framework có thể access .metadata và .session_data"""
def __init__(self, session_id: str, metadata: Any = None, session_data: Any = None, manager: Any = None):
self.session_id = session_id
self.metadata = metadata # Agno expects this attribute
self.session_data = session_data # Agno expects this attribute
self._manager = manager # Reference to ConversationManager for async operations
def get_messages(self, *args, **kwargs) -> list[Any]:
"""Agno calls this to get messages from session"""
# This is called synchronously but we have async data
# Return empty list - messages will be loaded via load_history
return []
def upsert_run(self, run: Any = None) -> bool:
"""Agno calls this to save run data"""
# This is a sync method, just acknowledge for now
# Actual message saving happens via save_message/save_session
return True
# Use composition instead of inheritance to avoid implementing all BaseDb methods # Use composition instead of inheritance to avoid implementing all BaseDb methods
class ConversationManager: # Don't inherit BaseDb directly class ConversationManager: # Don't inherit BaseDb directly
""" """
...@@ -60,7 +82,9 @@ class ConversationManager: # Don't inherit BaseDb directly ...@@ -60,7 +82,9 @@ class ConversationManager: # Don't inherit BaseDb directly
) )
try: try:
await self._pool.open() await self._pool.open()
logger.info(f"✅ PostgreSQL connection pool opened: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}") logger.info(
f"✅ PostgreSQL connection pool opened: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}"
)
except Exception as e: except Exception as e:
logger.error(f"❌ Failed to open PostgreSQL pool: {e}") logger.error(f"❌ Failed to open PostgreSQL pool: {e}")
self._pool = None self._pool = None
...@@ -72,7 +96,7 @@ class ConversationManager: # Don't inherit BaseDb directly ...@@ -72,7 +96,7 @@ class ConversationManager: # Don't inherit BaseDb directly
try: try:
logger.info(f"🔌 Initializing PostgreSQL table: {self.table_name}") logger.info(f"🔌 Initializing PostgreSQL table: {self.table_name}")
pool = await self._get_pool() pool = await self._get_pool()
# Use connection với timeout ngắn hơn # Use connection với timeout ngắn hơn
async with pool.connection(timeout=5.0) as conn: # 5s timeout cho connection async with pool.connection(timeout=5.0) as conn: # 5s timeout cho connection
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
...@@ -90,10 +114,10 @@ class ConversationManager: # Don't inherit BaseDb directly ...@@ -90,10 +114,10 @@ class ConversationManager: # Don't inherit BaseDb directly
await cursor.execute( await cursor.execute(
sql.SQL(""" sql.SQL("""
CREATE INDEX IF NOT EXISTS idx_{}_user_timestamp CREATE INDEX IF NOT EXISTS {}
ON {} (user_id, timestamp) ON {} (user_id, timestamp)
""").format( """).format(
sql.Identifier(self.table_name), sql.Identifier(f"idx_{self.table_name}_user_timestamp"),
sql.Identifier(self.table_name), sql.Identifier(self.table_name),
) )
) )
...@@ -101,7 +125,9 @@ class ConversationManager: # Don't inherit BaseDb directly ...@@ -101,7 +125,9 @@ class ConversationManager: # Don't inherit BaseDb directly
logger.info(f"✅ Table {self.table_name} initialized successfully") logger.info(f"✅ Table {self.table_name} initialized successfully")
except Exception as e: except Exception as e:
logger.error(f"❌ Error initializing table: {e}") logger.error(f"❌ Error initializing table: {e}")
logger.error(f" Connection URL: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}") logger.error(
f" Connection URL: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}"
)
raise raise
async def save_conversation_turn(self, user_id: str, human_message: str, ai_message: str): async def save_conversation_turn(self, user_id: str, human_message: str, ai_message: str):
...@@ -182,9 +208,7 @@ class ConversationManager: # Don't inherit BaseDb directly ...@@ -182,9 +208,7 @@ class ConversationManager: # Don't inherit BaseDb directly
async with pool.connection() as conn: async with pool.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
sql.SQL("DELETE FROM {} WHERE user_id = %s").format( sql.SQL("DELETE FROM {} WHERE user_id = %s").format(sql.Identifier(self.table_name)),
sql.Identifier(self.table_name)
),
(user_id,), (user_id,),
) )
await conn.commit() await conn.commit()
...@@ -198,9 +222,7 @@ class ConversationManager: # Don't inherit BaseDb directly ...@@ -198,9 +222,7 @@ class ConversationManager: # Don't inherit BaseDb directly
pool = await self._get_pool() pool = await self._get_pool()
async with pool.connection() as conn, conn.cursor() as cursor: async with pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
sql.SQL("SELECT COUNT(DISTINCT user_id) FROM {}").format( sql.SQL("SELECT COUNT(DISTINCT user_id) FROM {}").format(sql.Identifier(self.table_name))
sql.Identifier(self.table_name)
)
) )
result = await cursor.fetchone() result = await cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0
...@@ -334,6 +356,50 @@ class ConversationManager: # Don't inherit BaseDb directly ...@@ -334,6 +356,50 @@ class ConversationManager: # Don't inherit BaseDb directly
"""Agno interface: Clear session (alias của clear_history)""" """Agno interface: Clear session (alias của clear_history)"""
return await self.clear_history(session_id) return await self.clear_history(session_id)
def get_session(self, session_id: str, session_type: str = "default"):
"""
Agno interface: Get session data (SYNC method - Agno calls this synchronously).
Returns SessionData object with required attributes for Agno framework.
"""
try:
# Return SessionData object with required attributes: metadata, session_data
session = SessionData(
session_id=session_id,
metadata=None,
session_data={"session_type": session_type, "created_at": datetime.now()},
manager=self,
)
logger.debug(f"📋 [Agno] Get session: {session_id}")
return session
except Exception as e:
logger.error(f"❌ [Agno] Error getting session {session_id}: {e}")
return None
def upsert_session(self, session: Any):
"""
Agno interface: Save/update session (SYNC method - Agno calls this synchronously).
This is a placeholder since actual message saving happens via save_message/save_session.
Args:
session: SessionData object or dict with 'session_id' key
"""
try:
# Handle both SessionData object and dict
if isinstance(session, SessionData):
session_id = session.session_id
else:
session_id = session.get("session_id") if isinstance(session, dict) else None
if not session_id:
logger.error("❌ [Agno] upsert_session: session_id is required")
return False
logger.debug(f"💾 [Agno] Upserted session {session_id}")
return True
except Exception as e:
logger.error(f"❌ [Agno] Error upserting session: {e}", exc_info=True)
return False
# ConversationManager implements BaseDb interface methods # ConversationManager implements BaseDb interface methods
# but doesn't inherit BaseDb to avoid implementing all abstract methods # but doesn't inherit BaseDb to avoid implementing all abstract methods
......
.\.venv\Scripts\activate .\.venv\Scripts\activate
uvicorn server:app --host 0.0.0.0 --port 5000 --reload uvicorn server:app --host 0.0.0.0 --port 5001 --reload
uvicorn server:app --host 0.0.0.0 --port 5000 uvicorn server:app --host 0.0.0.0 --port 5001
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment