1use crate::models::{
2 integrations::{
3 anthropic::{AnthropicConfig, AnthropicModel},
4 gemini::{GeminiConfig, GeminiModel},
5 openai::{OpenAIConfig, OpenAIModel},
6 },
7 model_pricing::{ContextAware, ModelContextInfo},
8};
9use serde::{Deserialize, Serialize};
10use std::fmt::Display;
11
12#[derive(Clone, Debug, PartialEq, Serialize)]
13pub enum LLMModel {
14 Anthropic(AnthropicModel),
15 Gemini(GeminiModel),
16 OpenAI(OpenAIModel),
17 Custom(String),
18}
19
20impl ContextAware for LLMModel {
21 fn context_info(&self) -> ModelContextInfo {
22 match self {
23 LLMModel::Anthropic(model) => model.context_info(),
24 LLMModel::Gemini(model) => model.context_info(),
25 LLMModel::OpenAI(model) => model.context_info(),
26 LLMModel::Custom(_) => ModelContextInfo::default(),
27 }
28 }
29
30 fn model_name(&self) -> String {
31 match self {
32 LLMModel::Anthropic(model) => model.model_name(),
33 LLMModel::Gemini(model) => model.model_name(),
34 LLMModel::OpenAI(model) => model.model_name(),
35 LLMModel::Custom(model_name) => model_name.clone(),
36 }
37 }
38}
39
40#[derive(Debug)]
41pub struct LLMProviderConfig {
42 pub anthropic_config: Option<AnthropicConfig>,
43 pub gemini_config: Option<GeminiConfig>,
44 pub openai_config: Option<OpenAIConfig>,
45}
46
47impl From<String> for LLMModel {
48 fn from(value: String) -> Self {
49 if value.starts_with("claude-haiku-4-5") {
50 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
51 } else if value.starts_with("claude-sonnet-4-5") {
52 LLMModel::Anthropic(AnthropicModel::Claude45Sonnet)
53 } else if value.starts_with("claude-opus-4-5") {
54 LLMModel::Anthropic(AnthropicModel::Claude45Opus)
55 } else if value == "gemini-2.5-flash-lite" {
56 LLMModel::Gemini(GeminiModel::Gemini25FlashLite)
57 } else if value.starts_with("gemini-2.5-flash") {
58 LLMModel::Gemini(GeminiModel::Gemini25Flash)
59 } else if value.starts_with("gemini-2.5-pro") {
60 LLMModel::Gemini(GeminiModel::Gemini25Pro)
61 } else if value.starts_with("gemini-3-pro-preview") {
62 LLMModel::Gemini(GeminiModel::Gemini3Pro)
63 } else if value.starts_with("gemini-3-flash-preview") {
64 LLMModel::Gemini(GeminiModel::Gemini3Flash)
65 } else if value.starts_with("gpt-5-mini") {
66 LLMModel::OpenAI(OpenAIModel::GPT5Mini)
67 } else if value.starts_with("gpt-5") {
68 LLMModel::OpenAI(OpenAIModel::GPT5)
69 } else {
70 LLMModel::Custom(value)
71 }
72 }
73}
74
75impl Display for LLMModel {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 LLMModel::Anthropic(model) => write!(f, "{}", model),
79 LLMModel::Gemini(model) => write!(f, "{}", model),
80 LLMModel::OpenAI(model) => write!(f, "{}", model),
81 LLMModel::Custom(model) => write!(f, "{}", model),
82 }
83 }
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize, Default)]
88pub struct LLMProviderOptions {
89 #[serde(skip_serializing_if = "Option::is_none")]
91 pub anthropic: Option<LLMAnthropicOptions>,
92
93 #[serde(skip_serializing_if = "Option::is_none")]
95 pub openai: Option<LLMOpenAIOptions>,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub google: Option<LLMGoogleOptions>,
100}
101
102#[derive(Clone, Debug, Serialize, Deserialize, Default)]
104pub struct LLMAnthropicOptions {
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub thinking: Option<LLMThinkingOptions>,
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize)]
112pub struct LLMThinkingOptions {
113 pub budget_tokens: u32,
115}
116
117impl LLMThinkingOptions {
118 pub fn new(budget_tokens: u32) -> Self {
119 Self {
120 budget_tokens: budget_tokens.max(1024),
121 }
122 }
123}
124
125#[derive(Clone, Debug, Serialize, Deserialize, Default)]
127pub struct LLMOpenAIOptions {
128 #[serde(skip_serializing_if = "Option::is_none")]
130 pub reasoning_effort: Option<String>,
131}
132
133#[derive(Clone, Debug, Serialize, Deserialize, Default)]
135pub struct LLMGoogleOptions {
136 #[serde(skip_serializing_if = "Option::is_none")]
138 pub thinking_budget: Option<u32>,
139}
140
141#[derive(Clone, Debug, Serialize)]
142pub struct LLMInput {
143 pub model: LLMModel,
144 pub messages: Vec<LLMMessage>,
145 pub max_tokens: u32,
146 pub tools: Option<Vec<LLMTool>>,
147 #[serde(skip_serializing_if = "Option::is_none")]
148 pub provider_options: Option<LLMProviderOptions>,
149}
150
151#[derive(Debug)]
152pub struct LLMStreamInput {
153 pub model: LLMModel,
154 pub messages: Vec<LLMMessage>,
155 pub max_tokens: u32,
156 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
157 pub tools: Option<Vec<LLMTool>>,
158 pub provider_options: Option<LLMProviderOptions>,
159}
160
161impl From<&LLMStreamInput> for LLMInput {
162 fn from(value: &LLMStreamInput) -> Self {
163 LLMInput {
164 model: value.model.clone(),
165 messages: value.messages.clone(),
166 max_tokens: value.max_tokens,
167 tools: value.tools.clone(),
168 provider_options: value.provider_options.clone(),
169 }
170 }
171}
172
173#[derive(Serialize, Deserialize, Debug, Clone, Default)]
174pub struct LLMMessage {
175 pub role: String,
176 pub content: LLMMessageContent,
177}
178
179#[derive(Serialize, Deserialize, Debug, Clone)]
180pub struct SimpleLLMMessage {
181 #[serde(rename = "role")]
182 pub role: SimpleLLMRole,
183 pub content: String,
184}
185
186#[derive(Serialize, Deserialize, Debug, Clone)]
187#[serde(rename_all = "lowercase")]
188pub enum SimpleLLMRole {
189 User,
190 Assistant,
191}
192
193impl std::fmt::Display for SimpleLLMRole {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 match self {
196 SimpleLLMRole::User => write!(f, "user"),
197 SimpleLLMRole::Assistant => write!(f, "assistant"),
198 }
199 }
200}
201
202#[derive(Serialize, Deserialize, Debug, Clone)]
203#[serde(untagged)]
204pub enum LLMMessageContent {
205 String(String),
206 List(Vec<LLMMessageTypedContent>),
207}
208
209#[allow(clippy::to_string_trait_impl)]
210impl ToString for LLMMessageContent {
211 fn to_string(&self) -> String {
212 match self {
213 LLMMessageContent::String(s) => s.clone(),
214 LLMMessageContent::List(l) => l
215 .iter()
216 .map(|c| match c {
217 LLMMessageTypedContent::Text { text } => text.clone(),
218 LLMMessageTypedContent::ToolCall { .. } => String::new(),
219 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
220 LLMMessageTypedContent::Image { .. } => String::new(),
221 })
222 .collect::<Vec<_>>()
223 .join("\n"),
224 }
225 }
226}
227
228impl From<String> for LLMMessageContent {
229 fn from(value: String) -> Self {
230 LLMMessageContent::String(value)
231 }
232}
233
234impl Default for LLMMessageContent {
235 fn default() -> Self {
236 LLMMessageContent::String(String::new())
237 }
238}
239
240#[derive(Serialize, Deserialize, Debug, Clone)]
241#[serde(tag = "type")]
242pub enum LLMMessageTypedContent {
243 #[serde(rename = "text")]
244 Text { text: String },
245 #[serde(rename = "tool_use")]
246 ToolCall {
247 id: String,
248 name: String,
249 #[serde(alias = "input")]
250 args: serde_json::Value,
251 },
252 #[serde(rename = "tool_result")]
253 ToolResult {
254 tool_use_id: String,
255 content: String,
256 },
257 #[serde(rename = "image")]
258 Image { source: LLMMessageImageSource },
259}
260
261#[derive(Serialize, Deserialize, Debug, Clone)]
262pub struct LLMMessageImageSource {
263 #[serde(rename = "type")]
264 pub r#type: String,
265 pub media_type: String,
266 pub data: String,
267}
268
269impl Default for LLMMessageTypedContent {
270 fn default() -> Self {
271 LLMMessageTypedContent::Text {
272 text: String::new(),
273 }
274 }
275}
276
277#[derive(Serialize, Deserialize, Debug, Clone)]
278pub struct LLMChoice {
279 pub finish_reason: Option<String>,
280 pub index: u32,
281 pub message: LLMMessage,
282}
283
284#[derive(Serialize, Deserialize, Debug, Clone)]
285pub struct LLMCompletionResponse {
286 pub model: String,
287 pub object: String,
288 pub choices: Vec<LLMChoice>,
289 pub created: u64,
290 pub usage: Option<LLMTokenUsage>,
291 pub id: String,
292}
293
294#[derive(Serialize, Deserialize, Debug, Clone)]
295pub struct LLMStreamDelta {
296 #[serde(skip_serializing_if = "Option::is_none")]
297 pub content: Option<String>,
298}
299
300#[derive(Serialize, Deserialize, Debug, Clone)]
301pub struct LLMStreamChoice {
302 pub finish_reason: Option<String>,
303 pub index: u32,
304 pub message: Option<LLMMessage>,
305 pub delta: LLMStreamDelta,
306}
307
308#[derive(Serialize, Deserialize, Debug, Clone)]
309pub struct LLMCompletionStreamResponse {
310 pub model: String,
311 pub object: String,
312 pub choices: Vec<LLMStreamChoice>,
313 pub created: u64,
314 #[serde(skip_serializing_if = "Option::is_none")]
315 pub usage: Option<LLMTokenUsage>,
316 pub id: String,
317 pub citations: Option<Vec<String>>,
318}
319
320#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
321pub struct LLMTool {
322 pub name: String,
323 pub description: String,
324 pub input_schema: serde_json::Value,
325}
326
327#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
328pub struct LLMTokenUsage {
329 pub prompt_tokens: u32,
330 pub completion_tokens: u32,
331 pub total_tokens: u32,
332
333 #[serde(skip_serializing_if = "Option::is_none")]
334 pub prompt_tokens_details: Option<PromptTokensDetails>,
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
338#[serde(rename_all = "snake_case")]
339pub enum TokenType {
340 InputTokens,
341 OutputTokens,
342 CacheReadInputTokens,
343 CacheWriteInputTokens,
344}
345
346#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
347pub struct PromptTokensDetails {
348 #[serde(skip_serializing_if = "Option::is_none")]
349 pub input_tokens: Option<u32>,
350 #[serde(skip_serializing_if = "Option::is_none")]
351 pub output_tokens: Option<u32>,
352 #[serde(skip_serializing_if = "Option::is_none")]
353 pub cache_read_input_tokens: Option<u32>,
354 #[serde(skip_serializing_if = "Option::is_none")]
355 pub cache_write_input_tokens: Option<u32>,
356}
357
358impl PromptTokensDetails {
359 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
361 [
362 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
363 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
364 (
365 TokenType::CacheReadInputTokens,
366 self.cache_read_input_tokens.unwrap_or(0),
367 ),
368 (
369 TokenType::CacheWriteInputTokens,
370 self.cache_write_input_tokens.unwrap_or(0),
371 ),
372 ]
373 .into_iter()
374 }
375}
376
377impl std::ops::Add for PromptTokensDetails {
378 type Output = Self;
379
380 fn add(self, rhs: Self) -> Self::Output {
381 Self {
382 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
383 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
384 cache_read_input_tokens: Some(
385 self.cache_read_input_tokens.unwrap_or(0)
386 + rhs.cache_read_input_tokens.unwrap_or(0),
387 ),
388 cache_write_input_tokens: Some(
389 self.cache_write_input_tokens.unwrap_or(0)
390 + rhs.cache_write_input_tokens.unwrap_or(0),
391 ),
392 }
393 }
394}
395
396impl std::ops::AddAssign for PromptTokensDetails {
397 fn add_assign(&mut self, rhs: Self) {
398 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
399 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
400 self.cache_read_input_tokens = Some(
401 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
402 );
403 self.cache_write_input_tokens = Some(
404 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
405 );
406 }
407}
408
409#[derive(Serialize, Deserialize, Debug, Clone)]
410#[serde(tag = "type")]
411pub enum GenerationDelta {
412 Content { content: String },
413 Thinking { thinking: String },
414 ToolUse { tool_use: GenerationDeltaToolUse },
415 Usage { usage: LLMTokenUsage },
416 Metadata { metadata: serde_json::Value },
417}
418
419#[derive(Serialize, Deserialize, Debug, Clone)]
420pub struct GenerationDeltaToolUse {
421 pub id: Option<String>,
422 pub name: Option<String>,
423 pub input: Option<String>,
424 pub index: usize,
425}