stakpak_shared/models/
llm.rs

1use crate::models::{
2    error::{AgentError, BadRequestErrorMessage},
3    integrations::{
4        anthropic::{Anthropic, AnthropicConfig, AnthropicInput, AnthropicModel},
5        gemini::{Gemini, GeminiConfig, GeminiInput, GeminiModel},
6        openai::{OpenAI, OpenAIConfig, OpenAIInput, OpenAIModel},
7    },
8};
9use serde::{Deserialize, Serialize};
10use std::fmt::Display;
11
12#[derive(Clone, Debug, Serialize)]
13pub enum LLMModel {
14    Anthropic(AnthropicModel),
15    Gemini(GeminiModel),
16    OpenAI(OpenAIModel),
17    Custom(String),
18}
19
20#[derive(Debug)]
21pub struct LLMProviderConfig {
22    pub anthropic_config: Option<AnthropicConfig>,
23    pub gemini_config: Option<GeminiConfig>,
24    pub openai_config: Option<OpenAIConfig>,
25}
26
27impl From<String> for LLMModel {
28    fn from(value: String) -> Self {
29        match value.as_str() {
30            "claude-haiku-4-5" => LLMModel::Anthropic(AnthropicModel::Claude45Haiku),
31            "claude-sonnet-4-5" => LLMModel::Anthropic(AnthropicModel::Claude45Sonnet),
32            "claude-opus-4-5" => LLMModel::Anthropic(AnthropicModel::Claude45Opus),
33            "gemini-2.0-flash" => LLMModel::Gemini(GeminiModel::Gemini20Flash),
34            "gemini-2.0-flash-lite" => LLMModel::Gemini(GeminiModel::Gemini20FlashLite),
35            "gemini-2.5-flash" => LLMModel::Gemini(GeminiModel::Gemini25Flash),
36            "gemini-2.5-flash-lite" => LLMModel::Gemini(GeminiModel::Gemini25FlashLite),
37            "gemini-2.5-pro" => LLMModel::Gemini(GeminiModel::Gemini25Pro),
38            "gemini-3-pro-preview" => LLMModel::Gemini(GeminiModel::Gemini3Pro),
39            "gpt-5" => LLMModel::OpenAI(OpenAIModel::GPT5),
40            "gpt-5-mini" => LLMModel::OpenAI(OpenAIModel::GPT5Mini),
41            _ => LLMModel::Custom(value),
42        }
43    }
44}
45
46impl Display for LLMModel {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            LLMModel::Anthropic(model) => write!(f, "{}", model),
50            LLMModel::Gemini(model) => write!(f, "{}", model),
51            LLMModel::OpenAI(model) => write!(f, "{}", model),
52            LLMModel::Custom(model) => write!(f, "{}", model),
53        }
54    }
55}
56
57#[derive(Clone, Debug, Serialize)]
58pub struct LLMInput {
59    pub model: LLMModel,
60    pub messages: Vec<LLMMessage>,
61    pub max_tokens: u32,
62    pub tools: Option<Vec<LLMTool>>,
63}
64
65#[derive(Debug)]
66pub struct LLMStreamInput {
67    pub model: LLMModel,
68    pub messages: Vec<LLMMessage>,
69    pub max_tokens: u32,
70    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
71    pub tools: Option<Vec<LLMTool>>,
72}
73
74impl From<&LLMStreamInput> for LLMInput {
75    fn from(value: &LLMStreamInput) -> Self {
76        LLMInput {
77            model: value.model.clone(),
78            messages: value.messages.clone(),
79            max_tokens: value.max_tokens,
80            tools: value.tools.clone(),
81        }
82    }
83}
84
85pub async fn chat(
86    config: &LLMProviderConfig,
87    input: LLMInput,
88) -> Result<LLMCompletionResponse, AgentError> {
89    match input.model {
90        LLMModel::Anthropic(model) => {
91            if let Some(anthropic_config) = &config.anthropic_config {
92                let anthropic_input = AnthropicInput {
93                    model,
94                    messages: input.messages,
95                    grammar: None,
96                    max_tokens: input.max_tokens,
97                    stop_sequences: None,
98                    tools: input.tools,
99                    thinking: Default::default(),
100                };
101                Anthropic::chat(anthropic_config, anthropic_input).await
102            } else {
103                Err(AgentError::BadRequest(
104                    BadRequestErrorMessage::InvalidAgentInput(
105                        "Anthropic config not found".to_string(),
106                    ),
107                ))
108            }
109        }
110
111        LLMModel::Gemini(model) => {
112            if let Some(gemini_config) = &config.gemini_config {
113                let gemini_input = GeminiInput {
114                    model,
115                    messages: input.messages,
116                    max_tokens: input.max_tokens,
117                    tools: input.tools,
118                };
119                Gemini::chat(gemini_config, gemini_input).await
120            } else {
121                Err(AgentError::BadRequest(
122                    BadRequestErrorMessage::InvalidAgentInput(
123                        "Gemini config not found".to_string(),
124                    ),
125                ))
126            }
127        }
128        LLMModel::OpenAI(model) => {
129            if let Some(openai_config) = &config.openai_config {
130                let openai_input = OpenAIInput {
131                    model,
132                    messages: input.messages,
133                    max_tokens: input.max_tokens,
134                    json: None,
135                    tools: input.tools,
136                    reasoning_effort: None,
137                };
138                OpenAI::chat(openai_config, openai_input).await
139            } else {
140                Err(AgentError::BadRequest(
141                    BadRequestErrorMessage::InvalidAgentInput(
142                        "OpenAI config not found".to_string(),
143                    ),
144                ))
145            }
146        }
147        LLMModel::Custom(model_name) => {
148            if let Some(openai_config) = &config.openai_config {
149                let openai_input = OpenAIInput {
150                    model: OpenAIModel::Custom(model_name),
151                    messages: input.messages,
152                    max_tokens: input.max_tokens,
153                    json: None,
154                    tools: input.tools,
155                    reasoning_effort: None,
156                };
157                OpenAI::chat(openai_config, openai_input).await
158            } else {
159                Err(AgentError::BadRequest(
160                    BadRequestErrorMessage::InvalidAgentInput(
161                        "OpenAI config not found".to_string(),
162                    ),
163                ))
164            }
165        }
166    }
167}
168
169pub async fn chat_stream(
170    config: &LLMProviderConfig,
171    input: LLMStreamInput,
172) -> Result<LLMCompletionResponse, AgentError> {
173    match input.model {
174        LLMModel::Anthropic(model) => {
175            if let Some(anthropic_config) = &config.anthropic_config {
176                let anthropic_input = AnthropicInput {
177                    model,
178                    messages: input.messages,
179                    grammar: None,
180                    max_tokens: input.max_tokens,
181                    stop_sequences: None,
182                    tools: input.tools,
183                    thinking: Default::default(),
184                };
185                Anthropic::chat_stream(anthropic_config, input.stream_channel_tx, anthropic_input)
186                    .await
187            } else {
188                Err(AgentError::BadRequest(
189                    BadRequestErrorMessage::InvalidAgentInput(
190                        "Anthropic config not found".to_string(),
191                    ),
192                ))
193            }
194        }
195
196        LLMModel::Gemini(model) => {
197            if let Some(gemini_config) = &config.gemini_config {
198                let gemini_input = GeminiInput {
199                    model,
200                    messages: input.messages,
201                    max_tokens: input.max_tokens,
202                    tools: input.tools,
203                };
204                Gemini::chat_stream(gemini_config, input.stream_channel_tx, gemini_input).await
205            } else {
206                Err(AgentError::BadRequest(
207                    BadRequestErrorMessage::InvalidAgentInput(
208                        "Gemini config not found".to_string(),
209                    ),
210                ))
211            }
212        }
213        LLMModel::OpenAI(model) => {
214            if let Some(openai_config) = &config.openai_config {
215                let openai_input = OpenAIInput {
216                    model,
217                    messages: input.messages,
218                    max_tokens: input.max_tokens,
219                    json: None,
220                    tools: input.tools,
221                    reasoning_effort: None,
222                };
223                OpenAI::chat_stream(openai_config, input.stream_channel_tx, openai_input).await
224            } else {
225                Err(AgentError::BadRequest(
226                    BadRequestErrorMessage::InvalidAgentInput(
227                        "OpenAI config not found".to_string(),
228                    ),
229                ))
230            }
231        }
232        LLMModel::Custom(model_name) => {
233            if let Some(openai_config) = &config.openai_config {
234                let openai_input = OpenAIInput {
235                    model: OpenAIModel::Custom(model_name),
236                    messages: input.messages,
237                    max_tokens: input.max_tokens,
238                    json: None,
239                    tools: input.tools,
240                    reasoning_effort: None,
241                };
242                OpenAI::chat_stream(openai_config, input.stream_channel_tx, openai_input).await
243            } else {
244                Err(AgentError::BadRequest(
245                    BadRequestErrorMessage::InvalidAgentInput(
246                        "OpenAI config not found".to_string(),
247                    ),
248                ))
249            }
250        }
251    }
252}
253
254#[derive(Serialize, Deserialize, Debug, Clone, Default)]
255pub struct LLMMessage {
256    pub role: String,
257    pub content: LLMMessageContent,
258}
259
260#[derive(Serialize, Deserialize, Debug, Clone)]
261pub struct SimpleLLMMessage {
262    #[serde(rename = "role")]
263    pub role: SimpleLLMRole,
264    pub content: String,
265}
266
267#[derive(Serialize, Deserialize, Debug, Clone)]
268#[serde(rename_all = "lowercase")]
269pub enum SimpleLLMRole {
270    User,
271    Assistant,
272}
273
274impl std::fmt::Display for SimpleLLMRole {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        match self {
277            SimpleLLMRole::User => write!(f, "user"),
278            SimpleLLMRole::Assistant => write!(f, "assistant"),
279        }
280    }
281}
282
283#[derive(Serialize, Deserialize, Debug, Clone)]
284#[serde(untagged)]
285pub enum LLMMessageContent {
286    String(String),
287    List(Vec<LLMMessageTypedContent>),
288}
289
290#[allow(clippy::to_string_trait_impl)]
291impl ToString for LLMMessageContent {
292    fn to_string(&self) -> String {
293        match self {
294            LLMMessageContent::String(s) => s.clone(),
295            LLMMessageContent::List(l) => l
296                .iter()
297                .map(|c| match c {
298                    LLMMessageTypedContent::Text { text } => text.clone(),
299                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
300                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
301                    LLMMessageTypedContent::Image { .. } => String::new(),
302                })
303                .collect::<Vec<_>>()
304                .join("\n"),
305        }
306    }
307}
308
309impl From<String> for LLMMessageContent {
310    fn from(value: String) -> Self {
311        LLMMessageContent::String(value)
312    }
313}
314
315impl Default for LLMMessageContent {
316    fn default() -> Self {
317        LLMMessageContent::String(String::new())
318    }
319}
320
321#[derive(Serialize, Deserialize, Debug, Clone)]
322#[serde(tag = "type")]
323pub enum LLMMessageTypedContent {
324    #[serde(rename = "text")]
325    Text { text: String },
326    #[serde(rename = "tool_use")]
327    ToolCall {
328        id: String,
329        name: String,
330        #[serde(alias = "input")]
331        args: serde_json::Value,
332    },
333    #[serde(rename = "tool_result")]
334    ToolResult {
335        tool_use_id: String,
336        content: String,
337    },
338    #[serde(rename = "image")]
339    Image { source: LLMMessageImageSource },
340}
341
342#[derive(Serialize, Deserialize, Debug, Clone)]
343pub struct LLMMessageImageSource {
344    #[serde(rename = "type")]
345    pub r#type: String,
346    pub media_type: String,
347    pub data: String,
348}
349
350impl Default for LLMMessageTypedContent {
351    fn default() -> Self {
352        LLMMessageTypedContent::Text {
353            text: String::new(),
354        }
355    }
356}
357
358#[derive(Serialize, Deserialize, Debug, Clone)]
359pub struct LLMChoice {
360    pub finish_reason: Option<String>,
361    pub index: u32,
362    pub message: LLMMessage,
363}
364
365#[derive(Serialize, Deserialize, Debug, Clone)]
366pub struct LLMCompletionResponse {
367    pub model: String,
368    pub object: String,
369    pub choices: Vec<LLMChoice>,
370    pub created: u64,
371    pub usage: Option<LLMTokenUsage>,
372    pub id: String,
373}
374
375#[derive(Serialize, Deserialize, Debug, Clone)]
376pub struct LLMStreamDelta {
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub content: Option<String>,
379}
380
381#[derive(Serialize, Deserialize, Debug, Clone)]
382pub struct LLMStreamChoice {
383    pub finish_reason: Option<String>,
384    pub index: u32,
385    pub message: Option<LLMMessage>,
386    pub delta: LLMStreamDelta,
387}
388
389#[derive(Serialize, Deserialize, Debug, Clone)]
390pub struct LLMCompletionStreamResponse {
391    pub model: String,
392    pub object: String,
393    pub choices: Vec<LLMStreamChoice>,
394    pub created: u64,
395    #[serde(skip_serializing_if = "Option::is_none")]
396    pub usage: Option<LLMTokenUsage>,
397    pub id: String,
398    pub citations: Option<Vec<String>>,
399}
400
401#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
402pub struct LLMTool {
403    pub name: String,
404    pub description: String,
405    pub input_schema: serde_json::Value,
406}
407
408#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
409pub struct LLMTokenUsage {
410    pub prompt_tokens: u32,
411    pub completion_tokens: u32,
412    pub total_tokens: u32,
413
414    #[serde(skip_serializing_if = "Option::is_none")]
415    pub prompt_tokens_details: Option<PromptTokensDetails>,
416}
417
418#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
419#[serde(rename_all = "snake_case")]
420pub enum TokenType {
421    InputTokens,
422    OutputTokens,
423    CacheReadInputTokens,
424    CacheWriteInputTokens,
425}
426
427#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
428pub struct PromptTokensDetails {
429    #[serde(skip_serializing_if = "Option::is_none")]
430    pub input_tokens: Option<u32>,
431    #[serde(skip_serializing_if = "Option::is_none")]
432    pub output_tokens: Option<u32>,
433    #[serde(skip_serializing_if = "Option::is_none")]
434    pub cache_read_input_tokens: Option<u32>,
435    #[serde(skip_serializing_if = "Option::is_none")]
436    pub cache_write_input_tokens: Option<u32>,
437}
438
439impl PromptTokensDetails {
440    /// Returns an iterator over the token types and their values
441    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
442        [
443            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
444            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
445            (
446                TokenType::CacheReadInputTokens,
447                self.cache_read_input_tokens.unwrap_or(0),
448            ),
449            (
450                TokenType::CacheWriteInputTokens,
451                self.cache_write_input_tokens.unwrap_or(0),
452            ),
453        ]
454        .into_iter()
455    }
456}
457
458impl std::ops::Add for PromptTokensDetails {
459    type Output = Self;
460
461    fn add(self, rhs: Self) -> Self::Output {
462        Self {
463            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
464            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
465            cache_read_input_tokens: Some(
466                self.cache_read_input_tokens.unwrap_or(0)
467                    + rhs.cache_read_input_tokens.unwrap_or(0),
468            ),
469            cache_write_input_tokens: Some(
470                self.cache_write_input_tokens.unwrap_or(0)
471                    + rhs.cache_write_input_tokens.unwrap_or(0),
472            ),
473        }
474    }
475}
476
477impl std::ops::AddAssign for PromptTokensDetails {
478    fn add_assign(&mut self, rhs: Self) {
479        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
480        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
481        self.cache_read_input_tokens = Some(
482            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
483        );
484        self.cache_write_input_tokens = Some(
485            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
486        );
487    }
488}
489
490#[derive(Serialize, Deserialize, Debug, Clone)]
491#[serde(tag = "type")]
492pub enum GenerationDelta {
493    Content { content: String },
494    Thinking { thinking: String },
495    ToolUse { tool_use: GenerationDeltaToolUse },
496    Usage { usage: LLMTokenUsage },
497    Metadata { metadata: serde_json::Value },
498}
499
500#[derive(Serialize, Deserialize, Debug, Clone)]
501pub struct GenerationDeltaToolUse {
502    pub id: Option<String>,
503    pub name: Option<String>,
504    pub input: Option<String>,
505    pub index: usize,
506}