186 lines
6.5 KiB
Python
186 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
import logging
|
|
import requests
|
|
from typing import Optional, Dict, List, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class OpenRouterError(Exception):
|
|
"""Custom exception for OpenRouter API errors."""
|
|
def __init__(self, message: str, status_code: int = None, response: dict = None):
|
|
super().__init__(message)
|
|
self.status_code = status_code
|
|
self.response = response
|
|
|
|
class OpenRouterResponse:
|
|
"""Wrapper for OpenRouter API responses."""
|
|
def __init__(self, raw_response: dict):
|
|
self.raw_response = raw_response
|
|
self.choices = self._parse_choices()
|
|
self.usage = self._parse_usage()
|
|
self.model = raw_response.get("model")
|
|
|
|
def _parse_choices(self) -> List[Dict[str, Any]]:
|
|
choices = self.raw_response.get("choices", [])
|
|
return [
|
|
{
|
|
"message": choice.get("message", {}),
|
|
"finish_reason": choice.get("finish_reason"),
|
|
"index": choice.get("index")
|
|
}
|
|
for choice in choices
|
|
]
|
|
|
|
def _parse_usage(self) -> Dict[str, int]:
|
|
usage = self.raw_response.get("usage", {})
|
|
return {
|
|
"prompt_tokens": usage.get("prompt_tokens", 0),
|
|
"completion_tokens": usage.get("completion_tokens", 0),
|
|
"total_tokens": usage.get("total_tokens", 0)
|
|
}
|
|
|
|
class OpenRouterClient:
|
|
"""Client for interacting with the OpenRouter API."""
|
|
def __init__(self, api_key: str, model_name: str):
|
|
if not api_key:
|
|
raise ValueError("OpenRouter API key is required")
|
|
if not model_name:
|
|
raise ValueError("Model name is required")
|
|
|
|
self.api_key = api_key
|
|
self.model_name = model_name
|
|
self.base_url = "https://openrouter.ai/api/v1"
|
|
self.session = requests.Session()
|
|
self.session.headers.update({
|
|
"Authorization": f"Bearer {api_key}",
|
|
"HTTP-Referer": "https://github.com/OpenRouterTeam/openrouter-examples",
|
|
"X-Title": "CV Analysis Tool",
|
|
"Content-Type": "application/json"
|
|
})
|
|
|
|
def create_chat_completion(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
max_tokens: Optional[int] = None
|
|
) -> OpenRouterResponse:
|
|
"""
|
|
Create a chat completion using the OpenRouter API.
|
|
|
|
Args:
|
|
messages: List of message dictionaries with 'role' and 'content' keys
|
|
max_tokens: Maximum number of tokens to generate
|
|
|
|
Returns:
|
|
OpenRouterResponse object containing the API response
|
|
|
|
Raises:
|
|
OpenRouterError: If the API request fails
|
|
"""
|
|
endpoint = f"{self.base_url}/chat/completions"
|
|
payload = {
|
|
"model": self.model_name,
|
|
"messages": messages
|
|
}
|
|
|
|
if max_tokens is not None:
|
|
payload["max_tokens"] = max_tokens
|
|
|
|
try:
|
|
response = self.session.post(endpoint, json=payload)
|
|
response.raise_for_status()
|
|
return OpenRouterResponse(response.json())
|
|
except requests.exceptions.RequestException as e:
|
|
raise self._handle_request_error(e)
|
|
|
|
def get_available_models(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get list of available models from OpenRouter API.
|
|
|
|
Returns:
|
|
List of model information dictionaries
|
|
|
|
Raises:
|
|
OpenRouterError: If the API request fails
|
|
"""
|
|
endpoint = f"{self.base_url}/models"
|
|
|
|
try:
|
|
logger.debug(f"Fetching available models from: {endpoint}")
|
|
response = self.session.get(endpoint)
|
|
response.raise_for_status()
|
|
|
|
data = response.json()
|
|
logger.debug(f"Raw API response: {data}")
|
|
|
|
if not isinstance(data, dict) or "data" not in data:
|
|
raise OpenRouterError(
|
|
message="Invalid response format from OpenRouter API",
|
|
response=data
|
|
)
|
|
|
|
return data
|
|
except requests.exceptions.RequestException as e:
|
|
raise self._handle_request_error(e)
|
|
|
|
def verify_model_availability(self) -> bool:
|
|
"""
|
|
Verify if the configured model is available.
|
|
|
|
Returns:
|
|
True if model is available, False otherwise
|
|
"""
|
|
try:
|
|
response = self.get_available_models()
|
|
# OpenRouter API zwraca listę modeli w formacie:
|
|
# {"data": [{"id": "model_name", ...}, ...]}
|
|
models = response.get("data", [])
|
|
logger.debug(f"Available models: {[model.get('id') for model in models]}")
|
|
return any(model.get("id") == self.model_name for model in models)
|
|
except OpenRouterError as e:
|
|
logger.error(f"Failed to verify model availability: {e}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error while verifying model availability: {e}")
|
|
return False
|
|
|
|
def _handle_request_error(self, error: requests.exceptions.RequestException) -> OpenRouterError:
|
|
"""Convert requests exceptions to OpenRouterError."""
|
|
if error.response is not None:
|
|
try:
|
|
error_data = error.response.json()
|
|
message = error_data.get("error", {}).get("message", str(error))
|
|
return OpenRouterError(
|
|
message=message,
|
|
status_code=error.response.status_code,
|
|
response=error_data
|
|
)
|
|
except ValueError:
|
|
pass
|
|
return OpenRouterError(str(error))
|
|
|
|
def initialize_openrouter_client(api_key: str, model_name: str) -> OpenRouterClient:
|
|
"""
|
|
Initialize and verify OpenRouter client.
|
|
|
|
Args:
|
|
api_key: OpenRouter API key
|
|
model_name: Name of the model to use
|
|
|
|
Returns:
|
|
Initialized OpenRouterClient
|
|
|
|
Raises:
|
|
ValueError: If client initialization or verification fails
|
|
"""
|
|
try:
|
|
client = OpenRouterClient(api_key=api_key, model_name=model_name)
|
|
|
|
# Verify connection and model availability
|
|
if not client.verify_model_availability():
|
|
raise ValueError(f"Model {model_name} not available")
|
|
|
|
logger.debug(f"Successfully initialized OpenRouter client with model: {model_name}")
|
|
return client
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize OpenRouter client: {e}")
|
|
raise |