1use crate::models::{
2 error::{AgentError, BadRequestErrorMessage},
3 integrations::{
4 anthropic::{Anthropic, AnthropicConfig, AnthropicInput, AnthropicModel},
5 gemini::{Gemini, GeminiConfig, GeminiInput, GeminiModel},
6 openai::{OpenAI, OpenAIConfig, OpenAIInput, OpenAIModel},
7 },
8 model_pricing::{ContextAware, ModelContextInfo},
9};
10use serde::{Deserialize, Serialize};
11use std::fmt::Display;
12
13#[derive(Clone, Debug, PartialEq, Serialize)]
14pub enum LLMModel {
15 Anthropic(AnthropicModel),
16 Gemini(GeminiModel),
17 OpenAI(OpenAIModel),
18 Custom(String),
19}
20
21impl ContextAware for LLMModel {
22 fn context_info(&self) -> ModelContextInfo {
23 match self {
24 LLMModel::Anthropic(model) => model.context_info(),
25 LLMModel::Gemini(model) => model.context_info(),
26 LLMModel::OpenAI(model) => model.context_info(),
27 LLMModel::Custom(_) => ModelContextInfo::default(),
28 }
29 }
30
31 fn model_name(&self) -> String {
32 match self {
33 LLMModel::Anthropic(model) => model.model_name(),
34 LLMModel::Gemini(model) => model.model_name(),
35 LLMModel::OpenAI(model) => model.model_name(),
36 LLMModel::Custom(model_name) => model_name.clone(),
37 }
38 }
39}
40
41#[derive(Debug)]
42pub struct LLMProviderConfig {
43 pub anthropic_config: Option<AnthropicConfig>,
44 pub gemini_config: Option<GeminiConfig>,
45 pub openai_config: Option<OpenAIConfig>,
46}
47
48impl From<String> for LLMModel {
49 fn from(value: String) -> Self {
50 if value.starts_with("claude-haiku-4-5") {
51 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
52 } else if value.starts_with("claude-sonnet-4-5") {
53 LLMModel::Anthropic(AnthropicModel::Claude45Sonnet)
54 } else if value.starts_with("claude-opus-4-5") {
55 LLMModel::Anthropic(AnthropicModel::Claude45Opus)
56 } else if value == "gemini-2.5-flash-lite" {
57 LLMModel::Gemini(GeminiModel::Gemini25FlashLite)
58 } else if value.starts_with("gemini-2.5-flash") {
59 LLMModel::Gemini(GeminiModel::Gemini25Flash)
60 } else if value.starts_with("gemini-2.5-pro") {
61 LLMModel::Gemini(GeminiModel::Gemini25Pro)
62 } else if value.starts_with("gemini-3-pro-preview") {
63 LLMModel::Gemini(GeminiModel::Gemini3Pro)
64 } else if value.starts_with("gemini-3-flash-preview") {
65 LLMModel::Gemini(GeminiModel::Gemini3Flash)
66 } else if value.starts_with("gpt-5-mini") {
67 LLMModel::OpenAI(OpenAIModel::GPT5Mini)
68 } else if value.starts_with("gpt-5") {
69 LLMModel::OpenAI(OpenAIModel::GPT5)
70 } else {
71 LLMModel::Custom(value)
72 }
73 }
74}
75
76impl Display for LLMModel {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 LLMModel::Anthropic(model) => write!(f, "{}", model),
80 LLMModel::Gemini(model) => write!(f, "{}", model),
81 LLMModel::OpenAI(model) => write!(f, "{}", model),
82 LLMModel::Custom(model) => write!(f, "{}", model),
83 }
84 }
85}
86
87#[derive(Clone, Debug, Serialize)]
88pub struct LLMInput {
89 pub model: LLMModel,
90 pub messages: Vec<LLMMessage>,
91 pub max_tokens: u32,
92 pub tools: Option<Vec<LLMTool>>,
93}
94
95#[derive(Debug)]
96pub struct LLMStreamInput {
97 pub model: LLMModel,
98 pub messages: Vec<LLMMessage>,
99 pub max_tokens: u32,
100 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
101 pub tools: Option<Vec<LLMTool>>,
102}
103
104impl From<&LLMStreamInput> for LLMInput {
105 fn from(value: &LLMStreamInput) -> Self {
106 LLMInput {
107 model: value.model.clone(),
108 messages: value.messages.clone(),
109 max_tokens: value.max_tokens,
110 tools: value.tools.clone(),
111 }
112 }
113}
114
115pub async fn chat(
116 config: &LLMProviderConfig,
117 input: LLMInput,
118) -> Result<LLMCompletionResponse, AgentError> {
119 match input.model {
120 LLMModel::Anthropic(model) => {
121 if let Some(anthropic_config) = &config.anthropic_config {
122 let anthropic_input = AnthropicInput {
123 model,
124 messages: input.messages,
125 grammar: None,
126 max_tokens: input.max_tokens,
127 stop_sequences: None,
128 tools: input.tools,
129 thinking: Default::default(),
130 };
131 Anthropic::chat(anthropic_config, anthropic_input).await
132 } else {
133 Err(AgentError::BadRequest(
134 BadRequestErrorMessage::InvalidAgentInput(
135 "Anthropic config not found".to_string(),
136 ),
137 ))
138 }
139 }
140
141 LLMModel::Gemini(model) => {
142 if let Some(gemini_config) = &config.gemini_config {
143 let gemini_input = GeminiInput {
144 model,
145 messages: input.messages,
146 max_tokens: input.max_tokens,
147 tools: input.tools,
148 };
149 Gemini::chat(gemini_config, gemini_input).await
150 } else {
151 Err(AgentError::BadRequest(
152 BadRequestErrorMessage::InvalidAgentInput(
153 "Gemini config not found".to_string(),
154 ),
155 ))
156 }
157 }
158 LLMModel::OpenAI(model) => {
159 if let Some(openai_config) = &config.openai_config {
160 let openai_input = OpenAIInput {
161 model,
162 messages: input.messages,
163 max_tokens: input.max_tokens,
164 json: None,
165 tools: input.tools,
166 reasoning_effort: None,
167 };
168 OpenAI::chat(openai_config, openai_input).await
169 } else {
170 Err(AgentError::BadRequest(
171 BadRequestErrorMessage::InvalidAgentInput(
172 "OpenAI config not found".to_string(),
173 ),
174 ))
175 }
176 }
177 LLMModel::Custom(model_name) => {
178 if let Some(openai_config) = &config.openai_config {
179 let openai_input = OpenAIInput {
180 model: OpenAIModel::Custom(model_name),
181 messages: input.messages,
182 max_tokens: input.max_tokens,
183 json: None,
184 tools: input.tools,
185 reasoning_effort: None,
186 };
187 OpenAI::chat(openai_config, openai_input).await
188 } else {
189 Err(AgentError::BadRequest(
190 BadRequestErrorMessage::InvalidAgentInput(
191 "OpenAI config not found".to_string(),
192 ),
193 ))
194 }
195 }
196 }
197}
198
199pub async fn chat_stream(
200 config: &LLMProviderConfig,
201 input: LLMStreamInput,
202) -> Result<LLMCompletionResponse, AgentError> {
203 match input.model {
204 LLMModel::Anthropic(model) => {
205 if let Some(anthropic_config) = &config.anthropic_config {
206 let anthropic_input = AnthropicInput {
207 model,
208 messages: input.messages,
209 grammar: None,
210 max_tokens: input.max_tokens,
211 stop_sequences: None,
212 tools: input.tools,
213 thinking: Default::default(),
214 };
215 Anthropic::chat_stream(anthropic_config, input.stream_channel_tx, anthropic_input)
216 .await
217 } else {
218 Err(AgentError::BadRequest(
219 BadRequestErrorMessage::InvalidAgentInput(
220 "Anthropic config not found".to_string(),
221 ),
222 ))
223 }
224 }
225
226 LLMModel::Gemini(model) => {
227 if let Some(gemini_config) = &config.gemini_config {
228 let gemini_input = GeminiInput {
229 model,
230 messages: input.messages,
231 max_tokens: input.max_tokens,
232 tools: input.tools,
233 };
234 Gemini::chat_stream(gemini_config, input.stream_channel_tx, gemini_input).await
235 } else {
236 Err(AgentError::BadRequest(
237 BadRequestErrorMessage::InvalidAgentInput(
238 "Gemini config not found".to_string(),
239 ),
240 ))
241 }
242 }
243 LLMModel::OpenAI(model) => {
244 if let Some(openai_config) = &config.openai_config {
245 let openai_input = OpenAIInput {
246 model,
247 messages: input.messages,
248 max_tokens: input.max_tokens,
249 json: None,
250 tools: input.tools,
251 reasoning_effort: None,
252 };
253 OpenAI::chat_stream(openai_config, input.stream_channel_tx, openai_input).await
254 } else {
255 Err(AgentError::BadRequest(
256 BadRequestErrorMessage::InvalidAgentInput(
257 "OpenAI config not found".to_string(),
258 ),
259 ))
260 }
261 }
262 LLMModel::Custom(model_name) => {
263 if let Some(openai_config) = &config.openai_config {
264 let openai_input = OpenAIInput {
265 model: OpenAIModel::Custom(model_name),
266 messages: input.messages,
267 max_tokens: input.max_tokens,
268 json: None,
269 tools: input.tools,
270 reasoning_effort: None,
271 };
272 OpenAI::chat_stream(openai_config, input.stream_channel_tx, openai_input).await
273 } else {
274 Err(AgentError::BadRequest(
275 BadRequestErrorMessage::InvalidAgentInput(
276 "OpenAI config not found".to_string(),
277 ),
278 ))
279 }
280 }
281 }
282}
283
284#[derive(Serialize, Deserialize, Debug, Clone, Default)]
285pub struct LLMMessage {
286 pub role: String,
287 pub content: LLMMessageContent,
288}
289
290#[derive(Serialize, Deserialize, Debug, Clone)]
291pub struct SimpleLLMMessage {
292 #[serde(rename = "role")]
293 pub role: SimpleLLMRole,
294 pub content: String,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone)]
298#[serde(rename_all = "lowercase")]
299pub enum SimpleLLMRole {
300 User,
301 Assistant,
302}
303
304impl std::fmt::Display for SimpleLLMRole {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 match self {
307 SimpleLLMRole::User => write!(f, "user"),
308 SimpleLLMRole::Assistant => write!(f, "assistant"),
309 }
310 }
311}
312
313#[derive(Serialize, Deserialize, Debug, Clone)]
314#[serde(untagged)]
315pub enum LLMMessageContent {
316 String(String),
317 List(Vec<LLMMessageTypedContent>),
318}
319
320#[allow(clippy::to_string_trait_impl)]
321impl ToString for LLMMessageContent {
322 fn to_string(&self) -> String {
323 match self {
324 LLMMessageContent::String(s) => s.clone(),
325 LLMMessageContent::List(l) => l
326 .iter()
327 .map(|c| match c {
328 LLMMessageTypedContent::Text { text } => text.clone(),
329 LLMMessageTypedContent::ToolCall { .. } => String::new(),
330 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
331 LLMMessageTypedContent::Image { .. } => String::new(),
332 })
333 .collect::<Vec<_>>()
334 .join("\n"),
335 }
336 }
337}
338
339impl From<String> for LLMMessageContent {
340 fn from(value: String) -> Self {
341 LLMMessageContent::String(value)
342 }
343}
344
345impl Default for LLMMessageContent {
346 fn default() -> Self {
347 LLMMessageContent::String(String::new())
348 }
349}
350
351#[derive(Serialize, Deserialize, Debug, Clone)]
352#[serde(tag = "type")]
353pub enum LLMMessageTypedContent {
354 #[serde(rename = "text")]
355 Text { text: String },
356 #[serde(rename = "tool_use")]
357 ToolCall {
358 id: String,
359 name: String,
360 #[serde(alias = "input")]
361 args: serde_json::Value,
362 },
363 #[serde(rename = "tool_result")]
364 ToolResult {
365 tool_use_id: String,
366 content: String,
367 },
368 #[serde(rename = "image")]
369 Image { source: LLMMessageImageSource },
370}
371
372#[derive(Serialize, Deserialize, Debug, Clone)]
373pub struct LLMMessageImageSource {
374 #[serde(rename = "type")]
375 pub r#type: String,
376 pub media_type: String,
377 pub data: String,
378}
379
380impl Default for LLMMessageTypedContent {
381 fn default() -> Self {
382 LLMMessageTypedContent::Text {
383 text: String::new(),
384 }
385 }
386}
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389pub struct LLMChoice {
390 pub finish_reason: Option<String>,
391 pub index: u32,
392 pub message: LLMMessage,
393}
394
395#[derive(Serialize, Deserialize, Debug, Clone)]
396pub struct LLMCompletionResponse {
397 pub model: String,
398 pub object: String,
399 pub choices: Vec<LLMChoice>,
400 pub created: u64,
401 pub usage: Option<LLMTokenUsage>,
402 pub id: String,
403}
404
405#[derive(Serialize, Deserialize, Debug, Clone)]
406pub struct LLMStreamDelta {
407 #[serde(skip_serializing_if = "Option::is_none")]
408 pub content: Option<String>,
409}
410
411#[derive(Serialize, Deserialize, Debug, Clone)]
412pub struct LLMStreamChoice {
413 pub finish_reason: Option<String>,
414 pub index: u32,
415 pub message: Option<LLMMessage>,
416 pub delta: LLMStreamDelta,
417}
418
419#[derive(Serialize, Deserialize, Debug, Clone)]
420pub struct LLMCompletionStreamResponse {
421 pub model: String,
422 pub object: String,
423 pub choices: Vec<LLMStreamChoice>,
424 pub created: u64,
425 #[serde(skip_serializing_if = "Option::is_none")]
426 pub usage: Option<LLMTokenUsage>,
427 pub id: String,
428 pub citations: Option<Vec<String>>,
429}
430
431#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
432pub struct LLMTool {
433 pub name: String,
434 pub description: String,
435 pub input_schema: serde_json::Value,
436}
437
438#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
439pub struct LLMTokenUsage {
440 pub prompt_tokens: u32,
441 pub completion_tokens: u32,
442 pub total_tokens: u32,
443
444 #[serde(skip_serializing_if = "Option::is_none")]
445 pub prompt_tokens_details: Option<PromptTokensDetails>,
446}
447
448#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
449#[serde(rename_all = "snake_case")]
450pub enum TokenType {
451 InputTokens,
452 OutputTokens,
453 CacheReadInputTokens,
454 CacheWriteInputTokens,
455}
456
457#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
458pub struct PromptTokensDetails {
459 #[serde(skip_serializing_if = "Option::is_none")]
460 pub input_tokens: Option<u32>,
461 #[serde(skip_serializing_if = "Option::is_none")]
462 pub output_tokens: Option<u32>,
463 #[serde(skip_serializing_if = "Option::is_none")]
464 pub cache_read_input_tokens: Option<u32>,
465 #[serde(skip_serializing_if = "Option::is_none")]
466 pub cache_write_input_tokens: Option<u32>,
467}
468
469impl PromptTokensDetails {
470 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
472 [
473 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
474 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
475 (
476 TokenType::CacheReadInputTokens,
477 self.cache_read_input_tokens.unwrap_or(0),
478 ),
479 (
480 TokenType::CacheWriteInputTokens,
481 self.cache_write_input_tokens.unwrap_or(0),
482 ),
483 ]
484 .into_iter()
485 }
486}
487
488impl std::ops::Add for PromptTokensDetails {
489 type Output = Self;
490
491 fn add(self, rhs: Self) -> Self::Output {
492 Self {
493 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
494 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
495 cache_read_input_tokens: Some(
496 self.cache_read_input_tokens.unwrap_or(0)
497 + rhs.cache_read_input_tokens.unwrap_or(0),
498 ),
499 cache_write_input_tokens: Some(
500 self.cache_write_input_tokens.unwrap_or(0)
501 + rhs.cache_write_input_tokens.unwrap_or(0),
502 ),
503 }
504 }
505}
506
507impl std::ops::AddAssign for PromptTokensDetails {
508 fn add_assign(&mut self, rhs: Self) {
509 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
510 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
511 self.cache_read_input_tokens = Some(
512 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
513 );
514 self.cache_write_input_tokens = Some(
515 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
516 );
517 }
518}
519
520#[derive(Serialize, Deserialize, Debug, Clone)]
521#[serde(tag = "type")]
522pub enum GenerationDelta {
523 Content { content: String },
524 Thinking { thinking: String },
525 ToolUse { tool_use: GenerationDeltaToolUse },
526 Usage { usage: LLMTokenUsage },
527 Metadata { metadata: serde_json::Value },
528}
529
530#[derive(Serialize, Deserialize, Debug, Clone)]
531pub struct GenerationDeltaToolUse {
532 pub id: Option<String>,
533 pub name: Option<String>,
534 pub input: Option<String>,
535 pub index: usize,
536}