Commit db05a5fe authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

modify : llm factory file

parent f1a80925
""" """
LLM Factory - Centralized LLM creation for OpenAI & Gemini. LLM Factory - OpenAI LLM creation with caching.
Manages initialization and caching of LLM models with automatic provider detection. Manages initialization and caching of OpenAI models.
""" """
import contextlib import contextlib
import logging import logging
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from config import GOOGLE_API_KEY, OPENAI_API_KEY from config import OPENAI_API_KEY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LLMFactory: class LLMFactory:
"""Singleton factory for managing LLM instances with caching and provider auto-detection.""" """Singleton factory for managing OpenAI LLM instances with caching."""
COMMON_MODELS: list[str] = [ COMMON_MODELS: list[str] = [
"gpt-4o-mini", "gpt-4o-mini",
"gemini-2.0-flash-lite-preview-02-05", "gpt-4o",
"gpt-5-nano",
"gpt-5-mini",
] ]
def __init__(self): def __init__(self):
...@@ -64,7 +65,7 @@ class LLMFactory: ...@@ -64,7 +65,7 @@ class LLMFactory:
api_key: str | None = None, api_key: str | None = None,
) -> BaseChatModel: ) -> BaseChatModel:
""" """
Create and cache a new LLM instance based on model name. Create and cache a new OpenAI LLM instance.
Args: Args:
model_name: Clean model identifier model_name: Clean model identifier
...@@ -76,12 +77,9 @@ class LLMFactory: ...@@ -76,12 +77,9 @@ class LLMFactory:
Configured LLM instance Configured LLM instance
Raises: Raises:
ValueError: If required API keys are missing ValueError: If API key is missing
""" """
try: try:
if self._is_gemini_model(model_name):
llm = self._create_gemini(model_name, streaming, api_key)
else:
llm = self._create_openai(model_name, streaming, api_key) llm = self._create_openai(model_name, streaming, api_key)
if json_mode: if json_mode:
...@@ -95,34 +93,12 @@ class LLMFactory: ...@@ -95,34 +93,12 @@ class LLMFactory:
logger.error(f"❌ Failed to create model {model_name}: {e}") logger.error(f"❌ Failed to create model {model_name}: {e}")
raise raise
def _is_gemini_model(self, model_name: str) -> bool:
"""Check if model name is a Gemini model."""
return "gemini" in model_name.lower()
def _create_gemini(self, model_name: str, streaming: bool, api_key: str | None) -> BaseChatModel:
"""Create Gemini model instance."""
key = api_key or GOOGLE_API_KEY
if not key:
raise ValueError("GOOGLE_API_KEY is required for Gemini models")
llm = ChatGoogleGenerativeAI(
model=model_name,
streaming=streaming,
google_api_key=key,
temperature=0,
)
logger.info(f"✨ Created Gemini: {model_name}")
return llm
def _create_openai(self, model_name: str, streaming: bool, api_key: str | None) -> BaseChatModel: def _create_openai(self, model_name: str, streaming: bool, api_key: str | None) -> BaseChatModel:
"""Create OpenAI model instance with fallback to Gemini if needed.""" """Create OpenAI model instance."""
key = api_key or OPENAI_API_KEY key = api_key or OPENAI_API_KEY
if not key: if not key:
logger.warning("⚠️ No OpenAI key, attempting Gemini fallback") raise ValueError("OPENAI_API_KEY is required")
if GOOGLE_API_KEY:
return self._create_gemini("gemini-1.5-flash", streaming, GOOGLE_API_KEY)
raise ValueError("Neither OPENAI_API_KEY nor GOOGLE_API_KEY is available")
llm = ChatOpenAI( llm = ChatOpenAI(
model=model_name, model=model_name,
......
...@@ -5,11 +5,11 @@ services: ...@@ -5,11 +5,11 @@ services:
container_name: canifa_backend container_name: canifa_backend
env_file: .env env_file: .env
ports: ports:
- "8000:8000" - "5000:5000"
volumes: volumes:
- .:/app - .:/app
environment: environment:
- PORT=8000 - PORT=5000
restart: unless-stopped restart: unless-stopped
logging: logging:
driver: "json-file" driver: "json-file"
......
...@@ -36,7 +36,7 @@ http://localhost:8000 ...@@ -36,7 +36,7 @@ http://localhost:8000
```json ```json
{ {
"user_id": "user_12345", "user_id": "user_12345",
"user_query": "Cho em xem áo sơ mi nam dưới 500k" "user_query": "Cho em xem áo sơ mi nam dưới 5a00k"
} }
``` ```
......
...@@ -11,25 +11,15 @@ Modules: ...@@ -11,25 +11,15 @@ Modules:
- conversation: Conversation history management - conversation: Conversation history management
""" """
import os
import logging import logging
import os
# Configure Logging # Configure Logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", handlers=[logging.StreamHandler()]
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[logging.StreamHandler()]
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from config import LANGSMITH_API_KEY, LANGSMITH_ENDPOINT, LANGSMITH_PROJECT, LANGSMITH_TRACING
# Ensure LangSmith Env Vars are set for the process
os.environ["LANGSMITH_TRACING"] = LANGSMITH_TRACING
os.environ["LANGSMITH_ENDPOINT"] = LANGSMITH_ENDPOINT
os.environ["LANGSMITH_API_KEY"] = LANGSMITH_API_KEY
os.environ["LANGSMITH_PROJECT"] = LANGSMITH_PROJECT
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
...@@ -39,8 +29,6 @@ from fastapi.staticfiles import StaticFiles # Import StaticFiles ...@@ -39,8 +29,6 @@ from fastapi.staticfiles import StaticFiles # Import StaticFiles
# Updated APIs # Updated APIs
from api.chatbot_route import router as chatbot_router from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router from api.conservation_route import router as conservation_router
from common.middleware import ClerkAuthMiddleware
from config import PORT from config import PORT
app = FastAPI( app = FastAPI(
...@@ -78,7 +66,6 @@ app.include_router(conservation_router) ...@@ -78,7 +66,6 @@ app.include_router(conservation_router)
# Chatbot Agent (Mới) # Chatbot Agent (Mới)
app.include_router(chatbot_router, prefix="/api/agent") app.include_router(chatbot_router, prefix="/api/agent")
# Mount Static Files
# Mount this LAST to avoid conflicts with API routes # Mount this LAST to avoid conflicts with API routes
try: try:
static_dir = os.path.join(os.path.dirname(__file__), "static") static_dir = os.path.join(os.path.dirname(__file__), "static")
...@@ -93,6 +80,7 @@ except Exception as e: ...@@ -93,6 +80,7 @@ except Exception as e:
async def root(): async def root():
return {"message": "Contract AI Service is running!", "status": "healthy"} return {"message": "Contract AI Service is running!", "status": "healthy"}
if __name__ == "__main__": if __name__ == "__main__":
print("=" * 60) print("=" * 60)
print("🚀 Contract AI Service Starting...") print("🚀 Contract AI Service Starting...")
...@@ -104,15 +92,13 @@ if __name__ == "__main__": ...@@ -104,15 +92,13 @@ if __name__ == "__main__":
print("=" * 60) print("=" * 60)
# ENABLE_RELOAD = os.getenv("ENABLE_RELOAD", "false").lower() in ("true", "1", "yes") # ENABLE_RELOAD = os.getenv("ENABLE_RELOAD", "false").lower() in ("true", "1", "yes")
ENABLE_RELOAD = True ENABLE_RELOAD = False # Tạm thời tắt reload để kiểm tra độ ổn định
print("⚠️ Hot reload: FORCED ON (Dev Mode)") print(f"⚠️ Hot reload: {ENABLE_RELOAD}")
reload_dirs = [ reload_dirs = [
"ai_contract",
"conversation",
"common", "common",
"api", # Watch api folder "api", # Watch api folder
"agent" # Watch agent folder "agent", # Watch agent folder
] ]
uvicorn.run( uvicorn.run(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment