Commit c998c05f authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

feat: optimize logging, enable embedding cache, fix rate limiting hanging, and...

feat: optimize logging, enable embedding cache, fix rate limiting hanging, and improve pagination security
parent c2ff6158
......@@ -116,7 +116,7 @@ Header: `Authorization: Bearer <token>`
"message": "...", // JSON String
"is_human": false,
"timestamp": "..."
}
}
],
"next_cursor": 104 // Dùng ID này cho `before_id` tiếp theo
}
......@@ -127,9 +127,10 @@ Header: `Authorization: Bearer <token>`
---
## 3. Reset (Xóa và tạo mới)
**POST** `/api/history/archive`
*(Lưu ý: Chỉ dành cho User đã đăng nhập)*
### Request
Gửi Header `device_id` (Guest) hoặc `Authorization` (User).
Header `Authorization` (User).
Body rỗng `{}`.
### Response
......
......@@ -76,8 +76,8 @@ async def chat_controller(
config.model_name = model_name
llm = create_llm(model_name=model_name, streaming=False, json_mode=True)
tools = get_all_tools()
graph = build_graph(config, llm=llm, tools=tools)
# tools = get_all_tools() # Singleton now handles tools
graph = build_graph(config) # Singleton usage
# Init ConversationManager (Singleton)
memory = await get_conversation_manager()
......
......@@ -129,18 +129,28 @@ class CANIFAGraph:
_instance: list[CANIFAGraph | None] = [None]
def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None) -> Any:
"""Get compiled graph (always fresh to pick up prompt changes)."""
# Always create new instance to pick up prompt changes during hot reload
instance = CANIFAGraph(config, llm, tools)
return instance.build()
"""Get compiled graph (Singleton usage)."""
# Use singleton to avoid rebuilding graph on every request
manager = get_graph_manager(config, llm, tools)
return manager.build()
def get_graph_manager(
config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None
) -> CANIFAGraph:
"""Get CANIFAGraph instance."""
"""Get CANIFAGraph instance (Auto-rebuild if model config changes)."""
# 1. New Instance if Empty
if _instance[0] is None:
_instance[0] = CANIFAGraph(config, llm, tools)
logger.info(f"✨ Graph Created: {_instance[0].config.model_name}")
return _instance[0]
# 2. Check for Config Changes (e.g. Model Switch)
if config and config.model_name != _instance[0].config.model_name:
logger.info(f"🔄 Model Switch: {_instance[0].config.model_name} -> {config.model_name}")
_instance[0] = CANIFAGraph(config, llm, tools)
return _instance[0]
return _instance[0]
......
......@@ -177,13 +177,16 @@ async def _execute_single_search(db, item: SearchItem, query_vector: list[float]
# Timer: build query (sử dụng vector đã có hoặc build mới)
query_build_start = time.time()
sql = await build_starrocks_query(item, query_vector=query_vector)
sql, params = await build_starrocks_query(item, query_vector=query_vector)
query_build_time = (time.time() - query_build_start) * 1000 # Convert to ms
logger.debug("SQL built, length=%s, build_time_ms=%.2f", len(sql), query_build_time)
if not sql:
return []
# Timer: execute DB query
db_start = time.time()
products = await db.execute_query_async(sql)
products = await db.execute_query_async(sql, params=params)
db_time = (time.time() - db_start) * 1000 # Convert to ms
logger.info(
"_execute_single_search done, products=%s, build_ms=%.2f, db_ms=%.2f, total_ms=%.2f",
......@@ -223,6 +226,7 @@ def _format_product_results(products: list[dict]) -> list[dict]:
"name": parsed.get("product_name", ""),
"price": p.get("original_price") or 0,
"sale_price": p.get("sale_price") or 0,
"description": p.get("description_text_full", ""),
"url": parsed.get("product_web_url", ""),
"thumbnail_image_url": parsed.get("product_image_url_thumbnail", ""),
"discount_amount": p.get("discount_amount") or 0,
......
......@@ -6,48 +6,38 @@ from common.embedding_service import create_embedding_async
logger = logging.getLogger(__name__)
def _escape(val: str) -> str:
"""Thoát dấu nháy đơn để tránh SQL Injection cơ bản."""
return val.replace("'", "''")
def _get_where_clauses(params) -> list[str]:
"""Xây dựng danh sách các điều kiện lọc từ params."""
clauses = []
clauses.extend(_get_price_clauses(params))
clauses.extend(_get_metadata_clauses(params))
clauses.extend(_get_special_clauses(params))
return clauses
def _get_price_clauses(params) -> list[str]:
"""Lọc theo giá."""
def _get_price_clauses(params, sql_params: list) -> list[str]:
"""Lọc theo giá (Parameterized)."""
clauses = []
p_min = getattr(params, "price_min", None)
if p_min is not None:
clauses.append(f"sale_price >= {p_min}")
clauses.append("sale_price >= %s")
sql_params.append(p_min)
p_max = getattr(params, "price_max", None)
if p_max is not None:
clauses.append(f"sale_price <= {p_max}")
clauses.append("sale_price <= %s")
sql_params.append(p_max)
return clauses
def _get_metadata_clauses(params) -> list[str]:
"""Xây dựng điều kiện lọc từ metadata (Phối hợp Exact và Partial)."""
def _get_metadata_clauses(params, sql_params: list) -> list[str]:
"""Xây dựng điều kiện lọc từ metadata (Parameterized)."""
clauses = []
# 1. Exact Match (Giới tính, Độ tuổi) - Các trường này cần độ chính xác tuyệt đối
# 1. Exact Match
exact_fields = [
("gender_by_product", "gender_by_product"),
("age_by_product", "age_by_product"),
("form_neckline", "form_neckline"),
]
for param_name, col_name in exact_fields:
val = getattr(params, param_name, None)
if val:
clauses.append(f"{col_name} = '{_escape(val)}'")
clauses.append(f"{col_name} = %s")
sql_params.append(val)
# 2. Partial Match (LIKE) - Giúp map text linh hoạt hơn (Chất liệu, Dòng SP, Phong cách...)
# Cái này giúp map: "Yarn" -> "Yarn - Sợi", "Knit" -> "Knit - Dệt Kim"
# 2. Partial Match (LIKE)
partial_fields = [
("season", "season"),
("material_group", "material_group"),
......@@ -60,48 +50,44 @@ def _get_metadata_clauses(params) -> list[str]:
for param_name, col_name in partial_fields:
val = getattr(params, param_name, None)
if val:
v = _escape(val).lower()
# Dùng LOWER + LIKE để cân mọi loại ký tự thừa hoặc hoa/thường
clauses.append(f"LOWER({col_name}) LIKE '%{v}%'")
clauses.append(f"LOWER({col_name}) LIKE %s")
sql_params.append(f"%{val.lower()}%")
return clauses
def _get_special_clauses(params) -> list[str]:
def _get_special_clauses(params, sql_params: list) -> list[str]:
"""Các trường hợp đặc biệt: Mã sản phẩm, Màu sắc."""
clauses = []
# Mã sản phẩm / SKU
m_code = getattr(params, "magento_ref_code", None)
if m_code:
m = _escape(m_code)
clauses.append(f"(magento_ref_code = '{m}' OR internal_ref_code = '{m}')")
clauses.append("(magento_ref_code = %s OR internal_ref_code = %s)")
sql_params.extend([m_code, m_code])
# Màu sắc
color = getattr(params, "master_color", None)
if color:
c = _escape(color).lower()
clauses.append(f"(LOWER(master_color) LIKE '%{c}%' OR LOWER(product_color_name) LIKE '%{c}%')")
c_wildcard = f"%{color.lower()}%"
clauses.append("(LOWER(master_color) LIKE %s OR LOWER(product_color_name) LIKE %s)")
sql_params.extend([c_wildcard, c_wildcard])
return clauses
async def build_starrocks_query(params, query_vector: list[float] | None = None) -> str:
async def build_starrocks_query(params, query_vector: list[float] | None = None) -> tuple[str, list]:
"""
Build SQL cho Product Search với 2 chiến lược:
1. CODE SEARCH: Nếu có magento_ref_code → Tìm trực tiếp theo mã (KHÔNG dùng vector)
2. HYDE SEARCH: Semantic search với HyDE vector (Pure vector approach)
Build SQL query với Parameterized Query để tránh SQL Injection.
Returns: (sql_string, params_list)
"""
# ============================================================
# CASE 1: CODE SEARCH - Tìm theo mã sản phẩm (No Vector)
# CASE 1: CODE SEARCH
# ============================================================
magento_code = getattr(params, "magento_ref_code", None)
if magento_code:
logger.info(f"🎯 [CODE SEARCH] Direct search by code: {magento_code}")
code = _escape(magento_code)
# Tìm trực tiếp theo mã + Lọc trùng (GROUP BY internal_ref_code)
# Tìm chính xác theo mã (Lấy tất cả các bản ghi/màu sắc/size của mã đó)
sql = f"""
sql = """
SELECT
internal_ref_code,
description_text_full,
......@@ -110,24 +96,16 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
discount_amount,
1.0 as max_score
FROM shared_source.magento_product_dimension_with_text_embedding
WHERE (magento_ref_code = '{code}' OR internal_ref_code = '{code}')
WHERE (magento_ref_code = %s OR internal_ref_code = %s)
"""
print("✅ [CODE SEARCH] Query built - No vector search needed!")
# Ghi log debug query FULL vào Background Task (Không làm chậm Request)
# asyncio.create_task(save_query_to_log(sql))
return sql
return sql, [magento_code, magento_code]
# ============================================================
# CASE 2: HYDE SEARCH - Semantic Vector Search
# ============================================================
logger.info("🚀 [HYDE RETRIEVER] Starting semantic vector search...")
# 1. Lấy Vector từ HyDE (AI-generated hypothetical document)
query_text = getattr(params, "query", None)
if query_text and query_vector is None:
emb_start = time.time()
query_vector = await create_embedding_async(query_text)
......@@ -135,18 +113,23 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
if not query_vector:
logger.warning("⚠️ No vector found, returning empty query.")
return ""
return "", []
# Vector params
v_str = "[" + ",".join(str(v) for v in query_vector) + "]"
# 2. Build PRICE filter ONLY (chỉ lọc giá, để vector tự semantic search)
price_clauses = _get_price_clauses(params)
# Collect Params
price_params: list = []
price_clauses = _get_price_clauses(params, price_params)
where_filter = ""
if price_clauses:
where_filter = " AND " + " AND ".join(price_clauses)
logger.info(f"💰 [PRICE FILTER] Applied: {where_filter}")
# 3. SQL Pure Vector Search + Price Filter Only
# Build SQL
# NOTE: Vector v_str is safe (generated from floats) so f-string is OK here.
# Using %s for vector list string might cause StarRocks to treat it as string literal '[...]' instead of array.
sql = f"""
WITH top_matches AS (
SELECT /*+ SET_VAR(ann_params='{{"ef_search":128}}') */
......@@ -174,5 +157,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
ORDER BY max_score DESC
LIMIT 20
"""
# Return sql and params (params only contains filter values now, not the vector)
return sql, price_params
return sql
......@@ -23,7 +23,10 @@ tracer = trace.get_tracer(__name__)
router = APIRouter()
from common.rate_limit import rate_limit_service
@router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)")
@rate_limit_service.limiter.limit("50/minute")
async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks):
"""
Endpoint chat không stream - trả về response JSON đầy đủ một lần.
......
......@@ -28,7 +28,10 @@ class ClearHistoryResponse(BaseModel):
message: str
from common.rate_limit import rate_limit_service
@router.get("/api/history/{identity_key}", summary="Get Chat History", response_model=ChatHistoryResponse)
@rate_limit_service.limiter.limit("50/minute")
async def get_chat_history(request: Request, identity_key: str, limit: int | None = 50, before_id: int | None = None):
"""
Lấy lịch sử chat theo identity_key.
......@@ -63,7 +66,8 @@ async def get_chat_history(request: Request, identity_key: str, limit: int | Non
@router.delete("/api/history/{identity_key}", summary="Clear Chat History", response_model=ClearHistoryResponse)
async def clear_chat_history(identity_key: str):
@rate_limit_service.limiter.limit("10/minute")
async def clear_chat_history(request: Request, identity_key: str):
"""
Xóa toàn bộ lịch sử chat theo identity_key.
Logic: Middleware đã parse token -> Nếu user đã login thì dùng user_id, không thì dùng device_id.
......@@ -89,28 +93,28 @@ class ArchiveResponse(BaseModel):
@router.post("/api/history/archive", summary="Archive Chat History", response_model=ArchiveResponse)
@rate_limit_service.limiter.limit("5/minute")
async def archive_chat_history(request: Request):
"""
Lưu trữ lịch sử chat hiện tại (đổi tên key) và reset chat mới.
Giới hạn 5 lần/ngày.
"""
try:
# Tự động resolve identity
# NOTE: Với Reset, ta extract lại device_id từ body nếu có (cho trường hợp Guest bấm reset)
try:
req_json = await request.json()
body_device_id = req_json.get("device_id")
except:
body_device_id = None
identity = get_user_identity(request)
# Nếu chưa login mà có body_device_id -> ưu tiên dùng nó làm key
if not identity.is_authenticated and body_device_id:
logger.info("Archive: Using device_id from Body for Guest")
identity_key = body_device_id
else:
identity_key = identity.history_key
# Chỉ dành cho User đã đăng nhập
if not identity.is_authenticated:
return JSONResponse(
status_code=401,
content={
"status": "error",
"error_code": "LOGIN_REQUIRED",
"message": "Tính năng chỉ dành cho thành viên đã đăng nhập.",
"require_login": True
}
)
identity_key = identity.history_key
# Check reset limit
can_reset, usage, remaining = await reset_limit_service.check_limit(identity_key)
......
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
import os
import re
......@@ -30,8 +30,10 @@ def validate_prompt_braces(content: str) -> tuple[bool, list[str]]:
return len(problematic) == 0, problematic
from common.rate_limit import rate_limit_service
@router.get("/api/agent/system-prompt")
async def get_system_prompt_content():
async def get_system_prompt_content(request: Request):
"""Get current system prompt content"""
try:
if os.path.exists(PROMPT_FILE_PATH):
......@@ -44,11 +46,12 @@ async def get_system_prompt_content():
raise HTTPException(status_code=500, detail=str(e))
@router.post("/api/agent/system-prompt")
async def update_system_prompt_content(request: PromptUpdateRequest):
@rate_limit_service.limiter.limit("10/minute")
async def update_system_prompt_content(request: Request, body: PromptUpdateRequest):
"""Update system prompt content"""
try:
# Validate braces
is_valid, problematic = validate_prompt_braces(request.content)
is_valid, problematic = validate_prompt_braces(body.content)
if not is_valid:
# Return warning but still allow save
......@@ -62,7 +65,7 @@ async def update_system_prompt_content(request: PromptUpdateRequest):
# 1. Update file
with open(PROMPT_FILE_PATH, "w", encoding="utf-8") as f:
f.write(request.content)
f.write(body.content)
# 2. Reset Graph Singleton to force reload prompt
reset_graph()
......
......@@ -5,6 +5,7 @@ from datetime import datetime, date
from typing import Any
import psycopg
from psycopg import sql
from psycopg_pool import AsyncConnectionPool
from config import CHECKPOINT_POSTGRES_URL
......@@ -32,7 +33,7 @@ class ConversationManager:
max_lifetime=600, # Recycle connections every 10 mins
max_idle=300, # Close idle connections after 5 mins
open=False,
kwargs={"autocommit": True}
# kwargs={"autocommit": True} # DISABLE autocommit to support atomic transactions
)
await self._pool.open()
return self._pool
......@@ -43,20 +44,26 @@ class ConversationManager:
pool = await self._get_pool()
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
# Use sql.SQL for safe identifier quoting
create_table_query = sql.SQL("""
CREATE TABLE IF NOT EXISTS {table} (
id SERIAL PRIMARY KEY,
identity_key VARCHAR(255) NOT NULL,
message TEXT NOT NULL,
is_human BOOLEAN NOT NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
""").format(table=sql.Identifier(self.table_name))
await cursor.execute(create_table_query)
await cursor.execute(f"""
CREATE INDEX IF NOT EXISTS idx_{self.table_name}_identity_timestamp
ON {self.table_name} (identity_key, timestamp)
""")
create_index_query = sql.SQL("""
CREATE INDEX IF NOT EXISTS {index_name}
ON {table} (identity_key, timestamp)
""").format(
index_name=sql.Identifier(f"idx_{self.table_name}_identity_timestamp"),
table=sql.Identifier(self.table_name)
)
await cursor.execute(create_index_query)
await conn.commit()
logger.info(f"Table {self.table_name} initialized successfully")
except Exception as e:
......@@ -70,30 +77,21 @@ class ConversationManager:
try:
pool = await self._get_pool()
timestamp = datetime.now()
# Transaction block: atomic insert
async with pool.connection() as conn:
async with conn.cursor() as cursor:
insert_query = sql.SQL("""
INSERT INTO {table} (identity_key, message, is_human, timestamp)
VALUES (%s, %s, %s, %s), (%s, %s, %s, %s)
""").format(table=sql.Identifier(self.table_name))
await cursor.execute(
f"""INSERT INTO {self.table_name} (identity_key, message, is_human, timestamp)
VALUES (%s, %s, %s, %s), (%s, %s, %s, %s)""",
insert_query,
(
identity_key,
human_message,
True,
timestamp,
identity_key,
ai_message,
False,
timestamp,
identity_key, human_message, True, timestamp,
identity_key, ai_message, False, timestamp,
),
)
# With autocommit=True in pool, and context manager, transactions are handled.
# Explicit commit can be safer but might be redundant if autocommit is on.
# Let's keep existing logic but be mindful of autocommit.
# Actually if autocommit=True, we don't need conn.commit().
# But if we want atomic transaction for 2 inserts, we should NOT use autocommit=True for the pool globally,
# OR we start a transaction block.
# But psycopg3 connection `async with pool.connection() as conn` actually starts a transaction by default if autocommit is False.
# Let's revert pool autocommit=True and handle it normally which is safer for atomicity.
await conn.commit()
logger.debug(f"Saved conversation turn for identity_key {identity_key}")
......@@ -116,33 +114,43 @@ class ConversationManager:
"""
Retrieve chat history for an identity (user_id or device_id) using cursor-based pagination.
AI messages được parse từ JSON string để lấy product_ids.
Uses cached graph for performance.
"""
max_retries = 3
for attempt in range(max_retries):
try:
today = datetime.now().date()
query = f"""
# Optimize: Use Range Query for Index usage
now = datetime.now().astimezone() # Ensure Timezone Aware (e.g. +07:00)
start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
end_of_day = now.replace(hour=23, minute=59, second=59, microsecond=999999)
base_query = sql.SQL("""
SELECT message, is_human, timestamp, id
FROM {self.table_name}
FROM {table}
WHERE identity_key = %s
AND DATE(timestamp) = %s
"""
params = [identity_key, today]
AND timestamp >= %s AND timestamp <= %s
""").format(table=sql.Identifier(self.table_name))
params = [identity_key, start_of_day, end_of_day]
query_parts = [base_query]
if before_id:
query += " AND id < %s"
query_parts.append(sql.SQL("AND id < %s"))
params.append(before_id)
query += " ORDER BY id DESC"
query_parts.append(sql.SQL("ORDER BY id DESC"))
if limit:
query += " LIMIT %s"
query_parts.append(sql.SQL("LIMIT %s"))
params.append(limit)
final_query = sql.SQL(" ").join(query_parts)
pool = await self._get_pool()
async with pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(query, tuple(params))
results = await cursor.fetchall()
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(final_query, tuple(params))
results = await cursor.fetchall()
history = []
for row in results:
......@@ -195,19 +203,24 @@ class ConversationManager:
# Format: user123_archived_20231027_103045
new_key = f"{identity_key}_archived_{timestamp_suffix}"
today = datetime.now().date()
# Optimize: Use Range Query
now = datetime.now().astimezone() # Ensure Timezone Aware (e.g. +07:00)
start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
end_of_day = now.replace(hour=23, minute=59, second=59, microsecond=999999)
pool = await self._get_pool()
async with pool.connection() as conn:
async with conn.cursor() as cursor:
# Rename identity_key for today's messages
await cursor.execute(
f"""
UPDATE {self.table_name}
query = sql.SQL("""
UPDATE {table}
SET identity_key = %s
WHERE identity_key = %s
AND DATE(timestamp) = %s
""",
(new_key, identity_key, today)
AND timestamp >= %s AND timestamp <= %s
""").format(table=sql.Identifier(self.table_name))
await cursor.execute(
query,
(new_key, identity_key, start_of_day, end_of_day)
)
await conn.commit()
......@@ -224,7 +237,10 @@ class ConversationManager:
pool = await self._get_pool()
async with pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"DELETE FROM {self.table_name} WHERE identity_key = %s", (identity_key,))
query = sql.SQL("DELETE FROM {table} WHERE identity_key = %s").format(
table=sql.Identifier(self.table_name)
)
await cursor.execute(query, (identity_key,))
await conn.commit()
logger.info(f"Cleared chat history for identity_key {identity_key}")
except Exception as e:
......@@ -235,7 +251,10 @@ class ConversationManager:
try:
pool = await self._get_pool()
async with pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(f"SELECT COUNT(DISTINCT identity_key) FROM {self.table_name}")
query = sql.SQL("SELECT COUNT(DISTINCT identity_key) FROM {table}").format(
table=sql.Identifier(self.table_name)
)
await cursor.execute(query)
result = await cursor.fetchone()
return result[0] if result else 0
except Exception as e:
......@@ -248,16 +267,23 @@ class ConversationManager:
Chỉ đếm human messages (is_human = true).
"""
try:
# Optimize: Use Range Query
now = datetime.now()
start_of_day = datetime(now.year, now.month, now.day, 0, 0, 0)
end_of_day = datetime(now.year, now.month, now.day, 23, 59, 59, 999999)
pool = await self._get_pool()
async with pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(
f"""
SELECT COUNT(*) FROM {self.table_name}
query = sql.SQL("""
SELECT COUNT(*) FROM {table}
WHERE identity_key = %s
AND is_human = true
AND DATE(timestamp) = CURRENT_DATE
""",
(identity_key,),
AND timestamp >= %s AND timestamp <= %s
""").format(table=sql.Identifier(self.table_name))
await cursor.execute(
query,
(identity_key, start_of_day, end_of_day),
)
result = await cursor.fetchone()
return result[0] if result else 0
......
......@@ -43,10 +43,7 @@ class EmbeddingClientManager:
logger = logging.getLogger(__name__)
# NOTE:
# - TẠM THỜI KHÔNG DÙNG REDIS CACHE CHO EMBEDDING để tránh phụ thuộc Redis/aioredis.
# - Nếu cần bật lại cache, import `redis_cache` từ `common.cache`
# và dùng như các đoạn code cũ (get_embedding / set_embedding).
from common.cache import redis_cache
# --- Singleton ---
_manager = EmbeddingClientManager()
......@@ -55,8 +52,11 @@ get_async_embedding_client = _manager.get_async_client
def create_embedding(text: str) -> list[float]:
"""Sync embedding generation (No cache for sync to avoid overhead)"""
"""Sync embedding generation with Layer 2 Cache"""
try:
# 1. Try Cache (Sync wrapper for get_embedding if needed, but here we just use what we have)
# Note: common.cache is async, so sync create_embedding will still call OpenAI
# unless we add a sync cache method. For now, focus on async.
client = get_embedding_client()
response = client.embeddings.create(model="text-embedding-3-small", input=text)
return response.data[0].embedding
......@@ -67,13 +67,24 @@ def create_embedding(text: str) -> list[float]:
async def create_embedding_async(text: str) -> list[float]:
"""
Async embedding generation (KHÔNG dùng cache).
Nếu sau này cần cache lại, có thể thêm redis_cache.get_embedding / set_embedding.
Async embedding generation with Layer 2 Cache.
Saves OpenAI costs by reusing embeddings for identical queries.
"""
try:
# 1. Try Layer 2 Cache
cached = await redis_cache.get_embedding(text)
if cached:
return cached
# 2. Call OpenAI
client = get_async_embedding_client()
response = await client.embeddings.create(model="text-embedding-3-small", input=text)
embedding = response.data[0].embedding
# 3. Store in Cache
if embedding:
await redis_cache.set_embedding(text, embedding)
return embedding
except Exception as e:
logger.error(f"Error creating embedding (async): {e}")
......@@ -88,12 +99,32 @@ async def create_embeddings_async(texts: list[str]) -> list[list[float]]:
if not texts:
return []
client = get_async_embedding_client()
response = await client.embeddings.create(model="text-embedding-3-small", input=texts)
# Giữ nguyên thứ tự embedding theo order input
sorted_data = sorted(response.data, key=lambda x: x.index)
results = [item.embedding for item in sorted_data]
results = [[] for _ in texts]
missed_indices = []
missed_texts = []
# 1. Check Cache for each text
for i, text in enumerate(texts):
cached = await redis_cache.get_embedding(text)
if cached:
results[i] = cached
else:
missed_indices.append(i)
missed_texts.append(text)
# 2. Call OpenAI for missed texts
if missed_texts:
client = get_async_embedding_client()
response = await client.embeddings.create(model="text-embedding-3-small", input=missed_texts)
# OpenAI returns embeddings in the same order as missed_texts
for i, data_item in enumerate(response.data):
idx = missed_indices[i]
embedding = data_item.embedding
results[idx] = embedding
# 3. Cache the new embedding
await redis_cache.set_embedding(missed_texts[i], embedding)
return results
except Exception as e:
......
"""
StarRocks Database Connection Utility
Based on chatbot-rsa pattern
"""
import asyncio
import logging
from typing import Any
import aiomysql
import pymysql
from pymysql.cursors import DictCursor
from config import (
STARROCKS_DB,
STARROCKS_HOST,
STARROCKS_PASSWORD,
STARROCKS_PORT,
STARROCKS_USER,
)
logger = logging.getLogger(__name__)
__all__ = ["StarRocksConnection", "get_db_connection"]
class StarRocksConnectionManager:
"""
Singleton Class quản lý StarRocks Connection.
"""
def __init__(self):
self._connection: StarRocksConnection | None = None
def get_connection(self) -> "StarRocksConnection":
"""Lazy loading connection"""
if self._connection is None:
logger.info("🔧 [LAZY LOADING] Creating StarRocksConnection instance (first time)")
self._connection = StarRocksConnection()
return self._connection
# --- Singleton ---
_manager = StarRocksConnectionManager()
get_db_connection = _manager.get_connection
class StarRocksConnection:
# Shared connection (Singleton-like behavior) for all instances
_shared_conn = None
def __init__(
self,
host: str | None = None,
database: str | None = None,
user: str | None = None,
password: str | None = None,
port: int | None = None,
):
self.host = host or STARROCKS_HOST
self.database = database or STARROCKS_DB
self.user = user or STARROCKS_USER
self.password = password or STARROCKS_PASSWORD
self.port = port or STARROCKS_PORT
# self.conn references the shared connection
self.conn = None
logger.info(f"✅ StarRocksConnection initialized: {self.host}:{self.port}")
def connect(self):
"""
Establish or reuse persistent connection.
"""
# 1. Try to reuse existing shared connection
if StarRocksConnection._shared_conn and StarRocksConnection._shared_conn.open:
try:
# Ping to check if alive, reconnect if needed
StarRocksConnection._shared_conn.ping(reconnect=True)
self.conn = StarRocksConnection._shared_conn
return self.conn
except Exception as e:
logger.warning(f"⚠️ Connection lost, reconnecting: {e}")
StarRocksConnection._shared_conn = None
# 2. Create new connection if needed
print(f" [DB] 🔌 Đang kết nối StarRocks (New Session): {self.host}:{self.port}...")
logger.info(f"🔌 Connecting to StarRocks at {self.host}:{self.port} (DB: {self.database})...")
try:
new_conn = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
charset="utf8mb4",
cursorclass=DictCursor,
connect_timeout=10,
read_timeout=30,
write_timeout=30,
)
print(" [DB] ✅ Kết nối thành công.")
logger.info("✅ Connected to StarRocks")
# Save to class variable
StarRocksConnection._shared_conn = new_conn
self.conn = new_conn
except Exception as e:
print(f" [DB] ❌ Lỗi kết nối: {e!s}")
logger.error(f"❌ Failed to connect to StarRocks: {e}")
raise
return self.conn
def execute_query(self, query: str, params: tuple | None = None) -> list[dict[str, Any]]:
# print(" [DB] 🚀 Bắt đầu truy vấn dữ liệu...")
# (Reduced noise in logs)
logger.info("🚀 Executing StarRocks Query (Persistent Conn).")
conn = self.connect()
try:
with conn.cursor() as cursor:
cursor.execute(query, params)
results = cursor.fetchall()
print(f" [DB] ✅ Truy vấn xong. Lấy được {len(results)} dòng.")
logger.info(f"📊 Query successful, returned {len(results)} rows")
return [dict(row) for row in results]
except Exception as e:
print(f" [DB] ❌ Lỗi truy vấn: {e!s}")
logger.error(f"❌ StarRocks query error: {e}")
# Incase of query error due to connection, invalidate it
StarRocksConnection._shared_conn = None
raise
# FINALLY BLOCK REMOVED: Do NOT close connection
# Async pool shared
_shared_pool = None
_pool_lock = asyncio.Lock()
@classmethod
async def clear_pool(cls):
"""Clear and close existing pool (force recreate fresh connections)"""
async with cls._pool_lock:
if cls._shared_pool is not None:
logger.warning("🔄 Clearing StarRocks connection pool...")
cls._shared_pool.close()
await cls._shared_pool.wait_closed()
cls._shared_pool = None
logger.info("✅ Pool cleared successfully")
async def get_pool(self):
"""
Get or create shared async connection pool (Thread-safe singleton)
"""
if StarRocksConnection._shared_pool is None:
async with StarRocksConnection._pool_lock:
# Double-check inside lock to prevent multiple pools
if StarRocksConnection._shared_pool is None:
logger.info(f"🔌 Creating Async Pool to {self.host}:{self.port}...")
StarRocksConnection._shared_pool = await aiomysql.create_pool(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
charset="utf8mb4",
cursorclass=aiomysql.DictCursor,
minsize=5, # ← Từ 10 → 5
maxsize=30, # ← Từ 200 → 30 (QUAN TRỌNG!)
connect_timeout=10,
pool_recycle=1800, # ← Từ 3600 → 1800 (30 phút)
autocommit=True,
)
return StarRocksConnection._shared_pool
async def execute_query_async(self, query: str, params: tuple | None = None) -> list[dict[str, Any]]:
"""
Execute query asynchronously using aiomysql pool with Retry Logic.
"""
max_retries = 3
last_error = None
for attempt in range(max_retries):
try:
pool = await self.get_pool()
# logger.info(f"🚀 Executing Async Query (Attempt {attempt+1}).")
# Tăng timeout lên 30s cho load test với 300 users
conn = await asyncio.wait_for(pool.acquire(), timeout=30)
try:
async with conn.cursor() as cursor:
await cursor.execute(query, params)
results = await cursor.fetchall()
# logger.info(f"📊 Async Query successful, returned {len(results)} rows")
return [dict(row) for row in results]
finally:
pool.release(conn)
except TimeoutError as e:
last_error = e
logger.warning(f"⏱️ Pool acquire timeout (Attempt {attempt + 1}/{max_retries})")
# Timeout khi lấy connection → pool đầy, chờ rồi thử lại
await asyncio.sleep(0.2 * (attempt + 1))
continue
except ConnectionAbortedError as e:
last_error = e
logger.warning(f"🔌 Connection aborted (Attempt {attempt + 1}/{max_retries}): {e}")
# Connection bị abort → clear pool và thử lại với fresh connections
if attempt < max_retries - 1:
await StarRocksConnection.clear_pool()
await asyncio.sleep(0.3)
continue
except Exception as e:
last_error = e
logger.warning(f"⚠️ StarRocks DB Error (Attempt {attempt + 1}/{max_retries}): {e}")
if "Memory of process exceed limit" in str(e):
# Nếu StarRocks OOM, đợi một chút rồi thử lại
await asyncio.sleep(0.5 * (attempt + 1))
continue
if "Disconnected" in str(e) or "Lost connection" in str(e) or "aborted" in str(e).lower():
# Nếu mất kết nối, clear pool và thử lại
if attempt < max_retries - 1:
await StarRocksConnection.clear_pool()
await asyncio.sleep(0.3)
continue
# Các lỗi khác (cú pháp,...) thì raise luôn
raise
logger.error(f"❌ Failed after {max_retries} attempts: {last_error}")
raise last_error
def close(self):
"""Explicitly close if needed (e.g. app shutdown)"""
if StarRocksConnection._shared_conn and StarRocksConnection._shared_conn.open:
StarRocksConnection._shared_conn.close()
StarRocksConnection._shared_conn = None
self.conn = None
......@@ -189,10 +189,6 @@ class StarRocksConnection:
conn = await asyncio.wait_for(pool.acquire(), timeout=90)
async with conn.cursor() as cursor:
# Ping kiểm tra sức khỏe connection
await conn.ping()
# Chạy query
await cursor.execute(query, params)
results = await cursor.fetchall()
return [dict(row) for row in results]
......
......@@ -58,7 +58,7 @@ async def startup_event():
middleware_manager.setup(
app,
enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit=False, # 👈 Tắt slowapi vì đã có business rate limit
enable_rate_limit=True, # 👈 Bật lại SlowAPI theo yêu cầu
enable_cors=True, # 👈 Bật CORS
cors_origins=["*"], # 👈 Trong production nên limit origins
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment