252 lines
10 KiB
Python
252 lines
10 KiB
Python
import json
|
|
import sys
|
|
from typing import Union, Any # Import Any
|
|
from pydantic import SecretStr # Re-import SecretStr
|
|
import re # Import re for regex operations
|
|
|
|
from langchain_core.prompts import (
|
|
ChatPromptTemplate,
|
|
PromptTemplate,
|
|
SystemMessagePromptTemplate,
|
|
HumanMessagePromptTemplate,
|
|
)
|
|
from langchain_core.output_parsers import JsonOutputParser
|
|
from langchain_core.runnables import RunnablePassthrough, Runnable
|
|
from langchain_ollama import OllamaLLM
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_google_genai import ChatGoogleGenerativeAI # New import for Gemini
|
|
from loguru import logger
|
|
|
|
from llm.models import AnalysisFlags
|
|
from config import settings
|
|
|
|
class LLMInitializationError(Exception):
|
|
"""Custom exception for LLM initialization errors"""
|
|
def __init__(self, message, details=None):
|
|
super().__init__(message)
|
|
self.details = details
|
|
|
|
# Initialize LLM
|
|
llm: Union[ChatOpenAI, OllamaLLM, ChatGoogleGenerativeAI, None] = None # Add ChatGoogleGenerativeAI
|
|
if settings.llm.mode == 'openai':
|
|
logger.info(f"Initializing ChatOpenAI with model: {settings.llm.openai_model}")
|
|
llm = ChatOpenAI(
|
|
model=settings.llm.openai_model if settings.llm.openai_model else "", # Ensure model is str
|
|
temperature=0.7,
|
|
api_key=settings.llm.openai_api_key, # type: ignore # Suppress Pylance error due to SecretStr type mismatch
|
|
base_url=settings.llm.openai_api_base_url
|
|
)
|
|
elif settings.llm.mode == 'ollama':
|
|
logger.info(f"Initializing OllamaLLM with model: {settings.llm.ollama_model} at {settings.llm.ollama_base_url}")
|
|
try:
|
|
# Verify connection parameters
|
|
if not settings.llm.ollama_base_url:
|
|
raise ValueError("Ollama base URL is not configured")
|
|
if not settings.llm.ollama_model:
|
|
raise ValueError("Ollama model is not specified")
|
|
|
|
logger.debug(f"Attempting to connect to Ollama at {settings.llm.ollama_base_url}")
|
|
|
|
# Append /api/chat to base URL for OpenWebUI compatibility
|
|
base_url = f"{settings.llm.ollama_base_url.rstrip('/')}"
|
|
|
|
llm = OllamaLLM(
|
|
model=settings.llm.ollama_model,
|
|
base_url=base_url,
|
|
num_ctx=32000
|
|
# Removed streaming, timeout, max_retries as they are not valid parameters for OllamaLLM
|
|
)
|
|
|
|
# Test connection only if not in a test environment
|
|
import os
|
|
if os.getenv("IS_TEST_ENV") != "true":
|
|
logger.debug("Testing Ollama connection...")
|
|
llm.invoke("test /no_think ") # Simple test request
|
|
logger.info("Ollama connection established successfully")
|
|
else:
|
|
logger.info("Skipping Ollama connection test in test environment.")
|
|
|
|
except Exception as e:
|
|
error_msg = f"Failed to initialize Ollama: {str(e)}"
|
|
details = {
|
|
'model': settings.llm.ollama_model,
|
|
'url': settings.llm.ollama_base_url,
|
|
'error_type': type(e).__name__
|
|
}
|
|
logger.error(error_msg)
|
|
logger.debug(f"Connection details: {details}")
|
|
raise LLMInitializationError(
|
|
"Failed to connect to Ollama service. Please check:"
|
|
"\n1. Ollama is installed and running"
|
|
"\n2. The base URL is correct"
|
|
"\n3. The model is available",
|
|
details=details
|
|
) from e
|
|
elif settings.llm.mode == 'gemini': # New: Add Gemini initialization
|
|
logger.info(f"Initializing ChatGoogleGenerativeAI with model: {settings.llm.gemini_model}")
|
|
try:
|
|
if not settings.llm.gemini_api_key:
|
|
raise ValueError("Gemini API key is not configured")
|
|
if not settings.llm.gemini_model:
|
|
raise ValueError("Gemini model is not specified")
|
|
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=settings.llm.gemini_model,
|
|
temperature=0.7,
|
|
google_api_key=settings.llm.gemini_api_key
|
|
)
|
|
|
|
# Test connection only if not in a test environment
|
|
import os
|
|
if os.getenv("IS_TEST_ENV") != "true":
|
|
logger.debug("Testing Gemini connection...")
|
|
llm.invoke("test /no_think ") # Simple test request
|
|
logger.info("Gemini connection established successfully")
|
|
else:
|
|
logger.info("Skipping Gemini connection test in test environment.")
|
|
|
|
except Exception as e:
|
|
error_msg = f"Failed to initialize Gemini: {str(e)}"
|
|
details = {
|
|
'model': settings.llm.gemini_model,
|
|
'error_type': type(e).__name__
|
|
}
|
|
logger.error(error_msg)
|
|
logger.debug(f"Connection details: {details}")
|
|
raise LLMInitializationError(
|
|
"Failed to connect to Gemini service. Please check:"
|
|
"\n1. GEMINI_API_KEY is correct"
|
|
"\n2. GEMINI_MODEL is correct and accessible"
|
|
"\n3. Network connectivity to Gemini API",
|
|
details=details
|
|
) from e
|
|
|
|
if llm is None:
|
|
logger.error("LLM could not be initialized. Exiting.")
|
|
print("\nERROR: Unable to initialize LLM. Check logs for details.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Ensure llm is treated as a Runnable for chaining
|
|
# Cast llm to Any to bypass static type checking for chaining if it's causing issues
|
|
llm_runnable: Runnable = llm # type: ignore
|
|
|
|
# Set up Output Parser for structured JSON
|
|
parser = JsonOutputParser()
|
|
|
|
# Load prompt template from file
|
|
def load_prompt_template(version="v1.3.0"):
|
|
try:
|
|
with open(f"llm/jira_analysis_{version}.txt", "r") as f:
|
|
template_content = f.read()
|
|
|
|
# Split system and user parts
|
|
if "\n\nUSER:\n" in template_content:
|
|
system_template, user_template = template_content.split("\n\nUSER:\n")
|
|
system_template = system_template.replace("SYSTEM:\n", "").strip()
|
|
else:
|
|
# Handle legacy format
|
|
system_template = template_content
|
|
user_template = "Analyze this Jira ticket: {issueKey}"
|
|
|
|
return ChatPromptTemplate.from_messages([
|
|
SystemMessagePromptTemplate.from_template(system_template),
|
|
HumanMessagePromptTemplate.from_template(user_template)
|
|
])
|
|
except Exception as e:
|
|
logger.error(f"Failed to load prompt template: {str(e)}")
|
|
raise
|
|
|
|
# Fallback prompt template
|
|
FALLBACK_PROMPT = ChatPromptTemplate.from_messages([
|
|
SystemMessagePromptTemplate.from_template(
|
|
"Analyze Jira tickets and output JSON with hasMultipleEscalations"
|
|
),
|
|
HumanMessagePromptTemplate.from_template(
|
|
"Issue Key: {issueKey}\nSummary: {summary}"
|
|
)
|
|
])
|
|
|
|
# Helper function to extract JSON from a string that might contain other text
|
|
def extract_json(text: str) -> str:
|
|
"""
|
|
Extracts the first complete JSON object from a string.
|
|
Assumes the JSON object is enclosed in curly braces {}.
|
|
"""
|
|
# Find the first occurrence of '{'
|
|
start_index = text.find('{')
|
|
if start_index == -1:
|
|
logger.warning(f"No opening curly brace found in LLM response: {text}")
|
|
return text # Return original text if no JSON start is found
|
|
|
|
# Find the last occurrence of '}'
|
|
end_index = text.rfind('}')
|
|
if end_index == -1 or end_index < start_index:
|
|
logger.warning(f"No closing curly brace found or invalid JSON structure in LLM response: {text}")
|
|
return text # Return original text if no JSON end is found or it's before the start
|
|
|
|
json_string = text[start_index : end_index + 1]
|
|
logger.debug(f"Extracted JSON string: {json_string}")
|
|
return json_string
|
|
|
|
# Create chain with fallback mechanism
|
|
def create_analysis_chain():
|
|
try:
|
|
prompt_template = load_prompt_template()
|
|
chain = (
|
|
{
|
|
"issueKey": lambda x: x["issueKey"],
|
|
"summary": lambda x: x["summary"],
|
|
"description": lambda x: x["description"],
|
|
"status": lambda x: x["status"],
|
|
"labels": lambda x: x["labels"],
|
|
"assignee": lambda x: x["assignee"],
|
|
"updated": lambda x: x["updated"],
|
|
"comment": lambda x: x["comment"],
|
|
"format_instructions": lambda _: parser.get_format_instructions()
|
|
}
|
|
| prompt_template
|
|
| llm_runnable # Use the explicitly typed runnable
|
|
| extract_json # Add the new extraction step
|
|
| parser
|
|
)
|
|
|
|
return chain
|
|
except Exception as e:
|
|
logger.warning(f"Using fallback prompt due to error: {str(e)}")
|
|
chain = FALLBACK_PROMPT | llm_runnable # Use the explicitly typed runnable
|
|
return chain
|
|
|
|
# Initialize analysis chain
|
|
analysis_chain = create_analysis_chain()
|
|
|
|
# Enhanced response validation function
|
|
def validate_response(response: Union[dict, str], issue_key: str = "N/A") -> bool:
|
|
"""Validate the JSON response structure and content"""
|
|
try:
|
|
|
|
# If response is a string, attempt to parse it as JSON
|
|
if isinstance(response, str):
|
|
logger.debug(f"[{issue_key}] Raw LLM response (string): {response}")
|
|
try:
|
|
response = json.loads(response)
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"[{issue_key}] JSONDecodeError: {e}. Raw response: {response}")
|
|
return False
|
|
|
|
# Ensure response is a dictionary
|
|
if not isinstance(response, dict):
|
|
logger.error(f"[{issue_key}] Response is not a dictionary: {type(response)}")
|
|
return False
|
|
|
|
logger.debug(f"[{issue_key}] Parsed LLM response (JSON): {json.dumps(response)}")
|
|
|
|
# Validate against schema using AnalysisFlags model
|
|
try:
|
|
AnalysisFlags.model_validate(response)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"[{issue_key}] Pydantic validation failed: {e}. Continuing with raw response: {response}")
|
|
return True # Allow processing even if validation fails
|
|
except Exception as e:
|
|
logger.error(f"[{issue_key}] Unexpected error during response validation: {e}. Response: {response}")
|
|
return False |