stakpak_shared/models/
llm.rs

1use crate::models::{
2    integrations::{
3        anthropic::{AnthropicConfig, AnthropicModel},
4        gemini::{GeminiConfig, GeminiModel},
5        openai::{OpenAIConfig, OpenAIModel},
6    },
7    model_pricing::{ContextAware, ModelContextInfo},
8};
9use serde::{Deserialize, Serialize};
10use std::fmt::Display;
11
12#[derive(Clone, Debug, PartialEq, Serialize)]
13pub enum LLMModel {
14    Anthropic(AnthropicModel),
15    Gemini(GeminiModel),
16    OpenAI(OpenAIModel),
17    Custom(String),
18}
19
20impl ContextAware for LLMModel {
21    fn context_info(&self) -> ModelContextInfo {
22        match self {
23            LLMModel::Anthropic(model) => model.context_info(),
24            LLMModel::Gemini(model) => model.context_info(),
25            LLMModel::OpenAI(model) => model.context_info(),
26            LLMModel::Custom(_) => ModelContextInfo::default(),
27        }
28    }
29
30    fn model_name(&self) -> String {
31        match self {
32            LLMModel::Anthropic(model) => model.model_name(),
33            LLMModel::Gemini(model) => model.model_name(),
34            LLMModel::OpenAI(model) => model.model_name(),
35            LLMModel::Custom(model_name) => model_name.clone(),
36        }
37    }
38}
39
40#[derive(Debug)]
41pub struct LLMProviderConfig {
42    pub anthropic_config: Option<AnthropicConfig>,
43    pub gemini_config: Option<GeminiConfig>,
44    pub openai_config: Option<OpenAIConfig>,
45}
46
47impl From<String> for LLMModel {
48    fn from(value: String) -> Self {
49        if value.starts_with("claude-haiku-4-5") {
50            LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
51        } else if value.starts_with("claude-sonnet-4-5") {
52            LLMModel::Anthropic(AnthropicModel::Claude45Sonnet)
53        } else if value.starts_with("claude-opus-4-5") {
54            LLMModel::Anthropic(AnthropicModel::Claude45Opus)
55        } else if value == "gemini-2.5-flash-lite" {
56            LLMModel::Gemini(GeminiModel::Gemini25FlashLite)
57        } else if value.starts_with("gemini-2.5-flash") {
58            LLMModel::Gemini(GeminiModel::Gemini25Flash)
59        } else if value.starts_with("gemini-2.5-pro") {
60            LLMModel::Gemini(GeminiModel::Gemini25Pro)
61        } else if value.starts_with("gemini-3-pro-preview") {
62            LLMModel::Gemini(GeminiModel::Gemini3Pro)
63        } else if value.starts_with("gemini-3-flash-preview") {
64            LLMModel::Gemini(GeminiModel::Gemini3Flash)
65        } else if value.starts_with("gpt-5-mini") {
66            LLMModel::OpenAI(OpenAIModel::GPT5Mini)
67        } else if value.starts_with("gpt-5") {
68            LLMModel::OpenAI(OpenAIModel::GPT5)
69        } else {
70            LLMModel::Custom(value)
71        }
72    }
73}
74
75impl Display for LLMModel {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            LLMModel::Anthropic(model) => write!(f, "{}", model),
79            LLMModel::Gemini(model) => write!(f, "{}", model),
80            LLMModel::OpenAI(model) => write!(f, "{}", model),
81            LLMModel::Custom(model) => write!(f, "{}", model),
82        }
83    }
84}
85
86/// Provider-specific options for LLM requests
87#[derive(Clone, Debug, Serialize, Deserialize, Default)]
88pub struct LLMProviderOptions {
89    /// Anthropic-specific options
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub anthropic: Option<LLMAnthropicOptions>,
92
93    /// OpenAI-specific options
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub openai: Option<LLMOpenAIOptions>,
96
97    /// Google/Gemini-specific options
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub google: Option<LLMGoogleOptions>,
100}
101
102/// Anthropic-specific options
103#[derive(Clone, Debug, Serialize, Deserialize, Default)]
104pub struct LLMAnthropicOptions {
105    /// Extended thinking configuration
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub thinking: Option<LLMThinkingOptions>,
108}
109
110/// Thinking/reasoning options
111#[derive(Clone, Debug, Serialize, Deserialize)]
112pub struct LLMThinkingOptions {
113    /// Budget tokens for thinking (must be >= 1024)
114    pub budget_tokens: u32,
115}
116
117impl LLMThinkingOptions {
118    pub fn new(budget_tokens: u32) -> Self {
119        Self {
120            budget_tokens: budget_tokens.max(1024),
121        }
122    }
123}
124
125/// OpenAI-specific options
126#[derive(Clone, Debug, Serialize, Deserialize, Default)]
127pub struct LLMOpenAIOptions {
128    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub reasoning_effort: Option<String>,
131}
132
133/// Google/Gemini-specific options
134#[derive(Clone, Debug, Serialize, Deserialize, Default)]
135pub struct LLMGoogleOptions {
136    /// Thinking budget in tokens
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub thinking_budget: Option<u32>,
139}
140
141#[derive(Clone, Debug, Serialize)]
142pub struct LLMInput {
143    pub model: LLMModel,
144    pub messages: Vec<LLMMessage>,
145    pub max_tokens: u32,
146    pub tools: Option<Vec<LLMTool>>,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub provider_options: Option<LLMProviderOptions>,
149}
150
151#[derive(Debug)]
152pub struct LLMStreamInput {
153    pub model: LLMModel,
154    pub messages: Vec<LLMMessage>,
155    pub max_tokens: u32,
156    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
157    pub tools: Option<Vec<LLMTool>>,
158    pub provider_options: Option<LLMProviderOptions>,
159}
160
161impl From<&LLMStreamInput> for LLMInput {
162    fn from(value: &LLMStreamInput) -> Self {
163        LLMInput {
164            model: value.model.clone(),
165            messages: value.messages.clone(),
166            max_tokens: value.max_tokens,
167            tools: value.tools.clone(),
168            provider_options: value.provider_options.clone(),
169        }
170    }
171}
172
173#[derive(Serialize, Deserialize, Debug, Clone, Default)]
174pub struct LLMMessage {
175    pub role: String,
176    pub content: LLMMessageContent,
177}
178
179#[derive(Serialize, Deserialize, Debug, Clone)]
180pub struct SimpleLLMMessage {
181    #[serde(rename = "role")]
182    pub role: SimpleLLMRole,
183    pub content: String,
184}
185
186#[derive(Serialize, Deserialize, Debug, Clone)]
187#[serde(rename_all = "lowercase")]
188pub enum SimpleLLMRole {
189    User,
190    Assistant,
191}
192
193impl std::fmt::Display for SimpleLLMRole {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        match self {
196            SimpleLLMRole::User => write!(f, "user"),
197            SimpleLLMRole::Assistant => write!(f, "assistant"),
198        }
199    }
200}
201
202#[derive(Serialize, Deserialize, Debug, Clone)]
203#[serde(untagged)]
204pub enum LLMMessageContent {
205    String(String),
206    List(Vec<LLMMessageTypedContent>),
207}
208
209#[allow(clippy::to_string_trait_impl)]
210impl ToString for LLMMessageContent {
211    fn to_string(&self) -> String {
212        match self {
213            LLMMessageContent::String(s) => s.clone(),
214            LLMMessageContent::List(l) => l
215                .iter()
216                .map(|c| match c {
217                    LLMMessageTypedContent::Text { text } => text.clone(),
218                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
219                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
220                    LLMMessageTypedContent::Image { .. } => String::new(),
221                })
222                .collect::<Vec<_>>()
223                .join("\n"),
224        }
225    }
226}
227
228impl From<String> for LLMMessageContent {
229    fn from(value: String) -> Self {
230        LLMMessageContent::String(value)
231    }
232}
233
234impl Default for LLMMessageContent {
235    fn default() -> Self {
236        LLMMessageContent::String(String::new())
237    }
238}
239
240#[derive(Serialize, Deserialize, Debug, Clone)]
241#[serde(tag = "type")]
242pub enum LLMMessageTypedContent {
243    #[serde(rename = "text")]
244    Text { text: String },
245    #[serde(rename = "tool_use")]
246    ToolCall {
247        id: String,
248        name: String,
249        #[serde(alias = "input")]
250        args: serde_json::Value,
251    },
252    #[serde(rename = "tool_result")]
253    ToolResult {
254        tool_use_id: String,
255        content: String,
256    },
257    #[serde(rename = "image")]
258    Image { source: LLMMessageImageSource },
259}
260
261#[derive(Serialize, Deserialize, Debug, Clone)]
262pub struct LLMMessageImageSource {
263    #[serde(rename = "type")]
264    pub r#type: String,
265    pub media_type: String,
266    pub data: String,
267}
268
269impl Default for LLMMessageTypedContent {
270    fn default() -> Self {
271        LLMMessageTypedContent::Text {
272            text: String::new(),
273        }
274    }
275}
276
277#[derive(Serialize, Deserialize, Debug, Clone)]
278pub struct LLMChoice {
279    pub finish_reason: Option<String>,
280    pub index: u32,
281    pub message: LLMMessage,
282}
283
284#[derive(Serialize, Deserialize, Debug, Clone)]
285pub struct LLMCompletionResponse {
286    pub model: String,
287    pub object: String,
288    pub choices: Vec<LLMChoice>,
289    pub created: u64,
290    pub usage: Option<LLMTokenUsage>,
291    pub id: String,
292}
293
294#[derive(Serialize, Deserialize, Debug, Clone)]
295pub struct LLMStreamDelta {
296    #[serde(skip_serializing_if = "Option::is_none")]
297    pub content: Option<String>,
298}
299
300#[derive(Serialize, Deserialize, Debug, Clone)]
301pub struct LLMStreamChoice {
302    pub finish_reason: Option<String>,
303    pub index: u32,
304    pub message: Option<LLMMessage>,
305    pub delta: LLMStreamDelta,
306}
307
308#[derive(Serialize, Deserialize, Debug, Clone)]
309pub struct LLMCompletionStreamResponse {
310    pub model: String,
311    pub object: String,
312    pub choices: Vec<LLMStreamChoice>,
313    pub created: u64,
314    #[serde(skip_serializing_if = "Option::is_none")]
315    pub usage: Option<LLMTokenUsage>,
316    pub id: String,
317    pub citations: Option<Vec<String>>,
318}
319
320#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
321pub struct LLMTool {
322    pub name: String,
323    pub description: String,
324    pub input_schema: serde_json::Value,
325}
326
327#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
328pub struct LLMTokenUsage {
329    pub prompt_tokens: u32,
330    pub completion_tokens: u32,
331    pub total_tokens: u32,
332
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub prompt_tokens_details: Option<PromptTokensDetails>,
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
338#[serde(rename_all = "snake_case")]
339pub enum TokenType {
340    InputTokens,
341    OutputTokens,
342    CacheReadInputTokens,
343    CacheWriteInputTokens,
344}
345
346#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
347pub struct PromptTokensDetails {
348    #[serde(skip_serializing_if = "Option::is_none")]
349    pub input_tokens: Option<u32>,
350    #[serde(skip_serializing_if = "Option::is_none")]
351    pub output_tokens: Option<u32>,
352    #[serde(skip_serializing_if = "Option::is_none")]
353    pub cache_read_input_tokens: Option<u32>,
354    #[serde(skip_serializing_if = "Option::is_none")]
355    pub cache_write_input_tokens: Option<u32>,
356}
357
358impl PromptTokensDetails {
359    /// Returns an iterator over the token types and their values
360    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
361        [
362            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
363            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
364            (
365                TokenType::CacheReadInputTokens,
366                self.cache_read_input_tokens.unwrap_or(0),
367            ),
368            (
369                TokenType::CacheWriteInputTokens,
370                self.cache_write_input_tokens.unwrap_or(0),
371            ),
372        ]
373        .into_iter()
374    }
375}
376
377impl std::ops::Add for PromptTokensDetails {
378    type Output = Self;
379
380    fn add(self, rhs: Self) -> Self::Output {
381        Self {
382            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
383            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
384            cache_read_input_tokens: Some(
385                self.cache_read_input_tokens.unwrap_or(0)
386                    + rhs.cache_read_input_tokens.unwrap_or(0),
387            ),
388            cache_write_input_tokens: Some(
389                self.cache_write_input_tokens.unwrap_or(0)
390                    + rhs.cache_write_input_tokens.unwrap_or(0),
391            ),
392        }
393    }
394}
395
396impl std::ops::AddAssign for PromptTokensDetails {
397    fn add_assign(&mut self, rhs: Self) {
398        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
399        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
400        self.cache_read_input_tokens = Some(
401            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
402        );
403        self.cache_write_input_tokens = Some(
404            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
405        );
406    }
407}
408
409#[derive(Serialize, Deserialize, Debug, Clone)]
410#[serde(tag = "type")]
411pub enum GenerationDelta {
412    Content { content: String },
413    Thinking { thinking: String },
414    ToolUse { tool_use: GenerationDeltaToolUse },
415    Usage { usage: LLMTokenUsage },
416    Metadata { metadata: serde_json::Value },
417}
418
419#[derive(Serialize, Deserialize, Debug, Clone)]
420pub struct GenerationDeltaToolUse {
421    pub id: Option<String>,
422    pub name: Option<String>,
423    pub input: Option<String>,
424    pub index: usize,
425}