1use crate::models::{
2 error::{AgentError, BadRequestErrorMessage},
3 integrations::{
4 anthropic::{Anthropic, AnthropicConfig, AnthropicInput, AnthropicModel},
5 gemini::{Gemini, GeminiConfig, GeminiInput, GeminiModel},
6 openai::{OpenAI, OpenAIConfig, OpenAIInput, OpenAIModel},
7 },
8};
9use serde::{Deserialize, Serialize};
10use std::fmt::Display;
11
12#[derive(Clone, Debug, Serialize)]
13pub enum LLMModel {
14 Anthropic(AnthropicModel),
15 Gemini(GeminiModel),
16 OpenAI(OpenAIModel),
17 Custom(String),
18}
19
20#[derive(Debug)]
21pub struct LLMProviderConfig {
22 pub anthropic_config: Option<AnthropicConfig>,
23 pub gemini_config: Option<GeminiConfig>,
24 pub openai_config: Option<OpenAIConfig>,
25}
26
27impl From<String> for LLMModel {
28 fn from(value: String) -> Self {
29 match value.as_str() {
30 "claude-haiku-4-5" => LLMModel::Anthropic(AnthropicModel::Claude45Haiku),
31 "claude-sonnet-4-5" => LLMModel::Anthropic(AnthropicModel::Claude45Sonnet),
32 "claude-opus-4-5" => LLMModel::Anthropic(AnthropicModel::Claude45Opus),
33 "gemini-2.0-flash" => LLMModel::Gemini(GeminiModel::Gemini20Flash),
34 "gemini-2.0-flash-lite" => LLMModel::Gemini(GeminiModel::Gemini20FlashLite),
35 "gemini-2.5-flash" => LLMModel::Gemini(GeminiModel::Gemini25Flash),
36 "gemini-2.5-flash-lite" => LLMModel::Gemini(GeminiModel::Gemini25FlashLite),
37 "gemini-2.5-pro" => LLMModel::Gemini(GeminiModel::Gemini25Pro),
38 "gemini-3-pro-preview" => LLMModel::Gemini(GeminiModel::Gemini3Pro),
39 "gpt-5" => LLMModel::OpenAI(OpenAIModel::GPT5),
40 "gpt-5-mini" => LLMModel::OpenAI(OpenAIModel::GPT5Mini),
41 _ => LLMModel::Custom(value),
42 }
43 }
44}
45
46impl Display for LLMModel {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 LLMModel::Anthropic(model) => write!(f, "{}", model),
50 LLMModel::Gemini(model) => write!(f, "{}", model),
51 LLMModel::OpenAI(model) => write!(f, "{}", model),
52 LLMModel::Custom(model) => write!(f, "{}", model),
53 }
54 }
55}
56
57#[derive(Clone, Debug, Serialize)]
58pub struct LLMInput {
59 pub model: LLMModel,
60 pub messages: Vec<LLMMessage>,
61 pub max_tokens: u32,
62 pub tools: Option<Vec<LLMTool>>,
63}
64
65#[derive(Debug)]
66pub struct LLMStreamInput {
67 pub model: LLMModel,
68 pub messages: Vec<LLMMessage>,
69 pub max_tokens: u32,
70 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
71 pub tools: Option<Vec<LLMTool>>,
72}
73
74impl From<&LLMStreamInput> for LLMInput {
75 fn from(value: &LLMStreamInput) -> Self {
76 LLMInput {
77 model: value.model.clone(),
78 messages: value.messages.clone(),
79 max_tokens: value.max_tokens,
80 tools: value.tools.clone(),
81 }
82 }
83}
84
85pub async fn chat(
86 config: &LLMProviderConfig,
87 input: LLMInput,
88) -> Result<LLMCompletionResponse, AgentError> {
89 match input.model {
90 LLMModel::Anthropic(model) => {
91 if let Some(anthropic_config) = &config.anthropic_config {
92 let anthropic_input = AnthropicInput {
93 model,
94 messages: input.messages,
95 grammar: None,
96 max_tokens: input.max_tokens,
97 stop_sequences: None,
98 tools: input.tools,
99 thinking: Default::default(),
100 };
101 Anthropic::chat(anthropic_config, anthropic_input).await
102 } else {
103 Err(AgentError::BadRequest(
104 BadRequestErrorMessage::InvalidAgentInput(
105 "Anthropic config not found".to_string(),
106 ),
107 ))
108 }
109 }
110
111 LLMModel::Gemini(model) => {
112 if let Some(gemini_config) = &config.gemini_config {
113 let gemini_input = GeminiInput {
114 model,
115 messages: input.messages,
116 max_tokens: input.max_tokens,
117 tools: input.tools,
118 };
119 Gemini::chat(gemini_config, gemini_input).await
120 } else {
121 Err(AgentError::BadRequest(
122 BadRequestErrorMessage::InvalidAgentInput(
123 "Gemini config not found".to_string(),
124 ),
125 ))
126 }
127 }
128 LLMModel::OpenAI(model) => {
129 if let Some(openai_config) = &config.openai_config {
130 let openai_input = OpenAIInput {
131 model,
132 messages: input.messages,
133 max_tokens: input.max_tokens,
134 json: None,
135 tools: input.tools,
136 reasoning_effort: None,
137 };
138 OpenAI::chat(openai_config, openai_input).await
139 } else {
140 Err(AgentError::BadRequest(
141 BadRequestErrorMessage::InvalidAgentInput(
142 "OpenAI config not found".to_string(),
143 ),
144 ))
145 }
146 }
147 LLMModel::Custom(model_name) => {
148 if let Some(openai_config) = &config.openai_config {
149 let openai_input = OpenAIInput {
150 model: OpenAIModel::Custom(model_name),
151 messages: input.messages,
152 max_tokens: input.max_tokens,
153 json: None,
154 tools: input.tools,
155 reasoning_effort: None,
156 };
157 OpenAI::chat(openai_config, openai_input).await
158 } else {
159 Err(AgentError::BadRequest(
160 BadRequestErrorMessage::InvalidAgentInput(
161 "OpenAI config not found".to_string(),
162 ),
163 ))
164 }
165 }
166 }
167}
168
169pub async fn chat_stream(
170 config: &LLMProviderConfig,
171 input: LLMStreamInput,
172) -> Result<LLMCompletionResponse, AgentError> {
173 match input.model {
174 LLMModel::Anthropic(model) => {
175 if let Some(anthropic_config) = &config.anthropic_config {
176 let anthropic_input = AnthropicInput {
177 model,
178 messages: input.messages,
179 grammar: None,
180 max_tokens: input.max_tokens,
181 stop_sequences: None,
182 tools: input.tools,
183 thinking: Default::default(),
184 };
185 Anthropic::chat_stream(anthropic_config, input.stream_channel_tx, anthropic_input)
186 .await
187 } else {
188 Err(AgentError::BadRequest(
189 BadRequestErrorMessage::InvalidAgentInput(
190 "Anthropic config not found".to_string(),
191 ),
192 ))
193 }
194 }
195
196 LLMModel::Gemini(model) => {
197 if let Some(gemini_config) = &config.gemini_config {
198 let gemini_input = GeminiInput {
199 model,
200 messages: input.messages,
201 max_tokens: input.max_tokens,
202 tools: input.tools,
203 };
204 Gemini::chat_stream(gemini_config, input.stream_channel_tx, gemini_input).await
205 } else {
206 Err(AgentError::BadRequest(
207 BadRequestErrorMessage::InvalidAgentInput(
208 "Gemini config not found".to_string(),
209 ),
210 ))
211 }
212 }
213 LLMModel::OpenAI(model) => {
214 if let Some(openai_config) = &config.openai_config {
215 let openai_input = OpenAIInput {
216 model,
217 messages: input.messages,
218 max_tokens: input.max_tokens,
219 json: None,
220 tools: input.tools,
221 reasoning_effort: None,
222 };
223 OpenAI::chat_stream(openai_config, input.stream_channel_tx, openai_input).await
224 } else {
225 Err(AgentError::BadRequest(
226 BadRequestErrorMessage::InvalidAgentInput(
227 "OpenAI config not found".to_string(),
228 ),
229 ))
230 }
231 }
232 LLMModel::Custom(model_name) => {
233 if let Some(openai_config) = &config.openai_config {
234 let openai_input = OpenAIInput {
235 model: OpenAIModel::Custom(model_name),
236 messages: input.messages,
237 max_tokens: input.max_tokens,
238 json: None,
239 tools: input.tools,
240 reasoning_effort: None,
241 };
242 OpenAI::chat_stream(openai_config, input.stream_channel_tx, openai_input).await
243 } else {
244 Err(AgentError::BadRequest(
245 BadRequestErrorMessage::InvalidAgentInput(
246 "OpenAI config not found".to_string(),
247 ),
248 ))
249 }
250 }
251 }
252}
253
254#[derive(Serialize, Deserialize, Debug, Clone, Default)]
255pub struct LLMMessage {
256 pub role: String,
257 pub content: LLMMessageContent,
258}
259
260#[derive(Serialize, Deserialize, Debug, Clone)]
261pub struct SimpleLLMMessage {
262 #[serde(rename = "role")]
263 pub role: SimpleLLMRole,
264 pub content: String,
265}
266
267#[derive(Serialize, Deserialize, Debug, Clone)]
268#[serde(rename_all = "lowercase")]
269pub enum SimpleLLMRole {
270 User,
271 Assistant,
272}
273
274impl std::fmt::Display for SimpleLLMRole {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 match self {
277 SimpleLLMRole::User => write!(f, "user"),
278 SimpleLLMRole::Assistant => write!(f, "assistant"),
279 }
280 }
281}
282
283#[derive(Serialize, Deserialize, Debug, Clone)]
284#[serde(untagged)]
285pub enum LLMMessageContent {
286 String(String),
287 List(Vec<LLMMessageTypedContent>),
288}
289
290#[allow(clippy::to_string_trait_impl)]
291impl ToString for LLMMessageContent {
292 fn to_string(&self) -> String {
293 match self {
294 LLMMessageContent::String(s) => s.clone(),
295 LLMMessageContent::List(l) => l
296 .iter()
297 .map(|c| match c {
298 LLMMessageTypedContent::Text { text } => text.clone(),
299 LLMMessageTypedContent::ToolCall { .. } => String::new(),
300 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
301 LLMMessageTypedContent::Image { .. } => String::new(),
302 })
303 .collect::<Vec<_>>()
304 .join("\n"),
305 }
306 }
307}
308
309impl From<String> for LLMMessageContent {
310 fn from(value: String) -> Self {
311 LLMMessageContent::String(value)
312 }
313}
314
315impl Default for LLMMessageContent {
316 fn default() -> Self {
317 LLMMessageContent::String(String::new())
318 }
319}
320
321#[derive(Serialize, Deserialize, Debug, Clone)]
322#[serde(tag = "type")]
323pub enum LLMMessageTypedContent {
324 #[serde(rename = "text")]
325 Text { text: String },
326 #[serde(rename = "tool_use")]
327 ToolCall {
328 id: String,
329 name: String,
330 #[serde(alias = "input")]
331 args: serde_json::Value,
332 },
333 #[serde(rename = "tool_result")]
334 ToolResult {
335 tool_use_id: String,
336 content: String,
337 },
338 #[serde(rename = "image")]
339 Image { source: LLMMessageImageSource },
340}
341
342#[derive(Serialize, Deserialize, Debug, Clone)]
343pub struct LLMMessageImageSource {
344 #[serde(rename = "type")]
345 pub r#type: String,
346 pub media_type: String,
347 pub data: String,
348}
349
350impl Default for LLMMessageTypedContent {
351 fn default() -> Self {
352 LLMMessageTypedContent::Text {
353 text: String::new(),
354 }
355 }
356}
357
358#[derive(Serialize, Deserialize, Debug, Clone)]
359pub struct LLMChoice {
360 pub finish_reason: Option<String>,
361 pub index: u32,
362 pub message: LLMMessage,
363}
364
365#[derive(Serialize, Deserialize, Debug, Clone)]
366pub struct LLMCompletionResponse {
367 pub model: String,
368 pub object: String,
369 pub choices: Vec<LLMChoice>,
370 pub created: u64,
371 pub usage: Option<LLMTokenUsage>,
372 pub id: String,
373}
374
375#[derive(Serialize, Deserialize, Debug, Clone)]
376pub struct LLMStreamDelta {
377 #[serde(skip_serializing_if = "Option::is_none")]
378 pub content: Option<String>,
379}
380
381#[derive(Serialize, Deserialize, Debug, Clone)]
382pub struct LLMStreamChoice {
383 pub finish_reason: Option<String>,
384 pub index: u32,
385 pub message: Option<LLMMessage>,
386 pub delta: LLMStreamDelta,
387}
388
389#[derive(Serialize, Deserialize, Debug, Clone)]
390pub struct LLMCompletionStreamResponse {
391 pub model: String,
392 pub object: String,
393 pub choices: Vec<LLMStreamChoice>,
394 pub created: u64,
395 #[serde(skip_serializing_if = "Option::is_none")]
396 pub usage: Option<LLMTokenUsage>,
397 pub id: String,
398 pub citations: Option<Vec<String>>,
399}
400
401#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
402pub struct LLMTool {
403 pub name: String,
404 pub description: String,
405 pub input_schema: serde_json::Value,
406}
407
408#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
409pub struct LLMTokenUsage {
410 pub prompt_tokens: u32,
411 pub completion_tokens: u32,
412 pub total_tokens: u32,
413
414 #[serde(skip_serializing_if = "Option::is_none")]
415 pub prompt_tokens_details: Option<PromptTokensDetails>,
416}
417
418#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
419#[serde(rename_all = "snake_case")]
420pub enum TokenType {
421 InputTokens,
422 OutputTokens,
423 CacheReadInputTokens,
424 CacheWriteInputTokens,
425}
426
427#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
428pub struct PromptTokensDetails {
429 #[serde(skip_serializing_if = "Option::is_none")]
430 pub input_tokens: Option<u32>,
431 #[serde(skip_serializing_if = "Option::is_none")]
432 pub output_tokens: Option<u32>,
433 #[serde(skip_serializing_if = "Option::is_none")]
434 pub cache_read_input_tokens: Option<u32>,
435 #[serde(skip_serializing_if = "Option::is_none")]
436 pub cache_write_input_tokens: Option<u32>,
437}
438
439impl PromptTokensDetails {
440 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
442 [
443 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
444 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
445 (
446 TokenType::CacheReadInputTokens,
447 self.cache_read_input_tokens.unwrap_or(0),
448 ),
449 (
450 TokenType::CacheWriteInputTokens,
451 self.cache_write_input_tokens.unwrap_or(0),
452 ),
453 ]
454 .into_iter()
455 }
456}
457
458impl std::ops::Add for PromptTokensDetails {
459 type Output = Self;
460
461 fn add(self, rhs: Self) -> Self::Output {
462 Self {
463 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
464 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
465 cache_read_input_tokens: Some(
466 self.cache_read_input_tokens.unwrap_or(0)
467 + rhs.cache_read_input_tokens.unwrap_or(0),
468 ),
469 cache_write_input_tokens: Some(
470 self.cache_write_input_tokens.unwrap_or(0)
471 + rhs.cache_write_input_tokens.unwrap_or(0),
472 ),
473 }
474 }
475}
476
477impl std::ops::AddAssign for PromptTokensDetails {
478 fn add_assign(&mut self, rhs: Self) {
479 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
480 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
481 self.cache_read_input_tokens = Some(
482 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
483 );
484 self.cache_write_input_tokens = Some(
485 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
486 );
487 }
488}
489
490#[derive(Serialize, Deserialize, Debug, Clone)]
491#[serde(tag = "type")]
492pub enum GenerationDelta {
493 Content { content: String },
494 Thinking { thinking: String },
495 ToolUse { tool_use: GenerationDeltaToolUse },
496 Usage { usage: LLMTokenUsage },
497 Metadata { metadata: serde_json::Value },
498}
499
500#[derive(Serialize, Deserialize, Debug, Clone)]
501pub struct GenerationDeltaToolUse {
502 pub id: Option<String>,
503 pub name: Option<String>,
504 pub input: Option<String>,
505 pub index: usize,
506}