169 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			169 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Union
 | |
| from langchain_ollama import OllamaLLM
 | |
| from langchain_openai import ChatOpenAI
 | |
| from langchain_core.prompts import PromptTemplate
 | |
| from langchain_core.output_parsers import JsonOutputParser
 | |
| from loguru import logger
 | |
| import sys
 | |
| 
 | |
| from config import settings
 | |
| from .models import AnalysisFlags
 | |
| 
 | |
| class LLMInitializationError(Exception):
 | |
|     """Custom exception for LLM initialization errors"""
 | |
|     def __init__(self, message, details=None):
 | |
|         super().__init__(message)
 | |
|         self.details = details
 | |
| 
 | |
| # Initialize LLM
 | |
| llm = None
 | |
| if settings.llm.mode == 'openai':
 | |
|     logger.info(f"Initializing ChatOpenAI with model: {settings.openai_model}")
 | |
|     llm = ChatOpenAI(
 | |
|         model=settings.openai_model,
 | |
|         temperature=0.7,
 | |
|         max_tokens=2000,
 | |
|         api_key=settings.openai_api_key,
 | |
|         base_url=settings.openai_api_base_url
 | |
|     )
 | |
| elif settings.llm.mode == 'ollama':
 | |
|     logger.info(f"Initializing OllamaLLM with model: {settings.llm.ollama_model} at {settings.llm.ollama_base_url}")
 | |
|     try:
 | |
|         # Verify connection parameters
 | |
|         if not settings.llm.ollama_base_url:
 | |
|             raise ValueError("Ollama base URL is not configured")
 | |
|         if not settings.llm.ollama_model:
 | |
|             raise ValueError("Ollama model is not specified")
 | |
|             
 | |
|         logger.debug(f"Attempting to connect to Ollama at {settings.llm.ollama_base_url}")
 | |
|         
 | |
|         # Append /api/chat to base URL for OpenWebUI compatibility
 | |
|         base_url = f"{settings.llm.ollama_base_url.rstrip('/')}"
 | |
|         
 | |
|         llm = OllamaLLM(
 | |
|             model=settings.llm.ollama_model,
 | |
|             base_url=base_url,
 | |
|             streaming=False,
 | |
|             timeout=30,  # 30 second timeout
 | |
|             max_retries=3,  # Retry up to 3 times
 | |
|             temperature=0.1,
 | |
|             top_p=0.2
 | |
|         )
 | |
|         
 | |
|         # Test connection
 | |
|         logger.debug("Testing Ollama connection...")
 | |
|         llm.invoke("test")  # Simple test request
 | |
|         logger.info("Ollama connection established successfully")
 | |
|         
 | |
|     except Exception as e:
 | |
|         error_msg = f"Failed to initialize Ollama: {str(e)}"
 | |
|         details = {
 | |
|             'model': settings.llm.ollama_model,
 | |
|             'url': settings.llm.ollama_base_url,
 | |
|             'error_type': type(e).__name__
 | |
|         }
 | |
|         logger.error(error_msg)
 | |
|         logger.debug(f"Connection details: {details}")
 | |
|         raise LLMInitializationError(
 | |
|             "Failed to connect to Ollama service. Please check:"
 | |
|             "\n1. Ollama is installed and running"
 | |
|             "\n2. The base URL is correct"
 | |
|             "\n3. The model is available",
 | |
|             details=details
 | |
|         ) from e
 | |
| 
 | |
| if llm is None:
 | |
|     logger.error("LLM could not be initialized. Exiting.")
 | |
|     print("\nERROR: Unable to initialize LLM. Check logs for details.", file=sys.stderr)
 | |
|     sys.exit(1)
 | |
| 
 | |
| # Set up Output Parser for structured JSON
 | |
| parser = JsonOutputParser(pydantic_object=AnalysisFlags)
 | |
| 
 | |
| # Load prompt template from file
 | |
| def load_prompt_template(version="v1.1.0"):
 | |
|     try:
 | |
|         with open(f"llm/prompts/jira_analysis_{version}.txt", "r") as f:
 | |
|             template = f.read()
 | |
|         return PromptTemplate(
 | |
|             template=template,
 | |
|             input_variables=[
 | |
|                 "issueKey", "summary", "description", "status", "labels",
 | |
|                 "assignee", "updated", "comment"
 | |
|             ],
 | |
|             partial_variables={"format_instructions": parser.get_format_instructions()},
 | |
|         )
 | |
|     except Exception as e:
 | |
|         logger.error(f"Failed to load prompt template: {str(e)}")
 | |
|         raise
 | |
| 
 | |
| # Fallback prompt template
 | |
| FALLBACK_PROMPT = PromptTemplate(
 | |
|     template="Please analyze this Jira ticket and provide a basic summary.",
 | |
|     input_variables=["issueKey", "summary"]
 | |
| )
 | |
| 
 | |
| # Create chain with fallback mechanism
 | |
| def create_analysis_chain():
 | |
|     try:
 | |
|         prompt_template = load_prompt_template()
 | |
|         chain = prompt_template | llm | parser
 | |
|         
 | |
|         # Add langfuse handler if enabled
 | |
|         if settings.langfuse.enabled:
 | |
|             chain = chain.with_config(
 | |
|                 callbacks=[settings.langfuse_handler]
 | |
|             )
 | |
|             
 | |
|         return chain
 | |
|     except Exception as e:
 | |
|         logger.warning(f"Using fallback prompt due to error: {str(e)}")
 | |
|         chain = FALLBACK_PROMPT | llm | parser
 | |
|         
 | |
|         if settings.langfuse.enabled:
 | |
|             chain = chain.with_config(
 | |
|                 callbacks=[settings.langfuse_handler]
 | |
|             )
 | |
|             
 | |
|         return chain
 | |
| 
 | |
| # Initialize analysis chain
 | |
| analysis_chain = create_analysis_chain()
 | |
| 
 | |
| # Enhanced response validation function
 | |
| def validate_response(response: Union[dict, str]) -> bool:
 | |
|     """Validate the JSON response structure and content"""
 | |
|     try:
 | |
|         # If response is a string, attempt to parse it as JSON
 | |
|         if isinstance(response, str):
 | |
|             try:
 | |
|                 response = json.loads(response)
 | |
|             except json.JSONDecodeError:
 | |
|                 return False
 | |
|                 
 | |
|         # Ensure response is a dictionary
 | |
|         if not isinstance(response, dict):
 | |
|             return False
 | |
|             
 | |
|         # Check required fields
 | |
|         required_fields = ["hasMultipleEscalations", "customerSentiment"]
 | |
|         if not all(field in response for field in required_fields):
 | |
|             return False
 | |
|             
 | |
|         # Validate field types
 | |
|         if not isinstance(response["hasMultipleEscalations"], bool):
 | |
|             return False
 | |
|             
 | |
|         if response["customerSentiment"] is not None:
 | |
|             if not isinstance(response["customerSentiment"], str):
 | |
|                 return False
 | |
|                 
 | |
|         # Validate against schema using AnalysisFlags model
 | |
|         try:
 | |
|             AnalysisFlags.model_validate(response)
 | |
|             return True
 | |
|         except Exception:
 | |
|             return False
 | |
|             
 | |
|     except Exception:
 | |
|         return False |