Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
chatbot canifa
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
1
Merge Requests
1
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Vũ Hoàng Anh
chatbot canifa
Commits
49f43a45
Commit
49f43a45
authored
Jan 27, 2026
by
Vũ Hoàng Anh
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix mock API routing and retriever alias
parent
566ee233
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
899 additions
and
899 deletions
+899
-899
mock_api_route.py
backend/api/mock_api_route.py
+201
-202
middleware.py
backend/common/middleware.py
+335
-334
rate_limit.py
backend/common/rate_limit.py
+238
-238
server.py
backend/server.py
+125
-125
No files found.
backend/api/mock_api_route.py
View file @
49f43a45
import
asyncio
import
asyncio
import
json
import
json
import
logging
import
logging
import
time
import
time
from
fastapi
import
APIRouter
,
BackgroundTasks
,
HTTPException
from
fastapi
import
APIRouter
,
BackgroundTasks
,
HTTPException
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
agent.tools.data_retrieval_tool
import
SearchItem
,
data_retrieval_tool
from
agent.tools.data_retrieval_tool
import
SearchItem
,
data_retrieval_tool
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
router
=
APIRouter
()
router
=
APIRouter
()
# --- HELPERS ---
# --- HELPERS ---
async
def
retry_with_backoff
(
coro_fn
,
max_retries
=
3
,
backoff_factor
=
2
):
async
def
retry_with_backoff
(
coro_fn
,
max_retries
=
3
,
backoff_factor
=
2
):
"""Retry async function with exponential backoff"""
"""Retry async function with exponential backoff"""
for
attempt
in
range
(
max_retries
):
for
attempt
in
range
(
max_retries
):
try
:
try
:
return
await
coro_fn
()
return
await
coro_fn
()
except
Exception
as
e
:
except
Exception
as
e
:
if
attempt
==
max_retries
-
1
:
if
attempt
==
max_retries
-
1
:
raise
raise
wait_time
=
backoff_factor
**
attempt
wait_time
=
backoff_factor
**
attempt
logger
.
warning
(
f
"⚠️ Attempt {attempt + 1} failed: {e!s}, retrying in {wait_time}s..."
)
logger
.
warning
(
f
"⚠️ Attempt {attempt + 1} failed: {e!s}, retrying in {wait_time}s..."
)
await
asyncio
.
sleep
(
wait_time
)
await
asyncio
.
sleep
(
wait_time
)
# --- MODELS ---
# --- MODELS ---
class
MockQueryRequest
(
BaseModel
):
class
MockQueryRequest
(
BaseModel
):
user_query
:
str
user_query
:
str
user_id
:
str
|
None
=
"test_user"
user_id
:
str
|
None
=
"test_user"
session_id
:
str
|
None
=
None
session_id
:
str
|
None
=
None
class
MockDBRequest
(
BaseModel
):
class
MockDBRequest
(
BaseModel
):
query
:
str
|
None
=
None
query
:
str
|
None
=
None
magento_ref_code
:
str
|
None
=
None
magento_ref_code
:
str
|
None
=
None
price_min
:
float
|
None
=
None
price_min
:
float
|
None
=
None
price_max
:
float
|
None
=
None
price_max
:
float
|
None
=
None
top_k
:
int
=
10
top_k
:
int
=
10
class
MockRetrieverRequest
(
BaseModel
):
class
MockRetrieverRequest
(
BaseModel
):
user_query
:
str
user_query
:
str
price_min
:
float
|
None
=
None
price_min
:
float
|
None
=
None
price_max
:
float
|
None
=
None
price_max
:
float
|
None
=
None
magento_ref_code
:
str
|
None
=
None
magento_ref_code
:
str
|
None
=
None
user_id
:
str
|
None
=
"test_user"
user_id
:
str
|
None
=
"test_user"
session_id
:
str
|
None
=
None
session_id
:
str
|
None
=
None
# --- MOCK LLM RESPONSES (không gọi OpenAI) ---
# --- MOCK LLM RESPONSES (không gọi OpenAI) ---
MOCK_AI_RESPONSES
=
[
MOCK_AI_RESPONSES
=
[
"Dựa trên tìm kiếm của bạn, tôi tìm thấy các sản phẩm phù hợp với nhu cầu của bạn. Những mặt hàng này có chất lượng tốt và giá cả phải chăng."
,
"Dựa trên tìm kiếm của bạn, tôi tìm thấy các sản phẩm phù hợp với nhu cầu của bạn. Những mặt hàng này có chất lượng tốt và giá cả phải chăng."
,
"Tôi gợi ý cho bạn những sản phẩm sau. Chúng đều là những lựa chọn phổ biến và nhận được đánh giá cao từ khách hàng."
,
"Tôi gợi ý cho bạn những sản phẩm sau. Chúng đều là những lựa chọn phổ biến và nhận được đánh giá cao từ khách hàng."
,
"Dựa trên tiêu chí tìm kiếm của bạn, đây là những sản phẩm tốt nhất mà tôi có thể giới thiệu."
,
"Dựa trên tiêu chí tìm kiếm của bạn, đây là những sản phẩm tốt nhất mà tôi có thể giới thiệu."
,
"Những sản phẩm này hoàn toàn phù hợp với yêu cầu của bạn. Hãy xem chi tiết để chọn sản phẩm yêu thích nhất."
,
"Những sản phẩm này hoàn toàn phù hợp với yêu cầu của bạn. Hãy xem chi tiết để chọn sản phẩm yêu thích nhất."
,
"Tôi đã tìm được các mặt hàng tuyệt vời cho bạn. Hãy kiểm tra chúng để tìm ra lựa chọn tốt nhất."
,
"Tôi đã tìm được các mặt hàng tuyệt vời cho bạn. Hãy kiểm tra chúng để tìm ra lựa chọn tốt nhất."
,
]
]
# --- ENDPOINTS ---
# --- ENDPOINTS ---
from
agent.mock_controller
import
mock_chat_controller
from
agent.mock_controller
import
mock_chat_controller
@
router
.
post
(
"/mock/agent/chat"
,
summary
=
"Mock Agent Chat (Real Tools + Fake LLM)"
)
@
router
.
post
(
"/api/mock/agent/chat"
,
summary
=
"Mock Agent Chat (Real Tools + Fake LLM)"
)
async
def
mock_chat
(
req
:
MockQueryRequest
,
background_tasks
:
BackgroundTasks
):
async
def
mock_chat
(
req
:
MockQueryRequest
,
background_tasks
:
BackgroundTasks
):
"""
"""
Mock Agent Chat using mock_chat_controller:
Mock Agent Chat using mock_chat_controller:
- ✅ Real embedding + vector search (data_retrieval_tool THẬT)
- ✅ Real embedding + vector search (data_retrieval_tool THẬT)
- ✅ Real products from StarRocks
- ✅ Real products from StarRocks
- ❌ Fake LLM response (no OpenAI cost)
- ❌ Fake LLM response (no OpenAI cost)
- Perfect for stress testing + end-to-end testing
- Perfect for stress testing + end-to-end testing
"""
"""
try
:
try
:
logger
.
info
(
f
"🚀 [Mock Agent Chat] Starting with query: {req.user_query}"
)
logger
.
info
(
f
"🚀 [Mock Agent Chat] Starting with query: {req.user_query}"
)
result
=
await
mock_chat_controller
(
result
=
await
mock_chat_controller
(
query
=
req
.
user_query
,
query
=
req
.
user_query
,
user_id
=
req
.
user_id
or
"test_user"
,
user_id
=
req
.
user_id
or
"test_user"
,
background_tasks
=
background_tasks
,
background_tasks
=
background_tasks
,
)
)
return
{
return
{
"status"
:
"success"
,
"status"
:
"success"
,
"user_query"
:
req
.
user_query
,
"user_query"
:
req
.
user_query
,
"user_id"
:
req
.
user_id
,
"user_id"
:
req
.
user_id
,
"session_id"
:
req
.
session_id
,
"session_id"
:
req
.
session_id
,
**
result
,
# Include status, ai_response, product_ids, etc.
**
result
,
# Include status, ai_response, product_ids, etc.
}
}
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"❌ Error in mock agent chat: {e!s}"
,
exc_info
=
True
)
logger
.
error
(
f
"❌ Error in mock agent chat: {e!s}"
,
exc_info
=
True
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Mock Agent Chat Error: {e!s}"
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Mock Agent Chat Error: {e!s}"
)
@
router
.
post
(
"/api/mock/db/search"
,
summary
=
"Real Data Retrieval Tool (Agent Tool)"
)
async
def
mock_db_search
(
req
:
MockDBRequest
):
@
router
.
post
(
"/mock/db/search"
,
summary
=
"Real Data Retrieval Tool (Agent Tool)"
)
"""
async
def
mock_db_search
(
req
:
MockDBRequest
):
Dùng `data_retrieval_tool` THẬT từ Agent:
"""
- Nếu có magento_ref_code → CODE SEARCH (không cần embedding)
Dùng `data_retrieval_tool` THẬT từ Agent:
- Nếu có query → HYDE SEMANTIC SEARCH (embedding + vector search)
- Nếu có magento_ref_code → CODE SEARCH (không cần embedding)
- Lọc theo giá nếu có price_min/price_max
- Nếu có query → HYDE SEMANTIC SEARCH (embedding + vector search)
- Trả về sản phẩm thực từ StarRocks
- Lọc theo giá nếu có price_min/price_max
- Trả về sản phẩm thực từ StarRocks
Format input giống SearchItem của agent tool.
"""
Format input giống SearchItem của agent tool.
try
:
"""
logger
.
info
(
"📍 Data Retrieval Tool called"
)
try
:
start_time
=
time
.
time
()
logger
.
info
(
"📍 Data Retrieval Tool called"
)
start_time
=
time
.
time
()
# Xây dựng SearchItem từ request
search_item
=
SearchItem
(
# Xây dựng SearchItem từ request
query
=
req
.
query
or
"sản phẩm"
,
search_item
=
SearchItem
(
magento_ref_code
=
req
.
magento_ref_code
,
query
=
req
.
query
or
"sản phẩm"
,
price_min
=
req
.
price_min
,
magento_ref_code
=
req
.
magento_ref_code
,
price_max
=
req
.
price_max
,
price_min
=
req
.
price_min
,
action
=
"search"
,
price_max
=
req
.
price_max
,
)
action
=
"search"
,
)
logger
.
info
(
f
"🔧 Search params: {search_item.dict(exclude_none=True)}"
)
logger
.
info
(
f
"🔧 Search params: {search_item.dict(exclude_none=True)}"
)
# Gọi data_retrieval_tool THẬT với retry
result_json
=
await
retry_with_backoff
(
# Gọi data_retrieval_tool THẬT với retry
lambda
:
data_retrieval_tool
.
ainvoke
({
"searches"
:
[
search_item
]}),
max_retries
=
3
result_json
=
await
retry_with_backoff
(
)
lambda
:
data_retrieval_tool
.
ainvoke
({
"searches"
:
[
search_item
]}),
max_retries
=
3
result
=
json
.
loads
(
result_json
)
)
result
=
json
.
loads
(
result_json
)
elapsed_time
=
time
.
time
()
-
start_time
logger
.
info
(
f
"✅ Data Retrieval completed in {elapsed_time:.3f}s"
)
elapsed_time
=
time
.
time
()
-
start_time
logger
.
info
(
f
"✅ Data Retrieval completed in {elapsed_time:.3f}s"
)
return
{
"status"
:
result
.
get
(
"status"
,
"success"
),
return
{
"search_params"
:
search_item
.
dict
(
exclude_none
=
True
),
"status"
:
result
.
get
(
"status"
,
"success"
),
"total_results"
:
len
(
result
.
get
(
"results"
,
[{}])[
0
]
.
get
(
"products"
,
[])),
"search_params"
:
search_item
.
dict
(
exclude_none
=
True
),
"products"
:
result
.
get
(
"results"
,
[{}])[
0
]
.
get
(
"products"
,
[]),
"total_results"
:
len
(
result
.
get
(
"results"
,
[{}])[
0
]
.
get
(
"products"
,
[])),
"processing_time_ms"
:
round
(
elapsed_time
*
1000
,
2
),
"products"
:
result
.
get
(
"results"
,
[{}])[
0
]
.
get
(
"products"
,
[]),
"raw_result"
:
result
,
"processing_time_ms"
:
round
(
elapsed_time
*
1000
,
2
),
}
"raw_result"
:
result
,
}
except
Exception
as
e
:
logger
.
error
(
f
"❌ Error in DB search: {e!s}"
,
exc_info
=
True
)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"DB Search Error: {e!s}"
)
logger
.
error
(
f
"❌ Error in DB search: {e!s}"
,
exc_info
=
True
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"DB Search Error: {e!s}"
)
@
router
.
post
(
"/api/mock/retrieverdb"
,
summary
=
"Real Embedding + Real DB Vector Search"
)
@
router
.
post
(
"/api/mock/retriverdb"
,
summary
=
"Real Embedding + Real DB Vector Search (Legacy)"
)
@
router
.
post
(
"/mock/retriverdb"
,
summary
=
"Real Embedding + Real DB Vector Search"
)
async
def
mock_retriever_db
(
req
:
MockRetrieverRequest
):
async
def
mock_retriever_db
(
req
:
MockRetrieverRequest
):
"""
"""
API thực tế để test Retriever + DB Search (dùng agent tool):
API thực tế để test Retriever + DB Search (dùng agent tool):
- Lấy query từ user
- Lấy query từ user
- Embedding THẬT (gọi OpenAI embedding trong tool)
- Embedding THẬT (gọi OpenAI embedding trong tool)
- Vector search THẬT trong StarRocks
- Vector search THẬT trong StarRocks
- Trả về kết quả sản phẩm thực (bỏ qua LLM)
- Trả về kết quả sản phẩm thực (bỏ qua LLM)
Dùng để test performance của embedding + vector search riêng biệt.
Dùng để test performance của embedding + vector search riêng biệt.
"""
"""
try
:
try
:
logger
.
info
(
f
"📍 Retriever DB started: {req.user_query}"
)
logger
.
info
(
f
"📍 Retriever DB started: {req.user_query}"
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
# Xây dựng SearchItem từ request
# Xây dựng SearchItem từ request
search_item
=
SearchItem
(
search_item
=
SearchItem
(
query
=
req
.
user_query
,
query
=
req
.
user_query
,
magento_ref_code
=
req
.
magento_ref_code
,
magento_ref_code
=
req
.
magento_ref_code
,
price_min
=
req
.
price_min
,
price_min
=
req
.
price_min
,
price_max
=
req
.
price_max
,
price_max
=
req
.
price_max
,
action
=
"search"
,
action
=
"search"
,
)
)
logger
.
info
(
f
"🔧 Retriever params: {search_item.dict(exclude_none=True)}"
)
logger
.
info
(
f
"🔧 Retriever params: {search_item.dict(exclude_none=True)}"
)
# Gọi data_retrieval_tool THẬT (embedding + vector search) với retry
# Gọi data_retrieval_tool THẬT (embedding + vector search) với retry
result_json
=
await
retry_with_backoff
(
result_json
=
await
retry_with_backoff
(
lambda
:
data_retrieval_tool
.
ainvoke
({
"searches"
:
[
search_item
]}),
max_retries
=
3
lambda
:
data_retrieval_tool
.
ainvoke
({
"searches"
:
[
search_item
]}),
max_retries
=
3
)
)
result
=
json
.
loads
(
result_json
)
result
=
json
.
loads
(
result_json
)
elapsed_time
=
time
.
time
()
-
start_time
elapsed_time
=
time
.
time
()
-
start_time
logger
.
info
(
f
"✅ Retriever completed in {elapsed_time:.3f}s"
)
logger
.
info
(
f
"✅ Retriever completed in {elapsed_time:.3f}s"
)
# Parse kết quả
# Parse kết quả
search_results
=
result
.
get
(
"results"
,
[{}])[
0
]
search_results
=
result
.
get
(
"results"
,
[{}])[
0
]
products
=
search_results
.
get
(
"products"
,
[])
products
=
search_results
.
get
(
"products"
,
[])
return
{
return
{
"status"
:
result
.
get
(
"status"
,
"success"
),
"status"
:
result
.
get
(
"status"
,
"success"
),
"user_query"
:
req
.
user_query
,
"user_query"
:
req
.
user_query
,
"user_id"
:
req
.
user_id
,
"user_id"
:
req
.
user_id
,
"session_id"
:
req
.
session_id
,
"session_id"
:
req
.
session_id
,
"search_params"
:
search_item
.
dict
(
exclude_none
=
True
),
"search_params"
:
search_item
.
dict
(
exclude_none
=
True
),
"total_results"
:
len
(
products
),
"total_results"
:
len
(
products
),
"products"
:
products
,
"products"
:
products
,
"processing_time_ms"
:
round
(
elapsed_time
*
1000
,
2
),
"processing_time_ms"
:
round
(
elapsed_time
*
1000
,
2
),
}
}
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"❌ Error in retriever DB: {e!s}"
,
exc_info
=
True
)
logger
.
error
(
f
"❌ Error in retriever DB: {e!s}"
,
exc_info
=
True
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Retriever DB Error: {e!s}"
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Retriever DB Error: {e!s}"
)
backend/common/middleware.py
View file @
49f43a45
"""
"""
Middleware Module - Gom Auth + Rate Limit
Middleware Module - Gom Auth + Rate Limit
Singleton Pattern cho cả 2 services
Singleton Pattern cho cả 2 services
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
import
json
import
json
import
logging
import
logging
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
from
fastapi
import
HTTPException
,
Request
,
status
from
fastapi
import
HTTPException
,
Request
,
status
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.middleware.base
import
BaseHTTPMiddleware
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
fastapi
import
FastAPI
from
fastapi
import
FastAPI
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# =============================================================================
# =============================================================================
# CONFIGURATION
# CONFIGURATION
# =============================================================================
# =============================================================================
# Public endpoints - không cần auth
# Public endpoints - không cần auth
PUBLIC_PATHS
=
{
PUBLIC_PATHS
=
{
"/"
,
"/"
,
"/health"
,
"/health"
,
"/docs"
,
"/docs"
,
"/openapi.json"
,
"/openapi.json"
,
"/redoc"
,
"/redoc"
,
}
}
# Public path prefixes
# Public path prefixes
PUBLIC_PATH_PREFIXES
=
[
PUBLIC_PATH_PREFIXES
=
[
"/static"
,
"/static"
,
"/mock"
,
"/mock"
,
]
"/api/mock"
,
]
# =============================================================================
# AUTH + RATE LIMIT MIDDLEWARE CLASS
# =============================================================================
# =============================================================================
# AUTH + RATE LIMIT MIDDLEWARE CLASS
# =============================================================================
# Paths that need rate limit check
RATE_LIMITED_PATHS
=
[
# Paths that need rate limit check
"/api/agent/chat"
,
RATE_LIMITED_PATHS
=
[
]
"/api/agent/chat"
,
]
class
CanifaAuthMiddleware
(
BaseHTTPMiddleware
):
"""
class
CanifaAuthMiddleware
(
BaseHTTPMiddleware
):
Canifa Authentication + Rate Limit Middleware
"""
Canifa Authentication + Rate Limit Middleware
Flow:
1. Frontend gửi request với Authorization: Bearer <canifa_token>
Flow:
2. Middleware verify token với Canifa API → extract customer_id
1. Frontend gửi request với Authorization: Bearer <canifa_token>
3. Check message rate limit (Guest: 10, User: 100)
2. Middleware verify token với Canifa API → extract customer_id
4. Attach user info vào request.state
3. Check message rate limit (Guest: 10, User: 100)
5. Routes lấy trực tiếp từ request.state
4. Attach user info vào request.state
"""
5. Routes lấy trực tiếp từ request.state
"""
async
def
dispatch
(
self
,
request
:
Request
,
call_next
:
Callable
):
path
=
request
.
url
.
path
async
def
dispatch
(
self
,
request
:
Request
,
call_next
:
Callable
):
method
=
request
.
method
path
=
request
.
url
.
path
method
=
request
.
method
# ✅ Allow OPTIONS requests (CORS preflight)
if
method
==
"OPTIONS"
:
# ✅ Allow OPTIONS requests (CORS preflight)
return
await
call_next
(
request
)
if
method
==
"OPTIONS"
:
return
await
call_next
(
request
)
# Skip public endpoints
if
path
in
PUBLIC_PATHS
:
# Skip public endpoints
return
await
call_next
(
request
)
if
path
in
PUBLIC_PATHS
:
return
await
call_next
(
request
)
# Skip public path prefixes
if
any
(
path
.
startswith
(
prefix
)
for
prefix
in
PUBLIC_PATH_PREFIXES
):
# Skip public path prefixes
return
await
call_next
(
request
)
if
any
(
path
.
startswith
(
prefix
)
for
prefix
in
PUBLIC_PATH_PREFIXES
):
return
await
call_next
(
request
)
# =====================================================================
# STEP 1: AUTHENTICATION (Canifa API)
# =====================================================================
# =====================================================================
# STEP 1: AUTHENTICATION (Canifa API)
try
:
# =====================================================================
auth_header
=
request
.
headers
.
get
(
"Authorization"
)
try
:
auth_header
=
request
.
headers
.
get
(
"Authorization"
)
# --- Device ID from Body ---
device_id
=
""
# --- Device ID from Body ---
if
method
in
[
"POST"
,
"PUT"
,
"PATCH"
]:
device_id
=
""
try
:
if
method
in
[
"POST"
,
"PUT"
,
"PATCH"
]:
body_bytes
=
await
request
.
body
()
try
:
body_bytes
=
await
request
.
body
()
async
def
receive_wrapper
():
return
{
"type"
:
"http.request"
,
"body"
:
body_bytes
}
async
def
receive_wrapper
():
request
.
_receive
=
receive_wrapper
return
{
"type"
:
"http.request"
,
"body"
:
body_bytes
}
request
.
_receive
=
receive_wrapper
if
body_bytes
:
try
:
if
body_bytes
:
body_json
=
json
.
loads
(
body_bytes
)
try
:
device_id
=
body_json
.
get
(
"device_id"
,
""
)
body_json
=
json
.
loads
(
body_bytes
)
except
json
.
JSONDecodeError
:
device_id
=
body_json
.
get
(
"device_id"
,
""
)
pass
except
json
.
JSONDecodeError
:
except
Exception
as
e
:
pass
logger
.
warning
(
f
"Error reading device_id from body: {e}"
)
except
Exception
as
e
:
logger
.
warning
(
f
"Error reading device_id from body: {e}"
)
# Fallback: Nếu không có trong body, tìm trong header -> IP
if
not
device_id
:
# Fallback: Nếu không có trong body, tìm trong header -> IP
device_id
=
request
.
headers
.
get
(
"device_id"
,
""
)
if
not
device_id
:
device_id
=
request
.
headers
.
get
(
"device_id"
,
""
)
if
not
device_id
:
device_id
=
f
"unknown_{request.client.host}"
if
request
.
client
else
"unknown"
if
not
device_id
:
device_id
=
f
"unknown_{request.client.host}"
if
request
.
client
else
"unknown"
# ========== DEV MODE: Bypass auth ==========
dev_user_id
=
request
.
headers
.
get
(
"X-Dev-User-Id"
)
# ========== DEV MODE: Bypass auth ==========
if
dev_user_id
:
dev_user_id
=
request
.
headers
.
get
(
"X-Dev-User-Id"
)
logger
.
warning
(
f
"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id}"
)
if
dev_user_id
:
request
.
state
.
user
=
{
"customer_id"
:
dev_user_id
}
logger
.
warning
(
f
"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id}"
)
request
.
state
.
user_id
=
dev_user_id
request
.
state
.
user
=
{
"customer_id"
:
dev_user_id
}
request
.
state
.
is_authenticated
=
True
request
.
state
.
user_id
=
dev_user_id
request
.
state
.
device_id
=
device_id
or
dev_user_id
request
.
state
.
is_authenticated
=
True
return
await
call_next
(
request
)
request
.
state
.
device_id
=
device_id
or
dev_user_id
return
await
call_next
(
request
)
# --- TRƯỜNG HỢP 1: KHÔNG CÓ TOKEN -> GUEST ---
if
not
auth_header
or
not
auth_header
.
startswith
(
"Bearer "
):
# --- TRƯỜNG HỢP 1: KHÔNG CÓ TOKEN -> GUEST ---
request
.
state
.
user
=
None
if
not
auth_header
or
not
auth_header
.
startswith
(
"Bearer "
):
request
.
state
.
user_id
=
None
request
.
state
.
user
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
user_id
=
None
request
.
state
.
device_id
=
device_id
request
.
state
.
is_authenticated
=
False
else
:
request
.
state
.
device_id
=
device_id
# --- TRƯỜNG HỢP 2: CÓ TOKEN -> GỌI CANIFA VERIFY ---
else
:
token
=
auth_header
.
replace
(
"Bearer "
,
""
)
# --- TRƯỜNG HỢP 2: CÓ TOKEN -> GỌI CANIFA VERIFY ---
token
=
auth_header
.
replace
(
"Bearer "
,
""
)
from
common.canifa_api
import
verify_canifa_token
,
extract_user_id_from_canifa_response
from
common.canifa_api
import
verify_canifa_token
,
extract_user_id_from_canifa_response
try
:
user_data
=
await
verify_canifa_token
(
token
)
try
:
user_id
=
await
extract_user_id_from_canifa_response
(
user_data
)
user_data
=
await
verify_canifa_token
(
token
)
user_id
=
await
extract_user_id_from_canifa_response
(
user_data
)
if
user_id
:
request
.
state
.
user
=
user_data
if
user_id
:
request
.
state
.
user_id
=
user_id
request
.
state
.
user
=
user_data
request
.
state
.
token
=
token
request
.
state
.
user_id
=
user_id
request
.
state
.
is_authenticated
=
True
request
.
state
.
token
=
token
request
.
state
.
device_id
=
device_id
request
.
state
.
is_authenticated
=
True
logger
.
debug
(
f
"✅ Canifa Auth Success: User {user_id}"
)
request
.
state
.
device_id
=
device_id
else
:
logger
.
debug
(
f
"✅ Canifa Auth Success: User {user_id}"
)
logger
.
warning
(
f
"⚠️ Invalid Canifa Token -> Guest Mode"
)
else
:
request
.
state
.
user
=
None
logger
.
warning
(
f
"⚠️ Invalid Canifa Token -> Guest Mode"
)
request
.
state
.
user_id
=
None
request
.
state
.
user
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
user_id
=
None
request
.
state
.
device_id
=
device_id
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
device_id
except
Exception
as
e
:
logger
.
error
(
f
"❌ Canifa Auth Error: {e} -> Guest Mode"
)
except
Exception
as
e
:
request
.
state
.
user
=
None
logger
.
error
(
f
"❌ Canifa Auth Error: {e} -> Guest Mode"
)
request
.
state
.
user_id
=
None
request
.
state
.
user
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
user_id
=
None
request
.
state
.
device_id
=
device_id
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
device_id
except
Exception
as
e
:
logger
.
error
(
f
"❌ Middleware Auth Error: {e}"
)
except
Exception
as
e
:
request
.
state
.
user
=
None
logger
.
error
(
f
"❌ Middleware Auth Error: {e}"
)
request
.
state
.
user_id
=
None
request
.
state
.
user
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
user_id
=
None
request
.
state
.
device_id
=
""
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
""
# =====================================================================
# STEP 2: RATE LIMIT CHECK (Chỉ cho các path cần limit)
# =====================================================================
# =====================================================================
# STEP 2: RATE LIMIT CHECK (Chỉ cho các path cần limit)
if
path
in
RATE_LIMITED_PATHS
:
# =====================================================================
try
:
if
path
in
RATE_LIMITED_PATHS
:
from
common.message_limit
import
message_limit_service
try
:
from
fastapi.responses
import
JSONResponse
from
common.message_limit
import
message_limit_service
from
fastapi.responses
import
JSONResponse
# Lấy identity_key làm rate limit key
# Guest: device_id → limit 10
# Lấy identity_key làm rate limit key
# User: user_id → limit 100
# Guest: device_id → limit 10
is_authenticated
=
request
.
state
.
is_authenticated
# User: user_id → limit 100
if
is_authenticated
and
request
.
state
.
user_id
:
is_authenticated
=
request
.
state
.
is_authenticated
rate_limit_key
=
request
.
state
.
user_id
if
is_authenticated
and
request
.
state
.
user_id
:
else
:
rate_limit_key
=
request
.
state
.
user_id
rate_limit_key
=
request
.
state
.
device_id
else
:
rate_limit_key
=
request
.
state
.
device_id
if
rate_limit_key
:
can_send
,
limit_info
=
await
message_limit_service
.
check_limit
(
if
rate_limit_key
:
identity_key
=
rate_limit_key
,
can_send
,
limit_info
=
await
message_limit_service
.
check_limit
(
is_authenticated
=
is_authenticated
,
identity_key
=
rate_limit_key
,
)
is_authenticated
=
is_authenticated
,
)
# Lưu limit_info vào request.state để route có thể dùng
request
.
state
.
limit_info
=
limit_info
# Lưu limit_info vào request.state để route có thể dùng
request
.
state
.
limit_info
=
limit_info
if
not
can_send
:
logger
.
warning
(
if
not
can_send
:
f
"⚠️ Rate Limit Exceeded: {rate_limit_key} | "
logger
.
warning
(
f
"used={limit_info['used']}/{limit_info['limit']}"
f
"⚠️ Rate Limit Exceeded: {rate_limit_key} | "
)
f
"used={limit_info['used']}/{limit_info['limit']}"
return
JSONResponse
(
)
status_code
=
429
,
return
JSONResponse
(
content
=
{
status_code
=
429
,
"status"
:
"error"
,
content
=
{
"error_code"
:
limit_info
.
get
(
"error_code"
)
or
"MESSAGE_LIMIT_EXCEEDED"
,
"status"
:
"error"
,
"message"
:
limit_info
[
"message"
],
"error_code"
:
limit_info
.
get
(
"error_code"
)
or
"MESSAGE_LIMIT_EXCEEDED"
,
"require_login"
:
limit_info
[
"require_login"
],
"message"
:
limit_info
[
"message"
],
"limit_info"
:
{
"require_login"
:
limit_info
[
"require_login"
],
"limit"
:
limit_info
[
"limit"
],
"limit_info"
:
{
"used"
:
limit_info
[
"used"
],
"limit"
:
limit_info
[
"limit"
],
"remaining"
:
limit_info
[
"remaining"
],
"used"
:
limit_info
[
"used"
],
"reset_seconds"
:
limit_info
[
"reset_seconds"
],
"remaining"
:
limit_info
[
"remaining"
],
},
"reset_seconds"
:
limit_info
[
"reset_seconds"
],
},
},
)
},
else
:
)
logger
.
warning
(
f
"⚠️ No identity_key for rate limiting"
)
else
:
logger
.
warning
(
f
"⚠️ No identity_key for rate limiting"
)
except
Exception
as
e
:
logger
.
error
(
f
"❌ Rate Limit Check Error: {e}"
)
except
Exception
as
e
:
# Cho phép request tiếp tục nếu lỗi rate limit
logger
.
error
(
f
"❌ Rate Limit Check Error: {e}"
)
# Cho phép request tiếp tục nếu lỗi rate limit
return
await
call_next
(
request
)
return
await
call_next
(
request
)
# =============================================================================
# MIDDLEWARE MANAGER - Singleton to manage all middlewares
# =============================================================================
# =============================================================================
# MIDDLEWARE MANAGER - Singleton to manage all middlewares
# =============================================================================
class
MiddlewareManager
:
"""
class
MiddlewareManager
:
Middleware Manager - Singleton Pattern
"""
Quản lý và setup tất cả middlewares cho FastAPI app
Middleware Manager - Singleton Pattern
Quản lý và setup tất cả middlewares cho FastAPI app
Usage:
from common.middleware import middleware_manager
Usage:
from common.middleware import middleware_manager
app = FastAPI()
middleware_manager.setup(app, enable_auth=True, enable_rate_limit=True)
app = FastAPI()
"""
middleware_manager.setup(app, enable_auth=True, enable_rate_limit=True)
"""
_instance
:
MiddlewareManager
|
None
=
None
_initialized
:
bool
=
False
_instance
:
MiddlewareManager
|
None
=
None
_initialized
:
bool
=
False
def
__new__
(
cls
)
->
MiddlewareManager
:
if
cls
.
_instance
is
None
:
def
__new__
(
cls
)
->
MiddlewareManager
:
cls
.
_instance
=
super
()
.
__new__
(
cls
)
if
cls
.
_instance
is
None
:
return
cls
.
_instance
cls
.
_instance
=
super
()
.
__new__
(
cls
)
return
cls
.
_instance
def
__init__
(
self
)
->
None
:
if
MiddlewareManager
.
_initialized
:
def
__init__
(
self
)
->
None
:
return
if
MiddlewareManager
.
_initialized
:
return
self
.
_auth_enabled
=
False
self
.
_rate_limit_enabled
=
False
self
.
_auth_enabled
=
False
self
.
_rate_limit_enabled
=
False
MiddlewareManager
.
_initialized
=
True
logger
.
info
(
"✅ MiddlewareManager initialized"
)
MiddlewareManager
.
_initialized
=
True
logger
.
info
(
"✅ MiddlewareManager initialized"
)
def
setup
(
self
,
def
setup
(
app
:
FastAPI
,
self
,
*
,
app
:
FastAPI
,
enable_auth
:
bool
=
True
,
*
,
enable_rate_limit
:
bool
=
True
,
enable_auth
:
bool
=
True
,
enable_cors
:
bool
=
True
,
enable_rate_limit
:
bool
=
True
,
cors_origins
:
list
[
str
]
|
None
=
None
,
enable_cors
:
bool
=
True
,
)
->
None
:
cors_origins
:
list
[
str
]
|
None
=
None
,
"""
)
->
None
:
Setup tất cả middlewares cho FastAPI app.
"""
Setup tất cả middlewares cho FastAPI app.
Args:
app: FastAPI application
Args:
enable_auth: Bật Canifa authentication middleware
app: FastAPI application
enable_rate_limit: Bật rate limiting
enable_auth: Bật Canifa authentication middleware
enable_cors: Bật CORS middleware
enable_rate_limit: Bật rate limiting
cors_origins: List origins cho CORS (default: ["*"])
enable_cors: Bật CORS middleware
cors_origins: List origins cho CORS (default: ["*"])
Note:
Thứ tự middleware quan trọng! Middleware thêm sau sẽ chạy TRƯỚC.
Note:
Order: CORS → Auth → RateLimit → SlowAPI
Thứ tự middleware quan trọng! Middleware thêm sau sẽ chạy TRƯỚC.
"""
Order: CORS → Auth → RateLimit → SlowAPI
# 1. CORS Middleware (thêm cuối cùng để chạy đầu tiên)
"""
if
enable_cors
:
# 1. CORS Middleware (thêm cuối cùng để chạy đầu tiên)
self
.
_setup_cors
(
app
,
cors_origins
or
[
"*"
])
if
enable_cors
:
self
.
_setup_cors
(
app
,
cors_origins
or
[
"*"
])
# 2. Auth Middleware
if
enable_auth
:
# 2. Auth Middleware
self
.
_setup_auth
(
app
)
if
enable_auth
:
self
.
_setup_auth
(
app
)
# 3. Rate Limit Middleware
if
enable_rate_limit
:
# 3. Rate Limit Middleware
self
.
_setup_rate_limit
(
app
)
if
enable_rate_limit
:
self
.
_setup_rate_limit
(
app
)
logger
.
info
(
f
"✅ Middlewares configured: "
logger
.
info
(
f
"CORS={enable_cors}, Auth={enable_auth}, RateLimit={enable_rate_limit}"
f
"✅ Middlewares configured: "
)
f
"CORS={enable_cors}, Auth={enable_auth}, RateLimit={enable_rate_limit}"
)
def
_setup_cors
(
self
,
app
:
FastAPI
,
origins
:
list
[
str
])
->
None
:
"""Setup CORS middleware."""
def
_setup_cors
(
self
,
app
:
FastAPI
,
origins
:
list
[
str
])
->
None
:
from
fastapi.middleware.cors
import
CORSMiddleware
"""Setup CORS middleware."""
from
fastapi.middleware.cors
import
CORSMiddleware
app
.
add_middleware
(
CORSMiddleware
,
app
.
add_middleware
(
allow_origins
=
origins
,
CORSMiddleware
,
allow_credentials
=
True
,
allow_origins
=
origins
,
allow_methods
=
[
"*"
],
allow_credentials
=
True
,
allow_headers
=
[
"*"
],
allow_methods
=
[
"*"
],
)
allow_headers
=
[
"*"
],
logger
.
info
(
f
"✅ CORS middleware enabled (origins: {origins})"
)
)
logger
.
info
(
f
"✅ CORS middleware enabled (origins: {origins})"
)
def
_setup_auth
(
self
,
app
:
FastAPI
)
->
None
:
"""Setup Canifa auth middleware."""
def
_setup_auth
(
self
,
app
:
FastAPI
)
->
None
:
app
.
add_middleware
(
CanifaAuthMiddleware
)
"""Setup Canifa auth middleware."""
self
.
_auth_enabled
=
True
app
.
add_middleware
(
CanifaAuthMiddleware
)
logger
.
info
(
"✅ Canifa Auth middleware enabled"
)
self
.
_auth_enabled
=
True
logger
.
info
(
"✅ Canifa Auth middleware enabled"
)
def
_setup_rate_limit
(
self
,
app
:
FastAPI
)
->
None
:
"""Setup rate limiting."""
def
_setup_rate_limit
(
self
,
app
:
FastAPI
)
->
None
:
from
common.rate_limit
import
rate_limit_service
"""Setup rate limiting."""
from
common.rate_limit
import
rate_limit_service
rate_limit_service
.
setup
(
app
)
self
.
_rate_limit_enabled
=
True
rate_limit_service
.
setup
(
app
)
logger
.
info
(
"✅ Rate Limit middleware enabled"
)
self
.
_rate_limit_enabled
=
True
logger
.
info
(
"✅ Rate Limit middleware enabled"
)
@
property
def
is_auth_enabled
(
self
)
->
bool
:
@
property
return
self
.
_auth_enabled
def
is_auth_enabled
(
self
)
->
bool
:
return
self
.
_auth_enabled
@
property
def
is_rate_limit_enabled
(
self
)
->
bool
:
@
property
return
self
.
_rate_limit_enabled
def
is_rate_limit_enabled
(
self
)
->
bool
:
return
self
.
_rate_limit_enabled
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
middleware_manager
=
MiddlewareManager
()
middleware_manager
=
MiddlewareManager
()
backend/common/rate_limit.py
View file @
49f43a45
"""
"""
Rate Limiting Service - Singleton Pattern
Rate Limiting Service - Singleton Pattern
Sử dụng SlowAPI với Redis backend (production) hoặc Memory (dev)
Sử dụng SlowAPI với Redis backend (production) hoặc Memory (dev)
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
import
os
import
os
from
datetime
import
datetime
,
timedelta
from
datetime
import
datetime
,
timedelta
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
from
fastapi
import
Request
from
fastapi
import
Request
from
fastapi.responses
import
JSONResponse
from
fastapi.responses
import
JSONResponse
from
slowapi
import
Limiter
from
slowapi
import
Limiter
from
slowapi.errors
import
RateLimitExceeded
from
slowapi.errors
import
RateLimitExceeded
from
slowapi.middleware
import
SlowAPIMiddleware
from
slowapi.middleware
import
SlowAPIMiddleware
from
slowapi.util
import
get_remote_address
from
slowapi.util
import
get_remote_address
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
fastapi
import
FastAPI
from
fastapi
import
FastAPI
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
RateLimitService
:
class
RateLimitService
:
"""
"""
Rate Limiting Service - Singleton Pattern
Rate Limiting Service - Singleton Pattern
Usage:
Usage:
# Trong server.py
# Trong server.py
from common.rate_limit import RateLimitService
from common.rate_limit import RateLimitService
rate_limiter = RateLimitService()
rate_limiter = RateLimitService()
rate_limiter.setup(app)
rate_limiter.setup(app)
# Trong route
# Trong route
from common.rate_limit import RateLimitService
from common.rate_limit import RateLimitService
@router.post("/chat")
@router.post("/chat")
@RateLimitService().limiter.limit("10/minute")
@RateLimitService().limiter.limit("10/minute")
async def chat(request: Request):
async def chat(request: Request):
...
...
"""
"""
_instance
:
RateLimitService
|
None
=
None
_instance
:
RateLimitService
|
None
=
None
_initialized
:
bool
=
False
_initialized
:
bool
=
False
# =========================================================================
# =========================================================================
# SINGLETON PATTERN
# SINGLETON PATTERN
# =========================================================================
# =========================================================================
def
__new__
(
cls
)
->
RateLimitService
:
def
__new__
(
cls
)
->
RateLimitService
:
if
cls
.
_instance
is
None
:
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
()
.
__new__
(
cls
)
cls
.
_instance
=
super
()
.
__new__
(
cls
)
return
cls
.
_instance
return
cls
.
_instance
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
# Chỉ init một lần
# Chỉ init một lần
if
RateLimitService
.
_initialized
:
if
RateLimitService
.
_initialized
:
return
return
# Configuration
# Configuration
self
.
storage_uri
=
os
.
getenv
(
"RATE_STORAGE_URI"
,
"memory://"
)
self
.
storage_uri
=
os
.
getenv
(
"RATE_STORAGE_URI"
,
"memory://"
)
self
.
default_limits
=
[
"100/hour"
,
"30/minute"
]
self
.
default_limits
=
[
"100/hour"
,
"30/minute"
]
self
.
block_duration_minutes
=
int
(
os
.
getenv
(
"RATE_LIMIT_BLOCK_MINUTES"
,
"5"
))
self
.
block_duration_minutes
=
int
(
os
.
getenv
(
"RATE_LIMIT_BLOCK_MINUTES"
,
"5"
))
# Paths không áp dụng rate limit
# Paths không áp dụng rate limit
self
.
exempt_paths
=
{
self
.
exempt_paths
=
{
"/"
,
"/"
,
"/health"
,
"/health"
,
"/docs"
,
"/docs"
,
"/openapi.json"
,
"/openapi.json"
,
"/redoc"
,
"/redoc"
,
}
}
self
.
exempt_prefixes
=
[
"/static"
,
"/mock"
]
self
.
exempt_prefixes
=
[
"/static"
,
"/mock"
,
"/api/mock"
]
# In-memory blocklist (có thể chuyển sang Redis)
# In-memory blocklist (có thể chuyển sang Redis)
self
.
_blocklist
:
dict
[
str
,
datetime
]
=
{}
self
.
_blocklist
:
dict
[
str
,
datetime
]
=
{}
# Create limiter instance
# Create limiter instance
self
.
limiter
=
Limiter
(
self
.
limiter
=
Limiter
(
key_func
=
self
.
_get_client_identifier
,
key_func
=
self
.
_get_client_identifier
,
storage_uri
=
self
.
storage_uri
,
storage_uri
=
self
.
storage_uri
,
default_limits
=
self
.
default_limits
,
default_limits
=
self
.
default_limits
,
)
)
RateLimitService
.
_initialized
=
True
RateLimitService
.
_initialized
=
True
logger
.
info
(
f
"✅ RateLimitService initialized (storage: {self.storage_uri})"
)
logger
.
info
(
f
"✅ RateLimitService initialized (storage: {self.storage_uri})"
)
# =========================================================================
# =========================================================================
# CLIENT IDENTIFIER
# CLIENT IDENTIFIER
# =========================================================================
# =========================================================================
@
staticmethod
@
staticmethod
def
_get_client_identifier
(
request
:
Request
)
->
str
:
def
_get_client_identifier
(
request
:
Request
)
->
str
:
"""
"""
Lấy client identifier cho rate limiting.
Lấy client identifier cho rate limiting.
Ưu tiên: user_id (authenticated) > device_id > IP address
Ưu tiên: user_id (authenticated) > device_id > IP address
"""
"""
# 1. Nếu đã authenticated → dùng user_id
# 1. Nếu đã authenticated → dùng user_id
if
hasattr
(
request
.
state
,
"user_id"
)
and
request
.
state
.
user_id
:
if
hasattr
(
request
.
state
,
"user_id"
)
and
request
.
state
.
user_id
:
return
f
"user:{request.state.user_id}"
return
f
"user:{request.state.user_id}"
# 2. Nếu có device_id trong header → dùng device_id
# 2. Nếu có device_id trong header → dùng device_id
device_id
=
request
.
headers
.
get
(
"device_id"
)
device_id
=
request
.
headers
.
get
(
"device_id"
)
if
device_id
:
if
device_id
:
return
f
"device:{device_id}"
return
f
"device:{device_id}"
# 3. Fallback → IP address
# 3. Fallback → IP address
try
:
try
:
return
f
"ip:{get_remote_address(request)}"
return
f
"ip:{get_remote_address(request)}"
except
Exception
:
except
Exception
:
if
request
.
client
:
if
request
.
client
:
return
f
"ip:{request.client.host}"
return
f
"ip:{request.client.host}"
return
"unknown"
return
"unknown"
# =========================================================================
# =========================================================================
# BLOCKLIST MANAGEMENT
# BLOCKLIST MANAGEMENT
# =========================================================================
# =========================================================================
def
is_blocked
(
self
,
key
:
str
)
->
tuple
[
bool
,
int
]:
def
is_blocked
(
self
,
key
:
str
)
->
tuple
[
bool
,
int
]:
"""
"""
Check if client is blocked.
Check if client is blocked.
Returns: (is_blocked, retry_after_seconds)
Returns: (is_blocked, retry_after_seconds)
"""
"""
now
=
datetime
.
utcnow
()
now
=
datetime
.
utcnow
()
blocked_until
=
self
.
_blocklist
.
get
(
key
)
blocked_until
=
self
.
_blocklist
.
get
(
key
)
if
blocked_until
:
if
blocked_until
:
if
blocked_until
>
now
:
if
blocked_until
>
now
:
retry_after
=
int
((
blocked_until
-
now
)
.
total_seconds
())
retry_after
=
int
((
blocked_until
-
now
)
.
total_seconds
())
return
True
,
retry_after
return
True
,
retry_after
else
:
else
:
# Block expired
# Block expired
self
.
_blocklist
.
pop
(
key
,
None
)
self
.
_blocklist
.
pop
(
key
,
None
)
return
False
,
0
return
False
,
0
def
block_client
(
self
,
key
:
str
)
->
int
:
def
block_client
(
self
,
key
:
str
)
->
int
:
"""
"""
Block client for configured duration.
Block client for configured duration.
Returns: retry_after_seconds
Returns: retry_after_seconds
"""
"""
self
.
_blocklist
[
key
]
=
datetime
.
utcnow
()
+
timedelta
(
minutes
=
self
.
block_duration_minutes
)
self
.
_blocklist
[
key
]
=
datetime
.
utcnow
()
+
timedelta
(
minutes
=
self
.
block_duration_minutes
)
return
self
.
block_duration_minutes
*
60
return
self
.
block_duration_minutes
*
60
def
unblock_client
(
self
,
key
:
str
)
->
None
:
def
unblock_client
(
self
,
key
:
str
)
->
None
:
"""Unblock client manually."""
"""Unblock client manually."""
self
.
_blocklist
.
pop
(
key
,
None
)
self
.
_blocklist
.
pop
(
key
,
None
)
# =========================================================================
# =========================================================================
# PATH CHECKING
# PATH CHECKING
# =========================================================================
# =========================================================================
def
is_exempt
(
self
,
path
:
str
)
->
bool
:
def
is_exempt
(
self
,
path
:
str
)
->
bool
:
"""Check if path is exempt from rate limiting."""
"""Check if path is exempt from rate limiting."""
if
path
in
self
.
exempt_paths
:
if
path
in
self
.
exempt_paths
:
return
True
return
True
return
any
(
path
.
startswith
(
prefix
)
for
prefix
in
self
.
exempt_prefixes
)
return
any
(
path
.
startswith
(
prefix
)
for
prefix
in
self
.
exempt_prefixes
)
# =========================================================================
# =========================================================================
# SETUP FOR FASTAPI APP
# SETUP FOR FASTAPI APP
# =========================================================================
# =========================================================================
def
setup
(
self
,
app
:
FastAPI
)
->
None
:
def
setup
(
self
,
app
:
FastAPI
)
->
None
:
"""
"""
Setup rate limiting cho FastAPI app.
Setup rate limiting cho FastAPI app.
Gọi trong server.py sau khi tạo app.
Gọi trong server.py sau khi tạo app.
"""
"""
# Attach limiter to app state
# Attach limiter to app state
app
.
state
.
limiter
=
self
.
limiter
app
.
state
.
limiter
=
self
.
limiter
app
.
state
.
rate_limit_service
=
self
app
.
state
.
rate_limit_service
=
self
# Register middleware
# Register middleware
self
.
_register_block_middleware
(
app
)
self
.
_register_block_middleware
(
app
)
self
.
_register_exception_handler
(
app
)
self
.
_register_exception_handler
(
app
)
# Add SlowAPI middleware (PHẢI thêm SAU custom middlewares)
# Add SlowAPI middleware (PHẢI thêm SAU custom middlewares)
app
.
add_middleware
(
SlowAPIMiddleware
)
app
.
add_middleware
(
SlowAPIMiddleware
)
logger
.
info
(
"✅ Rate limiting middleware registered"
)
logger
.
info
(
"✅ Rate limiting middleware registered"
)
def
_register_block_middleware
(
self
,
app
:
FastAPI
)
->
None
:
def
_register_block_middleware
(
self
,
app
:
FastAPI
)
->
None
:
"""Register middleware to check blocklist."""
"""Register middleware to check blocklist."""
@
app
.
middleware
(
"http"
)
@
app
.
middleware
(
"http"
)
async
def
rate_limit_block_middleware
(
request
:
Request
,
call_next
):
async
def
rate_limit_block_middleware
(
request
:
Request
,
call_next
):
path
=
request
.
url
.
path
path
=
request
.
url
.
path
# Skip exempt paths
# Skip exempt paths
if
self
.
is_exempt
(
path
):
if
self
.
is_exempt
(
path
):
return
await
call_next
(
request
)
return
await
call_next
(
request
)
# Bypass header cho testing
# Bypass header cho testing
if
request
.
headers
.
get
(
"X-Bypass-RateLimit"
)
==
"1"
:
if
request
.
headers
.
get
(
"X-Bypass-RateLimit"
)
==
"1"
:
return
await
call_next
(
request
)
return
await
call_next
(
request
)
# Check blocklist
# Check blocklist
key
=
self
.
_get_client_identifier
(
request
)
key
=
self
.
_get_client_identifier
(
request
)
is_blocked
,
retry_after
=
self
.
is_blocked
(
key
)
is_blocked
,
retry_after
=
self
.
is_blocked
(
key
)
if
is_blocked
:
if
is_blocked
:
return
JSONResponse
(
return
JSONResponse
(
status_code
=
429
,
status_code
=
429
,
content
=
{
content
=
{
"detail"
:
"Quá số lượt cho phép. Vui lòng thử lại sau."
,
"detail"
:
"Quá số lượt cho phép. Vui lòng thử lại sau."
,
"retry_after_seconds"
:
retry_after
,
"retry_after_seconds"
:
retry_after
,
},
},
headers
=
{
"Retry-After"
:
str
(
retry_after
)},
headers
=
{
"Retry-After"
:
str
(
retry_after
)},
)
)
return
await
call_next
(
request
)
return
await
call_next
(
request
)
def
_register_exception_handler
(
self
,
app
:
FastAPI
)
->
None
:
def
_register_exception_handler
(
self
,
app
:
FastAPI
)
->
None
:
"""Register exception handler for rate limit exceeded."""
"""Register exception handler for rate limit exceeded."""
@
app
.
exception_handler
(
RateLimitExceeded
)
@
app
.
exception_handler
(
RateLimitExceeded
)
async
def
rate_limit_exceeded_handler
(
request
:
Request
,
exc
:
RateLimitExceeded
):
async
def
rate_limit_exceeded_handler
(
request
:
Request
,
exc
:
RateLimitExceeded
):
key
=
self
.
_get_client_identifier
(
request
)
key
=
self
.
_get_client_identifier
(
request
)
retry_after
=
self
.
block_client
(
key
)
retry_after
=
self
.
block_client
(
key
)
logger
.
warning
(
f
"⚠️ Rate limit exceeded for {key}, blocked for {self.block_duration_minutes} minutes"
)
logger
.
warning
(
f
"⚠️ Rate limit exceeded for {key}, blocked for {self.block_duration_minutes} minutes"
)
return
JSONResponse
(
return
JSONResponse
(
status_code
=
429
,
status_code
=
429
,
content
=
{
content
=
{
"detail"
:
"Quá số lượt cho phép. Vui lòng thử lại sau."
,
"detail"
:
"Quá số lượt cho phép. Vui lòng thử lại sau."
,
"retry_after_seconds"
:
retry_after
,
"retry_after_seconds"
:
retry_after
,
},
},
headers
=
{
"Retry-After"
:
str
(
retry_after
)},
headers
=
{
"Retry-After"
:
str
(
retry_after
)},
)
)
# =============================================================================
# =============================================================================
# SINGLETON INSTANCE - Import trực tiếp để dùng
# SINGLETON INSTANCE - Import trực tiếp để dùng
# =============================================================================
# =============================================================================
rate_limit_service
=
RateLimitService
()
rate_limit_service
=
RateLimitService
()
backend/server.py
View file @
49f43a45
import
asyncio
import
asyncio
import
os
import
os
import
platform
import
platform
if
platform
.
system
()
==
"Windows"
:
if
platform
.
system
()
==
"Windows"
:
print
(
"🔧 Windows detected: Applying SelectorEventLoopPolicy globally..."
)
print
(
"🔧 Windows detected: Applying SelectorEventLoopPolicy globally..."
)
asyncio
.
set_event_loop_policy
(
asyncio
.
WindowsSelectorEventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
asyncio
.
WindowsSelectorEventLoopPolicy
())
import
logging
import
logging
import
uvicorn
import
uvicorn
from
fastapi
import
FastAPI
from
fastapi
import
FastAPI
from
fastapi.staticfiles
import
StaticFiles
from
fastapi.staticfiles
import
StaticFiles
from
api.chatbot_route
import
router
as
chatbot_router
from
api.chatbot_route
import
router
as
chatbot_router
from
api.conservation_route
import
router
as
conservation_router
from
api.conservation_route
import
router
as
conservation_router
from
api.prompt_route
import
router
as
prompt_router
from
api.prompt_route
import
router
as
prompt_router
from
common.cache
import
redis_cache
from
common.cache
import
redis_cache
from
common.langfuse_client
import
get_langfuse_client
from
common.langfuse_client
import
get_langfuse_client
from
common.middleware
import
middleware_manager
from
common.middleware
import
middleware_manager
from
config
import
PORT
from
config
import
PORT
# Configure Logging
# Configure Logging
logging
.
basicConfig
(
logging
.
basicConfig
(
level
=
logging
.
INFO
,
level
=
logging
.
INFO
,
format
=
"
%(asctime)
s [
%(levelname)
s]
%(name)
s:
%(message)
s"
,
format
=
"
%(asctime)
s [
%(levelname)
s]
%(name)
s:
%(message)
s"
,
handlers
=
[
logging
.
StreamHandler
()],
handlers
=
[
logging
.
StreamHandler
()],
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
langfuse_client
=
get_langfuse_client
()
langfuse_client
=
get_langfuse_client
()
if
langfuse_client
:
if
langfuse_client
:
logger
.
info
(
"✅ Langfuse client ready (lazy loading)"
)
logger
.
info
(
"✅ Langfuse client ready (lazy loading)"
)
else
:
else
:
logger
.
warning
(
"⚠️ Langfuse client not available (missing keys or disabled)"
)
logger
.
warning
(
"⚠️ Langfuse client not available (missing keys or disabled)"
)
app
=
FastAPI
(
app
=
FastAPI
(
title
=
"Contract AI Service"
,
title
=
"Contract AI Service"
,
description
=
"API for Contract AI Service"
,
description
=
"API for Contract AI Service"
,
version
=
"1.0.0"
,
version
=
"1.0.0"
,
)
)
# =============================================================================
# =============================================================================
# STARTUP EVENT - Initialize Redis Cache
# STARTUP EVENT - Initialize Redis Cache
# =============================================================================
# =============================================================================
@
app
.
on_event
(
"startup"
)
@
app
.
on_event
(
"startup"
)
async
def
startup_event
():
async
def
startup_event
():
"""Initialize Redis cache on startup."""
"""Initialize Redis cache on startup."""
await
redis_cache
.
initialize
()
await
redis_cache
.
initialize
()
logger
.
info
(
"✅ Redis cache initialized for message limit"
)
logger
.
info
(
"✅ Redis cache initialized for message limit"
)
# =============================================================================
# =============================================================================
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# =============================================================================
# =============================================================================
middleware_manager
.
setup
(
middleware_manager
.
setup
(
app
,
app
,
enable_auth
=
True
,
# 👈 Bật lại Auth để test logic Guest/User
enable_auth
=
True
,
# 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit
=
True
,
# 👈 Bật lại SlowAPI theo yêu cầu
enable_rate_limit
=
True
,
# 👈 Bật lại SlowAPI theo yêu cầu
enable_cors
=
True
,
# 👈 Bật CORS
enable_cors
=
True
,
# 👈 Bật CORS
cors_origins
=
[
"*"
],
# 👈 Trong production nên limit origins
cors_origins
=
[
"*"
],
# 👈 Trong production nên limit origins
)
)
app
.
include_router
(
conservation_router
)
app
.
include_router
(
conservation_router
)
app
.
include_router
(
chatbot_router
)
app
.
include_router
(
chatbot_router
)
app
.
include_router
(
prompt_router
)
app
.
include_router
(
prompt_router
)
# --- MOCK API FOR LOAD TESTING ---
# --- MOCK API FOR LOAD TESTING ---
try
:
try
:
from
api.mock_api_route
import
router
as
mock_router
from
api.mock_api_route
import
router
as
mock_router
app
.
include_router
(
mock_router
)
app
.
include_router
(
mock_router
)
print
(
"✅ Mock API Router mounted at /mock"
)
print
(
"✅ Mock API Router mounted at /
api/
mock"
)
except
ImportError
:
except
ImportError
:
print
(
"⚠️ Mock Router not found, skipping..."
)
print
(
"⚠️ Mock Router not found, skipping..."
)
# ==========================================
# ==========================================
# 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢
# 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢
# ==========================================
# ==========================================
try
:
try
:
static_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"static"
)
static_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"static"
)
if
not
os
.
path
.
exists
(
static_dir
):
if
not
os
.
path
.
exists
(
static_dir
):
os
.
makedirs
(
static_dir
)
os
.
makedirs
(
static_dir
)
# Mount thư mục static để chạy file index.html
# Mount thư mục static để chạy file index.html
app
.
mount
(
"/static"
,
StaticFiles
(
directory
=
static_dir
,
html
=
True
),
name
=
"static"
)
app
.
mount
(
"/static"
,
StaticFiles
(
directory
=
static_dir
,
html
=
True
),
name
=
"static"
)
print
(
f
"✅ Static files mounted at /static (Dir: {static_dir})"
)
print
(
f
"✅ Static files mounted at /static (Dir: {static_dir})"
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"⚠️ Failed to mount static files: {e}"
)
print
(
f
"⚠️ Failed to mount static files: {e}"
)
from
fastapi.responses
import
RedirectResponse
from
fastapi.responses
import
RedirectResponse
@
app
.
get
(
"/"
)
@
app
.
get
(
"/"
)
async
def
root
():
async
def
root
():
return
RedirectResponse
(
url
=
"/static/index.html"
)
return
RedirectResponse
(
url
=
"/static/index.html"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
print
(
"="
*
60
)
print
(
"="
*
60
)
print
(
"🚀 Contract AI Service Starting..."
)
print
(
"🚀 Contract AI Service Starting..."
)
print
(
"="
*
60
)
print
(
"="
*
60
)
print
(
f
"📡 REST API: http://localhost:{PORT}"
)
print
(
f
"📡 REST API: http://localhost:{PORT}"
)
print
(
f
"📡 Test Chatbot: http://localhost:{PORT}/static/index.html"
)
print
(
f
"📡 Test Chatbot: http://localhost:{PORT}/static/index.html"
)
print
(
f
"📚 API Docs: http://localhost:{PORT}/docs"
)
print
(
f
"📚 API Docs: http://localhost:{PORT}/docs"
)
print
(
"="
*
60
)
print
(
"="
*
60
)
ENABLE_RELOAD
=
False
ENABLE_RELOAD
=
False
print
(
f
"⚠️ Hot reload: {ENABLE_RELOAD}"
)
print
(
f
"⚠️ Hot reload: {ENABLE_RELOAD}"
)
reload_dirs
=
[
"common"
,
"api"
,
"agent"
]
reload_dirs
=
[
"common"
,
"api"
,
"agent"
]
if
ENABLE_RELOAD
:
if
ENABLE_RELOAD
:
os
.
environ
[
"PYTHONUNBUFFERED"
]
=
"1"
os
.
environ
[
"PYTHONUNBUFFERED"
]
=
"1"
uvicorn
.
run
(
uvicorn
.
run
(
"server:app"
,
"server:app"
,
host
=
"0.0.0.0"
,
host
=
"0.0.0.0"
,
port
=
PORT
,
port
=
PORT
,
reload
=
ENABLE_RELOAD
,
reload
=
ENABLE_RELOAD
,
reload_dirs
=
reload_dirs
,
reload_dirs
=
reload_dirs
,
log_level
=
"info"
,
log_level
=
"info"
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment