stakpak_shared/models/
llm.rs

1use crate::models::{
2    error::{AgentError, BadRequestErrorMessage},
3    integrations::{
4        anthropic::{Anthropic, AnthropicConfig, AnthropicInput, AnthropicModel},
5        gemini::{Gemini, GeminiConfig, GeminiInput, GeminiModel},
6        openai::{OpenAI, OpenAIConfig, OpenAIInput, OpenAIModel},
7    },
8    model_pricing::{ContextAware, ModelContextInfo},
9};
10use serde::{Deserialize, Serialize};
11use std::fmt::Display;
12
13#[derive(Clone, Debug, PartialEq, Serialize)]
14pub enum LLMModel {
15    Anthropic(AnthropicModel),
16    Gemini(GeminiModel),
17    OpenAI(OpenAIModel),
18    Custom(String),
19}
20
21impl ContextAware for LLMModel {
22    fn context_info(&self) -> ModelContextInfo {
23        match self {
24            LLMModel::Anthropic(model) => model.context_info(),
25            LLMModel::Gemini(model) => model.context_info(),
26            LLMModel::OpenAI(model) => model.context_info(),
27            LLMModel::Custom(_) => ModelContextInfo::default(),
28        }
29    }
30
31    fn model_name(&self) -> String {
32        match self {
33            LLMModel::Anthropic(model) => model.model_name(),
34            LLMModel::Gemini(model) => model.model_name(),
35            LLMModel::OpenAI(model) => model.model_name(),
36            LLMModel::Custom(model_name) => model_name.clone(),
37        }
38    }
39}
40
41#[derive(Debug)]
42pub struct LLMProviderConfig {
43    pub anthropic_config: Option<AnthropicConfig>,
44    pub gemini_config: Option<GeminiConfig>,
45    pub openai_config: Option<OpenAIConfig>,
46}
47
48impl From<String> for LLMModel {
49    fn from(value: String) -> Self {
50        if value.starts_with("claude-haiku-4-5") {
51            LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
52        } else if value.starts_with("claude-sonnet-4-5") {
53            LLMModel::Anthropic(AnthropicModel::Claude45Sonnet)
54        } else if value.starts_with("claude-opus-4-5") {
55            LLMModel::Anthropic(AnthropicModel::Claude45Opus)
56        } else if value == "gemini-2.5-flash-lite" {
57            LLMModel::Gemini(GeminiModel::Gemini25FlashLite)
58        } else if value.starts_with("gemini-2.5-flash") {
59            LLMModel::Gemini(GeminiModel::Gemini25Flash)
60        } else if value.starts_with("gemini-2.5-pro") {
61            LLMModel::Gemini(GeminiModel::Gemini25Pro)
62        } else if value.starts_with("gemini-3-pro-preview") {
63            LLMModel::Gemini(GeminiModel::Gemini3Pro)
64        } else if value.starts_with("gemini-3-flash-preview") {
65            LLMModel::Gemini(GeminiModel::Gemini3Flash)
66        } else if value.starts_with("gpt-5-mini") {
67            LLMModel::OpenAI(OpenAIModel::GPT5Mini)
68        } else if value.starts_with("gpt-5") {
69            LLMModel::OpenAI(OpenAIModel::GPT5)
70        } else {
71            LLMModel::Custom(value)
72        }
73    }
74}
75
76impl Display for LLMModel {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            LLMModel::Anthropic(model) => write!(f, "{}", model),
80            LLMModel::Gemini(model) => write!(f, "{}", model),
81            LLMModel::OpenAI(model) => write!(f, "{}", model),
82            LLMModel::Custom(model) => write!(f, "{}", model),
83        }
84    }
85}
86
87#[derive(Clone, Debug, Serialize)]
88pub struct LLMInput {
89    pub model: LLMModel,
90    pub messages: Vec<LLMMessage>,
91    pub max_tokens: u32,
92    pub tools: Option<Vec<LLMTool>>,
93}
94
95#[derive(Debug)]
96pub struct LLMStreamInput {
97    pub model: LLMModel,
98    pub messages: Vec<LLMMessage>,
99    pub max_tokens: u32,
100    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
101    pub tools: Option<Vec<LLMTool>>,
102}
103
104impl From<&LLMStreamInput> for LLMInput {
105    fn from(value: &LLMStreamInput) -> Self {
106        LLMInput {
107            model: value.model.clone(),
108            messages: value.messages.clone(),
109            max_tokens: value.max_tokens,
110            tools: value.tools.clone(),
111        }
112    }
113}
114
115pub async fn chat(
116    config: &LLMProviderConfig,
117    input: LLMInput,
118) -> Result<LLMCompletionResponse, AgentError> {
119    match input.model {
120        LLMModel::Anthropic(model) => {
121            if let Some(anthropic_config) = &config.anthropic_config {
122                let anthropic_input = AnthropicInput {
123                    model,
124                    messages: input.messages,
125                    grammar: None,
126                    max_tokens: input.max_tokens,
127                    stop_sequences: None,
128                    tools: input.tools,
129                    thinking: Default::default(),
130                };
131                Anthropic::chat(anthropic_config, anthropic_input).await
132            } else {
133                Err(AgentError::BadRequest(
134                    BadRequestErrorMessage::InvalidAgentInput(
135                        "Anthropic config not found".to_string(),
136                    ),
137                ))
138            }
139        }
140
141        LLMModel::Gemini(model) => {
142            if let Some(gemini_config) = &config.gemini_config {
143                let gemini_input = GeminiInput {
144                    model,
145                    messages: input.messages,
146                    max_tokens: input.max_tokens,
147                    tools: input.tools,
148                };
149                Gemini::chat(gemini_config, gemini_input).await
150            } else {
151                Err(AgentError::BadRequest(
152                    BadRequestErrorMessage::InvalidAgentInput(
153                        "Gemini config not found".to_string(),
154                    ),
155                ))
156            }
157        }
158        LLMModel::OpenAI(model) => {
159            if let Some(openai_config) = &config.openai_config {
160                let openai_input = OpenAIInput {
161                    model,
162                    messages: input.messages,
163                    max_tokens: input.max_tokens,
164                    json: None,
165                    tools: input.tools,
166                    reasoning_effort: None,
167                };
168                OpenAI::chat(openai_config, openai_input).await
169            } else {
170                Err(AgentError::BadRequest(
171                    BadRequestErrorMessage::InvalidAgentInput(
172                        "OpenAI config not found".to_string(),
173                    ),
174                ))
175            }
176        }
177        LLMModel::Custom(model_name) => {
178            if let Some(openai_config) = &config.openai_config {
179                let openai_input = OpenAIInput {
180                    model: OpenAIModel::Custom(model_name),
181                    messages: input.messages,
182                    max_tokens: input.max_tokens,
183                    json: None,
184                    tools: input.tools,
185                    reasoning_effort: None,
186                };
187                OpenAI::chat(openai_config, openai_input).await
188            } else {
189                Err(AgentError::BadRequest(
190                    BadRequestErrorMessage::InvalidAgentInput(
191                        "OpenAI config not found".to_string(),
192                    ),
193                ))
194            }
195        }
196    }
197}
198
199pub async fn chat_stream(
200    config: &LLMProviderConfig,
201    input: LLMStreamInput,
202) -> Result<LLMCompletionResponse, AgentError> {
203    match input.model {
204        LLMModel::Anthropic(model) => {
205            if let Some(anthropic_config) = &config.anthropic_config {
206                let anthropic_input = AnthropicInput {
207                    model,
208                    messages: input.messages,
209                    grammar: None,
210                    max_tokens: input.max_tokens,
211                    stop_sequences: None,
212                    tools: input.tools,
213                    thinking: Default::default(),
214                };
215                Anthropic::chat_stream(anthropic_config, input.stream_channel_tx, anthropic_input)
216                    .await
217            } else {
218                Err(AgentError::BadRequest(
219                    BadRequestErrorMessage::InvalidAgentInput(
220                        "Anthropic config not found".to_string(),
221                    ),
222                ))
223            }
224        }
225
226        LLMModel::Gemini(model) => {
227            if let Some(gemini_config) = &config.gemini_config {
228                let gemini_input = GeminiInput {
229                    model,
230                    messages: input.messages,
231                    max_tokens: input.max_tokens,
232                    tools: input.tools,
233                };
234                Gemini::chat_stream(gemini_config, input.stream_channel_tx, gemini_input).await
235            } else {
236                Err(AgentError::BadRequest(
237                    BadRequestErrorMessage::InvalidAgentInput(
238                        "Gemini config not found".to_string(),
239                    ),
240                ))
241            }
242        }
243        LLMModel::OpenAI(model) => {
244            if let Some(openai_config) = &config.openai_config {
245                let openai_input = OpenAIInput {
246                    model,
247                    messages: input.messages,
248                    max_tokens: input.max_tokens,
249                    json: None,
250                    tools: input.tools,
251                    reasoning_effort: None,
252                };
253                OpenAI::chat_stream(openai_config, input.stream_channel_tx, openai_input).await
254            } else {
255                Err(AgentError::BadRequest(
256                    BadRequestErrorMessage::InvalidAgentInput(
257                        "OpenAI config not found".to_string(),
258                    ),
259                ))
260            }
261        }
262        LLMModel::Custom(model_name) => {
263            if let Some(openai_config) = &config.openai_config {
264                let openai_input = OpenAIInput {
265                    model: OpenAIModel::Custom(model_name),
266                    messages: input.messages,
267                    max_tokens: input.max_tokens,
268                    json: None,
269                    tools: input.tools,
270                    reasoning_effort: None,
271                };
272                OpenAI::chat_stream(openai_config, input.stream_channel_tx, openai_input).await
273            } else {
274                Err(AgentError::BadRequest(
275                    BadRequestErrorMessage::InvalidAgentInput(
276                        "OpenAI config not found".to_string(),
277                    ),
278                ))
279            }
280        }
281    }
282}
283
284#[derive(Serialize, Deserialize, Debug, Clone, Default)]
285pub struct LLMMessage {
286    pub role: String,
287    pub content: LLMMessageContent,
288}
289
290#[derive(Serialize, Deserialize, Debug, Clone)]
291pub struct SimpleLLMMessage {
292    #[serde(rename = "role")]
293    pub role: SimpleLLMRole,
294    pub content: String,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone)]
298#[serde(rename_all = "lowercase")]
299pub enum SimpleLLMRole {
300    User,
301    Assistant,
302}
303
304impl std::fmt::Display for SimpleLLMRole {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        match self {
307            SimpleLLMRole::User => write!(f, "user"),
308            SimpleLLMRole::Assistant => write!(f, "assistant"),
309        }
310    }
311}
312
313#[derive(Serialize, Deserialize, Debug, Clone)]
314#[serde(untagged)]
315pub enum LLMMessageContent {
316    String(String),
317    List(Vec<LLMMessageTypedContent>),
318}
319
320#[allow(clippy::to_string_trait_impl)]
321impl ToString for LLMMessageContent {
322    fn to_string(&self) -> String {
323        match self {
324            LLMMessageContent::String(s) => s.clone(),
325            LLMMessageContent::List(l) => l
326                .iter()
327                .map(|c| match c {
328                    LLMMessageTypedContent::Text { text } => text.clone(),
329                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
330                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
331                    LLMMessageTypedContent::Image { .. } => String::new(),
332                })
333                .collect::<Vec<_>>()
334                .join("\n"),
335        }
336    }
337}
338
339impl From<String> for LLMMessageContent {
340    fn from(value: String) -> Self {
341        LLMMessageContent::String(value)
342    }
343}
344
345impl Default for LLMMessageContent {
346    fn default() -> Self {
347        LLMMessageContent::String(String::new())
348    }
349}
350
351#[derive(Serialize, Deserialize, Debug, Clone)]
352#[serde(tag = "type")]
353pub enum LLMMessageTypedContent {
354    #[serde(rename = "text")]
355    Text { text: String },
356    #[serde(rename = "tool_use")]
357    ToolCall {
358        id: String,
359        name: String,
360        #[serde(alias = "input")]
361        args: serde_json::Value,
362    },
363    #[serde(rename = "tool_result")]
364    ToolResult {
365        tool_use_id: String,
366        content: String,
367    },
368    #[serde(rename = "image")]
369    Image { source: LLMMessageImageSource },
370}
371
372#[derive(Serialize, Deserialize, Debug, Clone)]
373pub struct LLMMessageImageSource {
374    #[serde(rename = "type")]
375    pub r#type: String,
376    pub media_type: String,
377    pub data: String,
378}
379
380impl Default for LLMMessageTypedContent {
381    fn default() -> Self {
382        LLMMessageTypedContent::Text {
383            text: String::new(),
384        }
385    }
386}
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389pub struct LLMChoice {
390    pub finish_reason: Option<String>,
391    pub index: u32,
392    pub message: LLMMessage,
393}
394
395#[derive(Serialize, Deserialize, Debug, Clone)]
396pub struct LLMCompletionResponse {
397    pub model: String,
398    pub object: String,
399    pub choices: Vec<LLMChoice>,
400    pub created: u64,
401    pub usage: Option<LLMTokenUsage>,
402    pub id: String,
403}
404
405#[derive(Serialize, Deserialize, Debug, Clone)]
406pub struct LLMStreamDelta {
407    #[serde(skip_serializing_if = "Option::is_none")]
408    pub content: Option<String>,
409}
410
411#[derive(Serialize, Deserialize, Debug, Clone)]
412pub struct LLMStreamChoice {
413    pub finish_reason: Option<String>,
414    pub index: u32,
415    pub message: Option<LLMMessage>,
416    pub delta: LLMStreamDelta,
417}
418
419#[derive(Serialize, Deserialize, Debug, Clone)]
420pub struct LLMCompletionStreamResponse {
421    pub model: String,
422    pub object: String,
423    pub choices: Vec<LLMStreamChoice>,
424    pub created: u64,
425    #[serde(skip_serializing_if = "Option::is_none")]
426    pub usage: Option<LLMTokenUsage>,
427    pub id: String,
428    pub citations: Option<Vec<String>>,
429}
430
431#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
432pub struct LLMTool {
433    pub name: String,
434    pub description: String,
435    pub input_schema: serde_json::Value,
436}
437
438#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
439pub struct LLMTokenUsage {
440    pub prompt_tokens: u32,
441    pub completion_tokens: u32,
442    pub total_tokens: u32,
443
444    #[serde(skip_serializing_if = "Option::is_none")]
445    pub prompt_tokens_details: Option<PromptTokensDetails>,
446}
447
448#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
449#[serde(rename_all = "snake_case")]
450pub enum TokenType {
451    InputTokens,
452    OutputTokens,
453    CacheReadInputTokens,
454    CacheWriteInputTokens,
455}
456
457#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
458pub struct PromptTokensDetails {
459    #[serde(skip_serializing_if = "Option::is_none")]
460    pub input_tokens: Option<u32>,
461    #[serde(skip_serializing_if = "Option::is_none")]
462    pub output_tokens: Option<u32>,
463    #[serde(skip_serializing_if = "Option::is_none")]
464    pub cache_read_input_tokens: Option<u32>,
465    #[serde(skip_serializing_if = "Option::is_none")]
466    pub cache_write_input_tokens: Option<u32>,
467}
468
469impl PromptTokensDetails {
470    /// Returns an iterator over the token types and their values
471    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
472        [
473            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
474            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
475            (
476                TokenType::CacheReadInputTokens,
477                self.cache_read_input_tokens.unwrap_or(0),
478            ),
479            (
480                TokenType::CacheWriteInputTokens,
481                self.cache_write_input_tokens.unwrap_or(0),
482            ),
483        ]
484        .into_iter()
485    }
486}
487
488impl std::ops::Add for PromptTokensDetails {
489    type Output = Self;
490
491    fn add(self, rhs: Self) -> Self::Output {
492        Self {
493            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
494            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
495            cache_read_input_tokens: Some(
496                self.cache_read_input_tokens.unwrap_or(0)
497                    + rhs.cache_read_input_tokens.unwrap_or(0),
498            ),
499            cache_write_input_tokens: Some(
500                self.cache_write_input_tokens.unwrap_or(0)
501                    + rhs.cache_write_input_tokens.unwrap_or(0),
502            ),
503        }
504    }
505}
506
507impl std::ops::AddAssign for PromptTokensDetails {
508    fn add_assign(&mut self, rhs: Self) {
509        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
510        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
511        self.cache_read_input_tokens = Some(
512            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
513        );
514        self.cache_write_input_tokens = Some(
515            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
516        );
517    }
518}
519
520#[derive(Serialize, Deserialize, Debug, Clone)]
521#[serde(tag = "type")]
522pub enum GenerationDelta {
523    Content { content: String },
524    Thinking { thinking: String },
525    ToolUse { tool_use: GenerationDeltaToolUse },
526    Usage { usage: LLMTokenUsage },
527    Metadata { metadata: serde_json::Value },
528}
529
530#[derive(Serialize, Deserialize, Debug, Clone)]
531pub struct GenerationDeltaToolUse {
532    pub id: Option<String>,
533    pub name: Option<String>,
534    pub input: Option<String>,
535    pub index: usize,
536}