Commit c998c05f authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

feat: optimize logging, enable embedding cache, fix rate limiting hanging, and...

feat: optimize logging, enable embedding cache, fix rate limiting hanging, and improve pagination security
parent c2ff6158
...@@ -127,9 +127,10 @@ Header: `Authorization: Bearer <token>` ...@@ -127,9 +127,10 @@ Header: `Authorization: Bearer <token>`
--- ---
## 3. Reset (Xóa và tạo mới) ## 3. Reset (Xóa và tạo mới)
**POST** `/api/history/archive` **POST** `/api/history/archive`
*(Lưu ý: Chỉ dành cho User đã đăng nhập)*
### Request ### Request
Gửi Header `device_id` (Guest) hoặc `Authorization` (User). Header `Authorization` (User).
Body rỗng `{}`. Body rỗng `{}`.
### Response ### Response
......
...@@ -76,8 +76,8 @@ async def chat_controller( ...@@ -76,8 +76,8 @@ async def chat_controller(
config.model_name = model_name config.model_name = model_name
llm = create_llm(model_name=model_name, streaming=False, json_mode=True) llm = create_llm(model_name=model_name, streaming=False, json_mode=True)
tools = get_all_tools() # tools = get_all_tools() # Singleton now handles tools
graph = build_graph(config, llm=llm, tools=tools) graph = build_graph(config) # Singleton usage
# Init ConversationManager (Singleton) # Init ConversationManager (Singleton)
memory = await get_conversation_manager() memory = await get_conversation_manager()
......
...@@ -129,18 +129,28 @@ class CANIFAGraph: ...@@ -129,18 +129,28 @@ class CANIFAGraph:
_instance: list[CANIFAGraph | None] = [None] _instance: list[CANIFAGraph | None] = [None]
def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None) -> Any: def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None) -> Any:
"""Get compiled graph (always fresh to pick up prompt changes).""" """Get compiled graph (Singleton usage)."""
# Always create new instance to pick up prompt changes during hot reload # Use singleton to avoid rebuilding graph on every request
instance = CANIFAGraph(config, llm, tools) manager = get_graph_manager(config, llm, tools)
return instance.build() return manager.build()
def get_graph_manager( def get_graph_manager(
config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None
) -> CANIFAGraph: ) -> CANIFAGraph:
"""Get CANIFAGraph instance.""" """Get CANIFAGraph instance (Auto-rebuild if model config changes)."""
# 1. New Instance if Empty
if _instance[0] is None: if _instance[0] is None:
_instance[0] = CANIFAGraph(config, llm, tools) _instance[0] = CANIFAGraph(config, llm, tools)
logger.info(f"✨ Graph Created: {_instance[0].config.model_name}")
return _instance[0]
# 2. Check for Config Changes (e.g. Model Switch)
if config and config.model_name != _instance[0].config.model_name:
logger.info(f"🔄 Model Switch: {_instance[0].config.model_name} -> {config.model_name}")
_instance[0] = CANIFAGraph(config, llm, tools)
return _instance[0]
return _instance[0] return _instance[0]
......
...@@ -177,13 +177,16 @@ async def _execute_single_search(db, item: SearchItem, query_vector: list[float] ...@@ -177,13 +177,16 @@ async def _execute_single_search(db, item: SearchItem, query_vector: list[float]
# Timer: build query (sử dụng vector đã có hoặc build mới) # Timer: build query (sử dụng vector đã có hoặc build mới)
query_build_start = time.time() query_build_start = time.time()
sql = await build_starrocks_query(item, query_vector=query_vector) sql, params = await build_starrocks_query(item, query_vector=query_vector)
query_build_time = (time.time() - query_build_start) * 1000 # Convert to ms query_build_time = (time.time() - query_build_start) * 1000 # Convert to ms
logger.debug("SQL built, length=%s, build_time_ms=%.2f", len(sql), query_build_time) logger.debug("SQL built, length=%s, build_time_ms=%.2f", len(sql), query_build_time)
if not sql:
return []
# Timer: execute DB query # Timer: execute DB query
db_start = time.time() db_start = time.time()
products = await db.execute_query_async(sql) products = await db.execute_query_async(sql, params=params)
db_time = (time.time() - db_start) * 1000 # Convert to ms db_time = (time.time() - db_start) * 1000 # Convert to ms
logger.info( logger.info(
"_execute_single_search done, products=%s, build_ms=%.2f, db_ms=%.2f, total_ms=%.2f", "_execute_single_search done, products=%s, build_ms=%.2f, db_ms=%.2f, total_ms=%.2f",
...@@ -223,6 +226,7 @@ def _format_product_results(products: list[dict]) -> list[dict]: ...@@ -223,6 +226,7 @@ def _format_product_results(products: list[dict]) -> list[dict]:
"name": parsed.get("product_name", ""), "name": parsed.get("product_name", ""),
"price": p.get("original_price") or 0, "price": p.get("original_price") or 0,
"sale_price": p.get("sale_price") or 0, "sale_price": p.get("sale_price") or 0,
"description": p.get("description_text_full", ""),
"url": parsed.get("product_web_url", ""), "url": parsed.get("product_web_url", ""),
"thumbnail_image_url": parsed.get("product_image_url_thumbnail", ""), "thumbnail_image_url": parsed.get("product_image_url_thumbnail", ""),
"discount_amount": p.get("discount_amount") or 0, "discount_amount": p.get("discount_amount") or 0,
......
...@@ -6,48 +6,38 @@ from common.embedding_service import create_embedding_async ...@@ -6,48 +6,38 @@ from common.embedding_service import create_embedding_async
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _escape(val: str) -> str:
"""Thoát dấu nháy đơn để tránh SQL Injection cơ bản."""
return val.replace("'", "''")
def _get_price_clauses(params, sql_params: list) -> list[str]:
def _get_where_clauses(params) -> list[str]: """Lọc theo giá (Parameterized)."""
"""Xây dựng danh sách các điều kiện lọc từ params."""
clauses = []
clauses.extend(_get_price_clauses(params))
clauses.extend(_get_metadata_clauses(params))
clauses.extend(_get_special_clauses(params))
return clauses
def _get_price_clauses(params) -> list[str]:
"""Lọc theo giá."""
clauses = [] clauses = []
p_min = getattr(params, "price_min", None) p_min = getattr(params, "price_min", None)
if p_min is not None: if p_min is not None:
clauses.append(f"sale_price >= {p_min}") clauses.append("sale_price >= %s")
sql_params.append(p_min)
p_max = getattr(params, "price_max", None) p_max = getattr(params, "price_max", None)
if p_max is not None: if p_max is not None:
clauses.append(f"sale_price <= {p_max}") clauses.append("sale_price <= %s")
sql_params.append(p_max)
return clauses return clauses
def _get_metadata_clauses(params) -> list[str]: def _get_metadata_clauses(params, sql_params: list) -> list[str]:
"""Xây dựng điều kiện lọc từ metadata (Phối hợp Exact và Partial).""" """Xây dựng điều kiện lọc từ metadata (Parameterized)."""
clauses = [] clauses = []
# 1. Exact Match (Giới tính, Độ tuổi) - Các trường này cần độ chính xác tuyệt đối # 1. Exact Match
exact_fields = [ exact_fields = [
("gender_by_product", "gender_by_product"), ("gender_by_product", "gender_by_product"),
("age_by_product", "age_by_product"), ("age_by_product", "age_by_product"),
("form_neckline", "form_neckline"),
] ]
for param_name, col_name in exact_fields: for param_name, col_name in exact_fields:
val = getattr(params, param_name, None) val = getattr(params, param_name, None)
if val: if val:
clauses.append(f"{col_name} = '{_escape(val)}'") clauses.append(f"{col_name} = %s")
sql_params.append(val)
# 2. Partial Match (LIKE) - Giúp map text linh hoạt hơn (Chất liệu, Dòng SP, Phong cách...) # 2. Partial Match (LIKE)
# Cái này giúp map: "Yarn" -> "Yarn - Sợi", "Knit" -> "Knit - Dệt Kim"
partial_fields = [ partial_fields = [
("season", "season"), ("season", "season"),
("material_group", "material_group"), ("material_group", "material_group"),
...@@ -60,48 +50,44 @@ def _get_metadata_clauses(params) -> list[str]: ...@@ -60,48 +50,44 @@ def _get_metadata_clauses(params) -> list[str]:
for param_name, col_name in partial_fields: for param_name, col_name in partial_fields:
val = getattr(params, param_name, None) val = getattr(params, param_name, None)
if val: if val:
v = _escape(val).lower() clauses.append(f"LOWER({col_name}) LIKE %s")
# Dùng LOWER + LIKE để cân mọi loại ký tự thừa hoặc hoa/thường sql_params.append(f"%{val.lower()}%")
clauses.append(f"LOWER({col_name}) LIKE '%{v}%'")
return clauses return clauses
def _get_special_clauses(params) -> list[str]: def _get_special_clauses(params, sql_params: list) -> list[str]:
"""Các trường hợp đặc biệt: Mã sản phẩm, Màu sắc.""" """Các trường hợp đặc biệt: Mã sản phẩm, Màu sắc."""
clauses = [] clauses = []
# Mã sản phẩm / SKU # Mã sản phẩm / SKU
m_code = getattr(params, "magento_ref_code", None) m_code = getattr(params, "magento_ref_code", None)
if m_code: if m_code:
m = _escape(m_code) clauses.append("(magento_ref_code = %s OR internal_ref_code = %s)")
clauses.append(f"(magento_ref_code = '{m}' OR internal_ref_code = '{m}')") sql_params.extend([m_code, m_code])
# Màu sắc # Màu sắc
color = getattr(params, "master_color", None) color = getattr(params, "master_color", None)
if color: if color:
c = _escape(color).lower() c_wildcard = f"%{color.lower()}%"
clauses.append(f"(LOWER(master_color) LIKE '%{c}%' OR LOWER(product_color_name) LIKE '%{c}%')") clauses.append("(LOWER(master_color) LIKE %s OR LOWER(product_color_name) LIKE %s)")
sql_params.extend([c_wildcard, c_wildcard])
return clauses return clauses
async def build_starrocks_query(params, query_vector: list[float] | None = None) -> str: async def build_starrocks_query(params, query_vector: list[float] | None = None) -> tuple[str, list]:
""" """
Build SQL cho Product Search với 2 chiến lược: Build SQL query với Parameterized Query để tránh SQL Injection.
1. CODE SEARCH: Nếu có magento_ref_code → Tìm trực tiếp theo mã (KHÔNG dùng vector) Returns: (sql_string, params_list)
2. HYDE SEARCH: Semantic search với HyDE vector (Pure vector approach)
""" """
# ============================================================ # ============================================================
# CASE 1: CODE SEARCH - Tìm theo mã sản phẩm (No Vector) # CASE 1: CODE SEARCH
# ============================================================ # ============================================================
magento_code = getattr(params, "magento_ref_code", None) magento_code = getattr(params, "magento_ref_code", None)
if magento_code: if magento_code:
logger.info(f"🎯 [CODE SEARCH] Direct search by code: {magento_code}") logger.info(f"🎯 [CODE SEARCH] Direct search by code: {magento_code}")
code = _escape(magento_code)
# Tìm trực tiếp theo mã + Lọc trùng (GROUP BY internal_ref_code) sql = """
# Tìm chính xác theo mã (Lấy tất cả các bản ghi/màu sắc/size của mã đó)
sql = f"""
SELECT SELECT
internal_ref_code, internal_ref_code,
description_text_full, description_text_full,
...@@ -110,24 +96,16 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None) ...@@ -110,24 +96,16 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
discount_amount, discount_amount,
1.0 as max_score 1.0 as max_score
FROM shared_source.magento_product_dimension_with_text_embedding FROM shared_source.magento_product_dimension_with_text_embedding
WHERE (magento_ref_code = '{code}' OR internal_ref_code = '{code}') WHERE (magento_ref_code = %s OR internal_ref_code = %s)
""" """
return sql, [magento_code, magento_code]
print("✅ [CODE SEARCH] Query built - No vector search needed!")
# Ghi log debug query FULL vào Background Task (Không làm chậm Request)
# asyncio.create_task(save_query_to_log(sql))
return sql
# ============================================================ # ============================================================
# CASE 2: HYDE SEARCH - Semantic Vector Search # CASE 2: HYDE SEARCH - Semantic Vector Search
# ============================================================ # ============================================================
logger.info("🚀 [HYDE RETRIEVER] Starting semantic vector search...") logger.info("🚀 [HYDE RETRIEVER] Starting semantic vector search...")
# 1. Lấy Vector từ HyDE (AI-generated hypothetical document)
query_text = getattr(params, "query", None) query_text = getattr(params, "query", None)
if query_text and query_vector is None: if query_text and query_vector is None:
emb_start = time.time() emb_start = time.time()
query_vector = await create_embedding_async(query_text) query_vector = await create_embedding_async(query_text)
...@@ -135,18 +113,23 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None) ...@@ -135,18 +113,23 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
if not query_vector: if not query_vector:
logger.warning("⚠️ No vector found, returning empty query.") logger.warning("⚠️ No vector found, returning empty query.")
return "" return "", []
# Vector params
v_str = "[" + ",".join(str(v) for v in query_vector) + "]" v_str = "[" + ",".join(str(v) for v in query_vector) + "]"
# 2. Build PRICE filter ONLY (chỉ lọc giá, để vector tự semantic search) # Collect Params
price_clauses = _get_price_clauses(params) price_params: list = []
price_clauses = _get_price_clauses(params, price_params)
where_filter = "" where_filter = ""
if price_clauses: if price_clauses:
where_filter = " AND " + " AND ".join(price_clauses) where_filter = " AND " + " AND ".join(price_clauses)
logger.info(f"💰 [PRICE FILTER] Applied: {where_filter}") logger.info(f"💰 [PRICE FILTER] Applied: {where_filter}")
# 3. SQL Pure Vector Search + Price Filter Only # Build SQL
# NOTE: Vector v_str is safe (generated from floats) so f-string is OK here.
# Using %s for vector list string might cause StarRocks to treat it as string literal '[...]' instead of array.
sql = f""" sql = f"""
WITH top_matches AS ( WITH top_matches AS (
SELECT /*+ SET_VAR(ann_params='{{"ef_search":128}}') */ SELECT /*+ SET_VAR(ann_params='{{"ef_search":128}}') */
...@@ -175,4 +158,6 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None) ...@@ -175,4 +158,6 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
LIMIT 20 LIMIT 20
""" """
return sql # Return sql and params (params only contains filter values now, not the vector)
return sql, price_params
...@@ -23,7 +23,10 @@ tracer = trace.get_tracer(__name__) ...@@ -23,7 +23,10 @@ tracer = trace.get_tracer(__name__)
router = APIRouter() router = APIRouter()
from common.rate_limit import rate_limit_service
@router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)") @router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)")
@rate_limit_service.limiter.limit("50/minute")
async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks): async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks):
""" """
Endpoint chat không stream - trả về response JSON đầy đủ một lần. Endpoint chat không stream - trả về response JSON đầy đủ một lần.
......
...@@ -28,7 +28,10 @@ class ClearHistoryResponse(BaseModel): ...@@ -28,7 +28,10 @@ class ClearHistoryResponse(BaseModel):
message: str message: str
from common.rate_limit import rate_limit_service
@router.get("/api/history/{identity_key}", summary="Get Chat History", response_model=ChatHistoryResponse) @router.get("/api/history/{identity_key}", summary="Get Chat History", response_model=ChatHistoryResponse)
@rate_limit_service.limiter.limit("50/minute")
async def get_chat_history(request: Request, identity_key: str, limit: int | None = 50, before_id: int | None = None): async def get_chat_history(request: Request, identity_key: str, limit: int | None = 50, before_id: int | None = None):
""" """
Lấy lịch sử chat theo identity_key. Lấy lịch sử chat theo identity_key.
...@@ -63,7 +66,8 @@ async def get_chat_history(request: Request, identity_key: str, limit: int | Non ...@@ -63,7 +66,8 @@ async def get_chat_history(request: Request, identity_key: str, limit: int | Non
@router.delete("/api/history/{identity_key}", summary="Clear Chat History", response_model=ClearHistoryResponse) @router.delete("/api/history/{identity_key}", summary="Clear Chat History", response_model=ClearHistoryResponse)
async def clear_chat_history(identity_key: str): @rate_limit_service.limiter.limit("10/minute")
async def clear_chat_history(request: Request, identity_key: str):
""" """
Xóa toàn bộ lịch sử chat theo identity_key. Xóa toàn bộ lịch sử chat theo identity_key.
Logic: Middleware đã parse token -> Nếu user đã login thì dùng user_id, không thì dùng device_id. Logic: Middleware đã parse token -> Nếu user đã login thì dùng user_id, không thì dùng device_id.
...@@ -89,27 +93,27 @@ class ArchiveResponse(BaseModel): ...@@ -89,27 +93,27 @@ class ArchiveResponse(BaseModel):
@router.post("/api/history/archive", summary="Archive Chat History", response_model=ArchiveResponse) @router.post("/api/history/archive", summary="Archive Chat History", response_model=ArchiveResponse)
@rate_limit_service.limiter.limit("5/minute")
async def archive_chat_history(request: Request): async def archive_chat_history(request: Request):
""" """
Lưu trữ lịch sử chat hiện tại (đổi tên key) và reset chat mới. Lưu trữ lịch sử chat hiện tại (đổi tên key) và reset chat mới.
Giới hạn 5 lần/ngày. Giới hạn 5 lần/ngày.
""" """
try: try:
# Tự động resolve identity
# NOTE: Với Reset, ta extract lại device_id từ body nếu có (cho trường hợp Guest bấm reset)
try:
req_json = await request.json()
body_device_id = req_json.get("device_id")
except:
body_device_id = None
identity = get_user_identity(request) identity = get_user_identity(request)
# Nếu chưa login mà có body_device_id -> ưu tiên dùng nó làm key # Chỉ dành cho User đã đăng nhập
if not identity.is_authenticated and body_device_id: if not identity.is_authenticated:
logger.info("Archive: Using device_id from Body for Guest") return JSONResponse(
identity_key = body_device_id status_code=401,
else: content={
"status": "error",
"error_code": "LOGIN_REQUIRED",
"message": "Tính năng chỉ dành cho thành viên đã đăng nhập.",
"require_login": True
}
)
identity_key = identity.history_key identity_key = identity.history_key
# Check reset limit # Check reset limit
......
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel from pydantic import BaseModel
import os import os
import re import re
...@@ -30,8 +30,10 @@ def validate_prompt_braces(content: str) -> tuple[bool, list[str]]: ...@@ -30,8 +30,10 @@ def validate_prompt_braces(content: str) -> tuple[bool, list[str]]:
return len(problematic) == 0, problematic return len(problematic) == 0, problematic
from common.rate_limit import rate_limit_service
@router.get("/api/agent/system-prompt") @router.get("/api/agent/system-prompt")
async def get_system_prompt_content(): async def get_system_prompt_content(request: Request):
"""Get current system prompt content""" """Get current system prompt content"""
try: try:
if os.path.exists(PROMPT_FILE_PATH): if os.path.exists(PROMPT_FILE_PATH):
...@@ -44,11 +46,12 @@ async def get_system_prompt_content(): ...@@ -44,11 +46,12 @@ async def get_system_prompt_content():
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/api/agent/system-prompt") @router.post("/api/agent/system-prompt")
async def update_system_prompt_content(request: PromptUpdateRequest): @rate_limit_service.limiter.limit("10/minute")
async def update_system_prompt_content(request: Request, body: PromptUpdateRequest):
"""Update system prompt content""" """Update system prompt content"""
try: try:
# Validate braces # Validate braces
is_valid, problematic = validate_prompt_braces(request.content) is_valid, problematic = validate_prompt_braces(body.content)
if not is_valid: if not is_valid:
# Return warning but still allow save # Return warning but still allow save
...@@ -62,7 +65,7 @@ async def update_system_prompt_content(request: PromptUpdateRequest): ...@@ -62,7 +65,7 @@ async def update_system_prompt_content(request: PromptUpdateRequest):
# 1. Update file # 1. Update file
with open(PROMPT_FILE_PATH, "w", encoding="utf-8") as f: with open(PROMPT_FILE_PATH, "w", encoding="utf-8") as f:
f.write(request.content) f.write(body.content)
# 2. Reset Graph Singleton to force reload prompt # 2. Reset Graph Singleton to force reload prompt
reset_graph() reset_graph()
......
This diff is collapsed.
...@@ -43,10 +43,7 @@ class EmbeddingClientManager: ...@@ -43,10 +43,7 @@ class EmbeddingClientManager:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# NOTE: from common.cache import redis_cache
# - TẠM THỜI KHÔNG DÙNG REDIS CACHE CHO EMBEDDING để tránh phụ thuộc Redis/aioredis.
# - Nếu cần bật lại cache, import `redis_cache` từ `common.cache`
# và dùng như các đoạn code cũ (get_embedding / set_embedding).
# --- Singleton --- # --- Singleton ---
_manager = EmbeddingClientManager() _manager = EmbeddingClientManager()
...@@ -55,8 +52,11 @@ get_async_embedding_client = _manager.get_async_client ...@@ -55,8 +52,11 @@ get_async_embedding_client = _manager.get_async_client
def create_embedding(text: str) -> list[float]: def create_embedding(text: str) -> list[float]:
"""Sync embedding generation (No cache for sync to avoid overhead)""" """Sync embedding generation with Layer 2 Cache"""
try: try:
# 1. Try Cache (Sync wrapper for get_embedding if needed, but here we just use what we have)
# Note: common.cache is async, so sync create_embedding will still call OpenAI
# unless we add a sync cache method. For now, focus on async.
client = get_embedding_client() client = get_embedding_client()
response = client.embeddings.create(model="text-embedding-3-small", input=text) response = client.embeddings.create(model="text-embedding-3-small", input=text)
return response.data[0].embedding return response.data[0].embedding
...@@ -67,13 +67,24 @@ def create_embedding(text: str) -> list[float]: ...@@ -67,13 +67,24 @@ def create_embedding(text: str) -> list[float]:
async def create_embedding_async(text: str) -> list[float]: async def create_embedding_async(text: str) -> list[float]:
""" """
Async embedding generation (KHÔNG dùng cache). Async embedding generation with Layer 2 Cache.
Nếu sau này cần cache lại, có thể thêm redis_cache.get_embedding / set_embedding. Saves OpenAI costs by reusing embeddings for identical queries.
""" """
try: try:
# 1. Try Layer 2 Cache
cached = await redis_cache.get_embedding(text)
if cached:
return cached
# 2. Call OpenAI
client = get_async_embedding_client() client = get_async_embedding_client()
response = await client.embeddings.create(model="text-embedding-3-small", input=text) response = await client.embeddings.create(model="text-embedding-3-small", input=text)
embedding = response.data[0].embedding embedding = response.data[0].embedding
# 3. Store in Cache
if embedding:
await redis_cache.set_embedding(text, embedding)
return embedding return embedding
except Exception as e: except Exception as e:
logger.error(f"Error creating embedding (async): {e}") logger.error(f"Error creating embedding (async): {e}")
...@@ -88,12 +99,32 @@ async def create_embeddings_async(texts: list[str]) -> list[list[float]]: ...@@ -88,12 +99,32 @@ async def create_embeddings_async(texts: list[str]) -> list[list[float]]:
if not texts: if not texts:
return [] return []
results = [[] for _ in texts]
missed_indices = []
missed_texts = []
# 1. Check Cache for each text
for i, text in enumerate(texts):
cached = await redis_cache.get_embedding(text)
if cached:
results[i] = cached
else:
missed_indices.append(i)
missed_texts.append(text)
# 2. Call OpenAI for missed texts
if missed_texts:
client = get_async_embedding_client() client = get_async_embedding_client()
response = await client.embeddings.create(model="text-embedding-3-small", input=texts) response = await client.embeddings.create(model="text-embedding-3-small", input=missed_texts)
# OpenAI returns embeddings in the same order as missed_texts
for i, data_item in enumerate(response.data):
idx = missed_indices[i]
embedding = data_item.embedding
results[idx] = embedding
# Giữ nguyên thứ tự embedding theo order input # 3. Cache the new embedding
sorted_data = sorted(response.data, key=lambda x: x.index) await redis_cache.set_embedding(missed_texts[i], embedding)
results = [item.embedding for item in sorted_data]
return results return results
except Exception as e: except Exception as e:
......
"""
StarRocks Database Connection Utility
Based on chatbot-rsa pattern
"""
import asyncio
import logging
from typing import Any
import aiomysql
import pymysql
from pymysql.cursors import DictCursor
from config import (
STARROCKS_DB,
STARROCKS_HOST,
STARROCKS_PASSWORD,
STARROCKS_PORT,
STARROCKS_USER,
)
logger = logging.getLogger(__name__)
__all__ = ["StarRocksConnection", "get_db_connection"]
class StarRocksConnectionManager:
"""
Singleton Class quản lý StarRocks Connection.
"""
def __init__(self):
self._connection: StarRocksConnection | None = None
def get_connection(self) -> "StarRocksConnection":
"""Lazy loading connection"""
if self._connection is None:
logger.info("🔧 [LAZY LOADING] Creating StarRocksConnection instance (first time)")
self._connection = StarRocksConnection()
return self._connection
# --- Singleton ---
_manager = StarRocksConnectionManager()
get_db_connection = _manager.get_connection
class StarRocksConnection:
# Shared connection (Singleton-like behavior) for all instances
_shared_conn = None
def __init__(
self,
host: str | None = None,
database: str | None = None,
user: str | None = None,
password: str | None = None,
port: int | None = None,
):
self.host = host or STARROCKS_HOST
self.database = database or STARROCKS_DB
self.user = user or STARROCKS_USER
self.password = password or STARROCKS_PASSWORD
self.port = port or STARROCKS_PORT
# self.conn references the shared connection
self.conn = None
logger.info(f"✅ StarRocksConnection initialized: {self.host}:{self.port}")
def connect(self):
"""
Establish or reuse persistent connection.
"""
# 1. Try to reuse existing shared connection
if StarRocksConnection._shared_conn and StarRocksConnection._shared_conn.open:
try:
# Ping to check if alive, reconnect if needed
StarRocksConnection._shared_conn.ping(reconnect=True)
self.conn = StarRocksConnection._shared_conn
return self.conn
except Exception as e:
logger.warning(f"⚠️ Connection lost, reconnecting: {e}")
StarRocksConnection._shared_conn = None
# 2. Create new connection if needed
print(f" [DB] 🔌 Đang kết nối StarRocks (New Session): {self.host}:{self.port}...")
logger.info(f"🔌 Connecting to StarRocks at {self.host}:{self.port} (DB: {self.database})...")
try:
new_conn = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
charset="utf8mb4",
cursorclass=DictCursor,
connect_timeout=10,
read_timeout=30,
write_timeout=30,
)
print(" [DB] ✅ Kết nối thành công.")
logger.info("✅ Connected to StarRocks")
# Save to class variable
StarRocksConnection._shared_conn = new_conn
self.conn = new_conn
except Exception as e:
print(f" [DB] ❌ Lỗi kết nối: {e!s}")
logger.error(f"❌ Failed to connect to StarRocks: {e}")
raise
return self.conn
def execute_query(self, query: str, params: tuple | None = None) -> list[dict[str, Any]]:
# print(" [DB] 🚀 Bắt đầu truy vấn dữ liệu...")
# (Reduced noise in logs)
logger.info("🚀 Executing StarRocks Query (Persistent Conn).")
conn = self.connect()
try:
with conn.cursor() as cursor:
cursor.execute(query, params)
results = cursor.fetchall()
print(f" [DB] ✅ Truy vấn xong. Lấy được {len(results)} dòng.")
logger.info(f"📊 Query successful, returned {len(results)} rows")
return [dict(row) for row in results]
except Exception as e:
print(f" [DB] ❌ Lỗi truy vấn: {e!s}")
logger.error(f"❌ StarRocks query error: {e}")
# Incase of query error due to connection, invalidate it
StarRocksConnection._shared_conn = None
raise
# FINALLY BLOCK REMOVED: Do NOT close connection
# Async pool shared
_shared_pool = None
_pool_lock = asyncio.Lock()
@classmethod
async def clear_pool(cls):
"""Clear and close existing pool (force recreate fresh connections)"""
async with cls._pool_lock:
if cls._shared_pool is not None:
logger.warning("🔄 Clearing StarRocks connection pool...")
cls._shared_pool.close()
await cls._shared_pool.wait_closed()
cls._shared_pool = None
logger.info("✅ Pool cleared successfully")
async def get_pool(self):
"""
Get or create shared async connection pool (Thread-safe singleton)
"""
if StarRocksConnection._shared_pool is None:
async with StarRocksConnection._pool_lock:
# Double-check inside lock to prevent multiple pools
if StarRocksConnection._shared_pool is None:
logger.info(f"🔌 Creating Async Pool to {self.host}:{self.port}...")
StarRocksConnection._shared_pool = await aiomysql.create_pool(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
charset="utf8mb4",
cursorclass=aiomysql.DictCursor,
minsize=5, # ← Từ 10 → 5
maxsize=30, # ← Từ 200 → 30 (QUAN TRỌNG!)
connect_timeout=10,
pool_recycle=1800, # ← Từ 3600 → 1800 (30 phút)
autocommit=True,
)
return StarRocksConnection._shared_pool
async def execute_query_async(self, query: str, params: tuple | None = None) -> list[dict[str, Any]]:
"""
Execute query asynchronously using aiomysql pool with Retry Logic.
"""
max_retries = 3
last_error = None
for attempt in range(max_retries):
try:
pool = await self.get_pool()
# logger.info(f"🚀 Executing Async Query (Attempt {attempt+1}).")
# Tăng timeout lên 30s cho load test với 300 users
conn = await asyncio.wait_for(pool.acquire(), timeout=30)
try:
async with conn.cursor() as cursor:
await cursor.execute(query, params)
results = await cursor.fetchall()
# logger.info(f"📊 Async Query successful, returned {len(results)} rows")
return [dict(row) for row in results]
finally:
pool.release(conn)
except TimeoutError as e:
last_error = e
logger.warning(f"⏱️ Pool acquire timeout (Attempt {attempt + 1}/{max_retries})")
# Timeout khi lấy connection → pool đầy, chờ rồi thử lại
await asyncio.sleep(0.2 * (attempt + 1))
continue
except ConnectionAbortedError as e:
last_error = e
logger.warning(f"🔌 Connection aborted (Attempt {attempt + 1}/{max_retries}): {e}")
# Connection bị abort → clear pool và thử lại với fresh connections
if attempt < max_retries - 1:
await StarRocksConnection.clear_pool()
await asyncio.sleep(0.3)
continue
except Exception as e:
last_error = e
logger.warning(f"⚠️ StarRocks DB Error (Attempt {attempt + 1}/{max_retries}): {e}")
if "Memory of process exceed limit" in str(e):
# Nếu StarRocks OOM, đợi một chút rồi thử lại
await asyncio.sleep(0.5 * (attempt + 1))
continue
if "Disconnected" in str(e) or "Lost connection" in str(e) or "aborted" in str(e).lower():
# Nếu mất kết nối, clear pool và thử lại
if attempt < max_retries - 1:
await StarRocksConnection.clear_pool()
await asyncio.sleep(0.3)
continue
# Các lỗi khác (cú pháp,...) thì raise luôn
raise
logger.error(f"❌ Failed after {max_retries} attempts: {last_error}")
raise last_error
def close(self):
"""Explicitly close if needed (e.g. app shutdown)"""
if StarRocksConnection._shared_conn and StarRocksConnection._shared_conn.open:
StarRocksConnection._shared_conn.close()
StarRocksConnection._shared_conn = None
self.conn = None
...@@ -189,10 +189,6 @@ class StarRocksConnection: ...@@ -189,10 +189,6 @@ class StarRocksConnection:
conn = await asyncio.wait_for(pool.acquire(), timeout=90) conn = await asyncio.wait_for(pool.acquire(), timeout=90)
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
# Ping kiểm tra sức khỏe connection
await conn.ping()
# Chạy query
await cursor.execute(query, params) await cursor.execute(query, params)
results = await cursor.fetchall() results = await cursor.fetchall()
return [dict(row) for row in results] return [dict(row) for row in results]
......
...@@ -58,7 +58,7 @@ async def startup_event(): ...@@ -58,7 +58,7 @@ async def startup_event():
middleware_manager.setup( middleware_manager.setup(
app, app,
enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit=False, # 👈 Tắt slowapi vì đã có business rate limit enable_rate_limit=True, # 👈 Bật lại SlowAPI theo yêu cầu
enable_cors=True, # 👈 Bật CORS enable_cors=True, # 👈 Bật CORS
cors_origins=["*"], # 👈 Trong production nên limit origins cors_origins=["*"], # 👈 Trong production nên limit origins
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment