252 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			252 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import json
 | |
| import sys
 | |
| from typing import Union, Any # Import Any
 | |
| from pydantic import SecretStr # Re-import SecretStr
 | |
| import re # Import re for regex operations
 | |
| 
 | |
| from langchain_core.prompts import (
 | |
|     ChatPromptTemplate,
 | |
|     PromptTemplate,
 | |
|     SystemMessagePromptTemplate,
 | |
|     HumanMessagePromptTemplate,
 | |
| )
 | |
| from langchain_core.output_parsers import JsonOutputParser
 | |
| from langchain_core.runnables import RunnablePassthrough, Runnable
 | |
| from langchain_ollama import OllamaLLM
 | |
| from langchain_openai import ChatOpenAI
 | |
| from langchain_google_genai import ChatGoogleGenerativeAI # New import for Gemini
 | |
| from loguru import logger
 | |
| 
 | |
| from llm.models import AnalysisFlags
 | |
| from config import settings
 | |
| 
 | |
| class LLMInitializationError(Exception):
 | |
|     """Custom exception for LLM initialization errors"""
 | |
|     def __init__(self, message, details=None):
 | |
|         super().__init__(message)
 | |
|         self.details = details
 | |
| 
 | |
| # Initialize LLM
 | |
| llm: Union[ChatOpenAI, OllamaLLM, ChatGoogleGenerativeAI, None] = None # Add ChatGoogleGenerativeAI
 | |
| if settings.llm.mode == 'openai':
 | |
|     logger.info(f"Initializing ChatOpenAI with model: {settings.llm.openai_model}")
 | |
|     llm = ChatOpenAI(
 | |
|         model=settings.llm.openai_model if settings.llm.openai_model else "", # Ensure model is str
 | |
|         temperature=0.7,
 | |
|         api_key=settings.llm.openai_api_key, # type: ignore # Suppress Pylance error due to SecretStr type mismatch
 | |
|         base_url=settings.llm.openai_api_base_url
 | |
|     )
 | |
| elif settings.llm.mode == 'ollama':
 | |
|     logger.info(f"Initializing OllamaLLM with model: {settings.llm.ollama_model} at {settings.llm.ollama_base_url}")
 | |
|     try:
 | |
|         # Verify connection parameters
 | |
|         if not settings.llm.ollama_base_url:
 | |
|             raise ValueError("Ollama base URL is not configured")
 | |
|         if not settings.llm.ollama_model:
 | |
|             raise ValueError("Ollama model is not specified")
 | |
|             
 | |
|         logger.debug(f"Attempting to connect to Ollama at {settings.llm.ollama_base_url}")
 | |
|         
 | |
|         # Append /api/chat to base URL for OpenWebUI compatibility
 | |
|         base_url = f"{settings.llm.ollama_base_url.rstrip('/')}"
 | |
|         
 | |
|         llm = OllamaLLM(
 | |
|             model=settings.llm.ollama_model,
 | |
|             base_url=base_url,
 | |
|             num_ctx=32000
 | |
|             # Removed streaming, timeout, max_retries as they are not valid parameters for OllamaLLM
 | |
|         )
 | |
|         
 | |
|         # Test connection only if not in a test environment
 | |
|         import os
 | |
|         if os.getenv("IS_TEST_ENV") != "true":
 | |
|             logger.debug("Testing Ollama connection...")
 | |
|             llm.invoke("test /no_think ")  # Simple test request
 | |
|             logger.info("Ollama connection established successfully")
 | |
|         else:
 | |
|             logger.info("Skipping Ollama connection test in test environment.")
 | |
|         
 | |
|     except Exception as e:
 | |
|         error_msg = f"Failed to initialize Ollama: {str(e)}"
 | |
|         details = {
 | |
|             'model': settings.llm.ollama_model,
 | |
|             'url': settings.llm.ollama_base_url,
 | |
|             'error_type': type(e).__name__
 | |
|         }
 | |
|         logger.error(error_msg)
 | |
|         logger.debug(f"Connection details: {details}")
 | |
|         raise LLMInitializationError(
 | |
|             "Failed to connect to Ollama service. Please check:"
 | |
|             "\n1. Ollama is installed and running"
 | |
|             "\n2. The base URL is correct"
 | |
|             "\n3. The model is available",
 | |
|             details=details
 | |
|         ) from e
 | |
| elif settings.llm.mode == 'gemini': # New: Add Gemini initialization
 | |
|     logger.info(f"Initializing ChatGoogleGenerativeAI with model: {settings.llm.gemini_model}")
 | |
|     try:
 | |
|         if not settings.llm.gemini_api_key:
 | |
|             raise ValueError("Gemini API key is not configured")
 | |
|         if not settings.llm.gemini_model:
 | |
|             raise ValueError("Gemini model is not specified")
 | |
|         
 | |
|         llm = ChatGoogleGenerativeAI(
 | |
|             model=settings.llm.gemini_model,
 | |
|             temperature=0.7,
 | |
|             google_api_key=settings.llm.gemini_api_key
 | |
|         )
 | |
|         
 | |
|         # Test connection only if not in a test environment
 | |
|         import os
 | |
|         if os.getenv("IS_TEST_ENV") != "true":
 | |
|             logger.debug("Testing Gemini connection...")
 | |
|             llm.invoke("test /no_think ") # Simple test request
 | |
|             logger.info("Gemini connection established successfully")
 | |
|         else:
 | |
|             logger.info("Skipping Gemini connection test in test environment.")
 | |
| 
 | |
|     except Exception as e:
 | |
|         error_msg = f"Failed to initialize Gemini: {str(e)}"
 | |
|         details = {
 | |
|             'model': settings.llm.gemini_model,
 | |
|             'error_type': type(e).__name__
 | |
|         }
 | |
|         logger.error(error_msg)
 | |
|         logger.debug(f"Connection details: {details}")
 | |
|         raise LLMInitializationError(
 | |
|             "Failed to connect to Gemini service. Please check:"
 | |
|             "\n1. GEMINI_API_KEY is correct"
 | |
|             "\n2. GEMINI_MODEL is correct and accessible"
 | |
|             "\n3. Network connectivity to Gemini API",
 | |
|             details=details
 | |
|         ) from e
 | |
| 
 | |
| if llm is None:
 | |
|     logger.error("LLM could not be initialized. Exiting.")
 | |
|     print("\nERROR: Unable to initialize LLM. Check logs for details.", file=sys.stderr)
 | |
|     sys.exit(1)
 | |
| 
 | |
| # Ensure llm is treated as a Runnable for chaining
 | |
| # Cast llm to Any to bypass static type checking for chaining if it's causing issues
 | |
| llm_runnable: Runnable = llm # type: ignore
 | |
| 
 | |
| # Set up Output Parser for structured JSON
 | |
| parser = JsonOutputParser()
 | |
| 
 | |
| # Load prompt template from file
 | |
| def load_prompt_template(version="v1.3.0"):
 | |
|     try:
 | |
|         with open(f"llm/jira_analysis_{version}.txt", "r") as f:
 | |
|             template_content = f.read()
 | |
|         
 | |
|         # Split system and user parts
 | |
|         if "\n\nUSER:\n" in template_content:
 | |
|             system_template, user_template = template_content.split("\n\nUSER:\n")
 | |
|             system_template = system_template.replace("SYSTEM:\n", "").strip()
 | |
|         else:
 | |
|             # Handle legacy format
 | |
|             system_template = template_content
 | |
|             user_template = "Analyze this Jira ticket: {issueKey}"
 | |
|         
 | |
|         return ChatPromptTemplate.from_messages([
 | |
|             SystemMessagePromptTemplate.from_template(system_template),
 | |
|             HumanMessagePromptTemplate.from_template(user_template)
 | |
|         ])
 | |
|     except Exception as e:
 | |
|         logger.error(f"Failed to load prompt template: {str(e)}")
 | |
|         raise
 | |
| 
 | |
| # Fallback prompt template
 | |
| FALLBACK_PROMPT = ChatPromptTemplate.from_messages([
 | |
|     SystemMessagePromptTemplate.from_template(
 | |
|         "Analyze Jira tickets and output JSON with hasMultipleEscalations"
 | |
|     ),
 | |
|     HumanMessagePromptTemplate.from_template(
 | |
|         "Issue Key: {issueKey}\nSummary: {summary}"
 | |
|     )
 | |
| ])
 | |
| 
 | |
| # Helper function to extract JSON from a string that might contain other text
 | |
| def extract_json(text: str) -> str:
 | |
|     """
 | |
|     Extracts the first complete JSON object from a string.
 | |
|     Assumes the JSON object is enclosed in curly braces {}.
 | |
|     """
 | |
|     # Find the first occurrence of '{'
 | |
|     start_index = text.find('{')
 | |
|     if start_index == -1:
 | |
|         logger.warning(f"No opening curly brace found in LLM response: {text}")
 | |
|         return text # Return original text if no JSON start is found
 | |
| 
 | |
|     # Find the last occurrence of '}'
 | |
|     end_index = text.rfind('}')
 | |
|     if end_index == -1 or end_index < start_index:
 | |
|         logger.warning(f"No closing curly brace found or invalid JSON structure in LLM response: {text}")
 | |
|         return text # Return original text if no JSON end is found or it's before the start
 | |
| 
 | |
|     json_string = text[start_index : end_index + 1]
 | |
|     logger.debug(f"Extracted JSON string: {json_string}")
 | |
|     return json_string
 | |
| 
 | |
| # Create chain with fallback mechanism
 | |
| def create_analysis_chain():
 | |
|     try:
 | |
|         prompt_template = load_prompt_template()
 | |
|         chain = (
 | |
|             {
 | |
|                 "issueKey": lambda x: x["issueKey"],
 | |
|                 "summary": lambda x: x["summary"],
 | |
|                 "description": lambda x: x["description"],
 | |
|                 "status": lambda x: x["status"],
 | |
|                 "labels": lambda x: x["labels"],
 | |
|                 "assignee": lambda x: x["assignee"],
 | |
|                 "updated": lambda x: x["updated"],
 | |
|                 "comment": lambda x: x["comment"],
 | |
|                 "format_instructions": lambda _: parser.get_format_instructions()
 | |
|             }
 | |
|             | prompt_template
 | |
|             | llm_runnable # Use the explicitly typed runnable
 | |
|             | extract_json # Add the new extraction step
 | |
|             | parser
 | |
|         )
 | |
|         
 | |
|         return chain
 | |
|     except Exception as e:
 | |
|         logger.warning(f"Using fallback prompt due to error: {str(e)}")
 | |
|         chain = FALLBACK_PROMPT | llm_runnable # Use the explicitly typed runnable
 | |
|         return chain
 | |
| 
 | |
| # Initialize analysis chain
 | |
| analysis_chain = create_analysis_chain()
 | |
| 
 | |
| # Enhanced response validation function
 | |
| def validate_response(response: Union[dict, str], issue_key: str = "N/A") -> bool:
 | |
|     """Validate the JSON response structure and content"""
 | |
|     try:
 | |
| 
 | |
|         # If response is a string, attempt to parse it as JSON
 | |
|         if isinstance(response, str):
 | |
|             logger.debug(f"[{issue_key}] Raw LLM response (string): {response}")
 | |
|             try:
 | |
|                 response = json.loads(response)
 | |
|             except json.JSONDecodeError as e:
 | |
|                 logger.error(f"[{issue_key}] JSONDecodeError: {e}. Raw response: {response}")
 | |
|                 return False
 | |
|         
 | |
|         # Ensure response is a dictionary
 | |
|         if not isinstance(response, dict):
 | |
|             logger.error(f"[{issue_key}] Response is not a dictionary: {type(response)}")
 | |
|             return False
 | |
|         
 | |
|         logger.debug(f"[{issue_key}] Parsed LLM response (JSON): {json.dumps(response)}")
 | |
| 
 | |
|         # Validate against schema using AnalysisFlags model
 | |
|         try:
 | |
|             AnalysisFlags.model_validate(response)
 | |
|             return True
 | |
|         except Exception as e:
 | |
|             logger.warning(f"[{issue_key}] Pydantic validation failed: {e}. Continuing with raw response: {response}")
 | |
|             return True # Allow processing even if validation fails
 | |
|     except Exception as e:
 | |
|         logger.error(f"[{issue_key}] Unexpected error during response validation: {e}. Response: {response}")
 | |
|         return False | 
