jira-webhook-llm/llm/chains.py

229 lines
9.0 KiB
Python

import json
import sys
from typing import Union, Any # Import Any
from pydantic import SecretStr # Re-import SecretStr
from langchain_core.prompts import (
ChatPromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnablePassthrough, Runnable
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI # New import for Gemini
from loguru import logger
from llm.models import AnalysisFlags
from config import settings
class LLMInitializationError(Exception):
"""Custom exception for LLM initialization errors"""
def __init__(self, message, details=None):
super().__init__(message)
self.details = details
# Initialize LLM
llm: Union[ChatOpenAI, OllamaLLM, ChatGoogleGenerativeAI, None] = None # Add ChatGoogleGenerativeAI
if settings.llm.mode == 'openai':
logger.info(f"Initializing ChatOpenAI with model: {settings.llm.openai_model}")
llm = ChatOpenAI(
model=settings.llm.openai_model if settings.llm.openai_model else "", # Ensure model is str
temperature=0.7,
max_tokens=2000,
api_key=settings.llm.openai_api_key, # type: ignore # Suppress Pylance error due to SecretStr type mismatch
base_url=settings.llm.openai_api_base_url
)
elif settings.llm.mode == 'ollama':
logger.info(f"Initializing OllamaLLM with model: {settings.llm.ollama_model} at {settings.llm.ollama_base_url}")
try:
# Verify connection parameters
if not settings.llm.ollama_base_url:
raise ValueError("Ollama base URL is not configured")
if not settings.llm.ollama_model:
raise ValueError("Ollama model is not specified")
logger.debug(f"Attempting to connect to Ollama at {settings.llm.ollama_base_url}")
# Append /api/chat to base URL for OpenWebUI compatibility
base_url = f"{settings.llm.ollama_base_url.rstrip('/')}"
llm = OllamaLLM(
model=settings.llm.ollama_model,
base_url=base_url
# Removed streaming, timeout, max_retries as they are not valid parameters for OllamaLLM
)
# Test connection only if not in a test environment
import os
if os.getenv("IS_TEST_ENV") != "true":
logger.debug("Testing Ollama connection...")
llm.invoke("test") # Simple test request
logger.info("Ollama connection established successfully")
else:
logger.info("Skipping Ollama connection test in test environment.")
except Exception as e:
error_msg = f"Failed to initialize Ollama: {str(e)}"
details = {
'model': settings.llm.ollama_model,
'url': settings.llm.ollama_base_url,
'error_type': type(e).__name__
}
logger.error(error_msg)
logger.debug(f"Connection details: {details}")
raise LLMInitializationError(
"Failed to connect to Ollama service. Please check:"
"\n1. Ollama is installed and running"
"\n2. The base URL is correct"
"\n3. The model is available",
details=details
) from e
elif settings.llm.mode == 'gemini': # New: Add Gemini initialization
logger.info(f"Initializing ChatGoogleGenerativeAI with model: {settings.llm.gemini_model}")
try:
if not settings.llm.gemini_api_key:
raise ValueError("Gemini API key is not configured")
if not settings.llm.gemini_model:
raise ValueError("Gemini model is not specified")
llm = ChatGoogleGenerativeAI(
model=settings.llm.gemini_model,
temperature=0.7,
max_tokens=2000,
google_api_key=settings.llm.gemini_api_key
)
# Test connection only if not in a test environment
import os
if os.getenv("IS_TEST_ENV") != "true":
logger.debug("Testing Gemini connection...")
llm.invoke("test") # Simple test request
logger.info("Gemini connection established successfully")
else:
logger.info("Skipping Gemini connection test in test environment.")
except Exception as e:
error_msg = f"Failed to initialize Gemini: {str(e)}"
details = {
'model': settings.llm.gemini_model,
'error_type': type(e).__name__
}
logger.error(error_msg)
logger.debug(f"Connection details: {details}")
raise LLMInitializationError(
"Failed to connect to Gemini service. Please check:"
"\n1. GEMINI_API_KEY is correct"
"\n2. GEMINI_MODEL is correct and accessible"
"\n3. Network connectivity to Gemini API",
details=details
) from e
if llm is None:
logger.error("LLM could not be initialized. Exiting.")
print("\nERROR: Unable to initialize LLM. Check logs for details.", file=sys.stderr)
sys.exit(1)
# Ensure llm is treated as a Runnable for chaining
# Cast llm to Any to bypass static type checking for chaining if it's causing issues
llm_runnable: Runnable = llm # type: ignore
# Set up Output Parser for structured JSON
parser = JsonOutputParser(pydantic_object=AnalysisFlags)
# Load prompt template from file
def load_prompt_template(version="v1.2.0"):
try:
with open(f"llm/jira_analysis_{version}.txt", "r") as f:
template_content = f.read()
# Split system and user parts
if "\n\nUSER:\n" in template_content:
system_template, user_template = template_content.split("\n\nUSER:\n")
system_template = system_template.replace("SYSTEM:\n", "").strip()
else:
# Handle legacy format
system_template = template_content
user_template = "Analyze this Jira ticket: {issueKey}"
return ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template(user_template)
])
except Exception as e:
logger.error(f"Failed to load prompt template: {str(e)}")
raise
# Fallback prompt template
FALLBACK_PROMPT = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(
"Analyze Jira tickets and output JSON with hasMultipleEscalations, customerSentiment"
),
HumanMessagePromptTemplate.from_template(
"Issue Key: {issueKey}\nSummary: {summary}"
)
])
# Create chain with fallback mechanism
def create_analysis_chain():
try:
prompt_template = load_prompt_template()
chain = (
{
"issueKey": lambda x: x["issueKey"],
"summary": lambda x: x["summary"],
"description": lambda x: x["description"],
"status": lambda x: x["status"],
"labels": lambda x: x["labels"],
"assignee": lambda x: x["assignee"],
"updated": lambda x: x["updated"],
"comment": lambda x: x["comment"],
"format_instructions": lambda _: parser.get_format_instructions()
}
| prompt_template
| llm_runnable # Use the explicitly typed runnable
| parser
)
return chain
except Exception as e:
logger.warning(f"Using fallback prompt due to error: {str(e)}")
chain = FALLBACK_PROMPT | llm_runnable # Use the explicitly typed runnable
return chain
# Initialize analysis chain
analysis_chain = create_analysis_chain()
# Enhanced response validation function
def validate_response(response: Union[dict, str], issue_key: str = "N/A") -> bool:
"""Validate the JSON response structure and content"""
try:
# If response is a string, attempt to parse it as JSON
if isinstance(response, str):
logger.debug(f"[{issue_key}] Raw LLM response (string): {response}")
try:
response = json.loads(response)
except json.JSONDecodeError as e:
logger.error(f"[{issue_key}] JSONDecodeError: {e}. Raw response: {response}")
return False
# Ensure response is a dictionary
if not isinstance(response, dict):
logger.error(f"[{issue_key}] Response is not a dictionary: {type(response)}")
return False
logger.debug(f"[{issue_key}] Parsed LLM response (JSON): {json.dumps(response)}")
# Validate against schema using AnalysisFlags model
try:
AnalysisFlags.model_validate(response)
return True
except Exception as e:
logger.error(f"[{issue_key}] Pydantic validation error: {e}. Invalid response: {response}")
return False
except Exception as e:
logger.error(f"[{issue_key}] Unexpected error during response validation: {e}. Response: {response}")
return False