1use async_stream::try_stream;
50use async_trait::async_trait;
51use serde::{Deserialize, Serialize};
52use serde_json::{Value, json};
53use std::pin::Pin;
54
55use crate::config::types::ReasoningEffortLevel;
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct LLMRequest {
60 pub messages: Vec<Message>,
61 pub system_prompt: Option<String>,
62 pub tools: Option<Vec<ToolDefinition>>,
63 pub model: String,
64 pub max_tokens: Option<u32>,
65 pub temperature: Option<f32>,
66 pub stream: bool,
67
68 pub tool_choice: Option<ToolChoice>,
71
72 pub parallel_tool_calls: Option<bool>,
74
75 pub parallel_tool_config: Option<ParallelToolConfig>,
77
78 pub reasoning_effort: Option<ReasoningEffortLevel>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(untagged)]
88pub enum ToolChoice {
89 Auto,
92
93 None,
96
97 Any,
100
101 Specific(SpecificToolChoice),
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SpecificToolChoice {
109 #[serde(rename = "type")]
110 pub tool_type: String, pub function: SpecificFunctionChoice,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SpecificFunctionChoice {
118 pub name: String,
119}
120
121impl ToolChoice {
122 pub fn auto() -> Self {
124 Self::Auto
125 }
126
127 pub fn none() -> Self {
129 Self::None
130 }
131
132 pub fn any() -> Self {
134 Self::Any
135 }
136
137 pub fn function(name: String) -> Self {
139 Self::Specific(SpecificToolChoice {
140 tool_type: "function".to_string(),
141 function: SpecificFunctionChoice { name },
142 })
143 }
144
145 pub fn allows_parallel_tools(&self) -> bool {
148 match self {
149 Self::Auto => true,
151 Self::Any => true,
153 Self::Specific(_) => false,
155 Self::None => false,
157 }
158 }
159
160 pub fn description(&self) -> &'static str {
162 match self {
163 Self::Auto => "Model decides when to use tools (allows parallel)",
164 Self::None => "No tools will be used",
165 Self::Any => "At least one tool must be used (allows parallel)",
166 Self::Specific(_) => "Specific tool must be used (no parallel)",
167 }
168 }
169
170 pub fn to_provider_format(&self, provider: &str) -> Value {
172 match (self, provider) {
173 (Self::Auto, "openai") | (Self::Auto, "deepseek") => json!("auto"),
174 (Self::None, "openai") | (Self::None, "deepseek") => json!("none"),
175 (Self::Any, "openai") | (Self::Any, "deepseek") => json!("required"),
176 (Self::Specific(choice), "openai") | (Self::Specific(choice), "deepseek") => {
177 json!(choice)
178 }
179
180 (Self::Auto, "anthropic") => json!({"type": "auto"}),
181 (Self::None, "anthropic") => json!({"type": "none"}),
182 (Self::Any, "anthropic") => json!({"type": "any"}),
183 (Self::Specific(choice), "anthropic") => {
184 json!({"type": "tool", "name": choice.function.name})
185 }
186
187 (Self::Auto, "gemini") => json!({"mode": "auto"}),
188 (Self::None, "gemini") => json!({"mode": "none"}),
189 (Self::Any, "gemini") => json!({"mode": "any"}),
190 (Self::Specific(choice), "gemini") => {
191 json!({"mode": "any", "allowed_function_names": [choice.function.name]})
192 }
193
194 _ => match self {
196 Self::Auto => json!("auto"),
197 Self::None => json!("none"),
198 Self::Any => json!("required"),
199 Self::Specific(choice) => json!(choice),
200 },
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct ParallelToolConfig {
209 pub disable_parallel_tool_use: bool,
212
213 pub max_parallel_tools: Option<usize>,
216
217 pub encourage_parallel: bool,
219}
220
221impl Default for ParallelToolConfig {
222 fn default() -> Self {
223 Self {
224 disable_parallel_tool_use: false,
225 max_parallel_tools: Some(5), encourage_parallel: true,
227 }
228 }
229}
230
231impl ParallelToolConfig {
232 pub fn anthropic_optimized() -> Self {
234 Self {
235 disable_parallel_tool_use: false,
236 max_parallel_tools: None, encourage_parallel: true,
238 }
239 }
240
241 pub fn sequential_only() -> Self {
243 Self {
244 disable_parallel_tool_use: true,
245 max_parallel_tools: Some(1),
246 encourage_parallel: false,
247 }
248 }
249}
250
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
253pub struct Message {
254 pub role: MessageRole,
255 pub content: String,
256 pub tool_calls: Option<Vec<ToolCall>>,
257 pub tool_call_id: Option<String>,
258}
259
260impl Message {
261 pub fn user(content: String) -> Self {
263 Self {
264 role: MessageRole::User,
265 content,
266 tool_calls: None,
267 tool_call_id: None,
268 }
269 }
270
271 pub fn assistant(content: String) -> Self {
273 Self {
274 role: MessageRole::Assistant,
275 content,
276 tool_calls: None,
277 tool_call_id: None,
278 }
279 }
280
281 pub fn assistant_with_tools(content: String, tool_calls: Vec<ToolCall>) -> Self {
284 Self {
285 role: MessageRole::Assistant,
286 content,
287 tool_calls: Some(tool_calls),
288 tool_call_id: None,
289 }
290 }
291
292 pub fn system(content: String) -> Self {
294 Self {
295 role: MessageRole::System,
296 content,
297 tool_calls: None,
298 tool_call_id: None,
299 }
300 }
301
302 pub fn tool_response(tool_call_id: String, content: String) -> Self {
312 Self {
313 role: MessageRole::Tool,
314 content,
315 tool_calls: None,
316 tool_call_id: Some(tool_call_id),
317 }
318 }
319
320 pub fn tool_response_with_name(
323 tool_call_id: String,
324 _function_name: String,
325 content: String,
326 ) -> Self {
327 Self::tool_response(tool_call_id, content)
329 }
330
331 pub fn validate_for_provider(&self, provider: &str) -> Result<(), String> {
334 self.role
336 .validate_for_provider(provider, self.tool_call_id.is_some())?;
337
338 if let Some(tool_calls) = &self.tool_calls {
340 if !self.role.can_make_tool_calls() {
341 return Err(format!("Role {:?} cannot make tool calls", self.role));
342 }
343
344 if tool_calls.is_empty() {
345 return Err("Tool calls array should not be empty".to_string());
346 }
347
348 for tool_call in tool_calls {
350 tool_call.validate()?;
351 }
352 }
353
354 match provider {
356 "openai" | "openrouter" => {
357 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
358 return Err(format!(
359 "{} requires tool_call_id for tool messages",
360 provider
361 ));
362 }
363 }
364 "gemini" => {
365 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
366 return Err(
367 "Gemini tool responses need tool_call_id for function name mapping"
368 .to_string(),
369 );
370 }
371 if self.role == MessageRole::System && !self.content.is_empty() {
373 }
375 }
376 "anthropic" => {
377 }
380 _ => {} }
382
383 Ok(())
384 }
385
386 pub fn has_tool_calls(&self) -> bool {
388 self.tool_calls
389 .as_ref()
390 .map_or(false, |calls| !calls.is_empty())
391 }
392
393 pub fn get_tool_calls(&self) -> Option<&[ToolCall]> {
395 self.tool_calls.as_deref()
396 }
397
398 pub fn is_tool_response(&self) -> bool {
400 self.role == MessageRole::Tool
401 }
402}
403
404#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
405pub enum MessageRole {
406 System,
407 User,
408 Assistant,
409 Tool,
410}
411
412impl MessageRole {
413 pub fn as_gemini_str(&self) -> &'static str {
419 match self {
420 MessageRole::System => "system", MessageRole::User => "user",
422 MessageRole::Assistant => "model", MessageRole::Tool => "user", }
425 }
426
427 pub fn as_openai_str(&self) -> &'static str {
432 match self {
433 MessageRole::System => "system",
434 MessageRole::User => "user",
435 MessageRole::Assistant => "assistant",
436 MessageRole::Tool => "tool", }
438 }
439
440 pub fn as_anthropic_str(&self) -> &'static str {
446 match self {
447 MessageRole::System => "system", MessageRole::User => "user",
449 MessageRole::Assistant => "assistant",
450 MessageRole::Tool => "user", }
452 }
453
454 pub fn as_generic_str(&self) -> &'static str {
457 match self {
458 MessageRole::System => "system",
459 MessageRole::User => "user",
460 MessageRole::Assistant => "assistant",
461 MessageRole::Tool => "tool",
462 }
463 }
464
465 pub fn can_make_tool_calls(&self) -> bool {
468 matches!(self, MessageRole::Assistant)
469 }
470
471 pub fn is_tool_response(&self) -> bool {
473 matches!(self, MessageRole::Tool)
474 }
475
476 pub fn validate_for_provider(
479 &self,
480 provider: &str,
481 has_tool_call_id: bool,
482 ) -> Result<(), String> {
483 match (self, provider) {
484 (MessageRole::Tool, provider)
485 if matches!(provider, "openai" | "openrouter" | "xai" | "deepseek")
486 && !has_tool_call_id =>
487 {
488 Err(format!("{} tool messages must have tool_call_id", provider))
489 }
490 (MessageRole::Tool, "gemini") if !has_tool_call_id => {
491 Err("Gemini tool messages need tool_call_id for function mapping".to_string())
492 }
493 _ => Ok(()),
494 }
495 }
496}
497
498#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct ToolDefinition {
502 #[serde(rename = "type")]
504 pub tool_type: String,
505
506 pub function: FunctionDefinition,
508}
509
510#[derive(Debug, Clone, Serialize, Deserialize)]
512pub struct FunctionDefinition {
513 pub name: String,
515
516 pub description: String,
518
519 pub parameters: Value,
521}
522
523impl ToolDefinition {
524 pub fn function(name: String, description: String, parameters: Value) -> Self {
526 Self {
527 tool_type: "function".to_string(),
528 function: FunctionDefinition {
529 name,
530 description,
531 parameters,
532 },
533 }
534 }
535
536 pub fn function_name(&self) -> &str {
538 &self.function.name
539 }
540
541 pub fn validate(&self) -> Result<(), String> {
543 if self.tool_type != "function" {
544 return Err(format!(
545 "Only 'function' type is supported, got: {}",
546 self.tool_type
547 ));
548 }
549
550 if self.function.name.is_empty() {
551 return Err("Function name cannot be empty".to_string());
552 }
553
554 if self.function.description.is_empty() {
555 return Err("Function description cannot be empty".to_string());
556 }
557
558 if !self.function.parameters.is_object() {
560 return Err("Function parameters must be a JSON object".to_string());
561 }
562
563 Ok(())
564 }
565}
566
567#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
570pub struct ToolCall {
571 pub id: String,
573
574 #[serde(rename = "type")]
576 pub call_type: String,
577
578 pub function: FunctionCall,
580}
581
582#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
584pub struct FunctionCall {
585 pub name: String,
587
588 pub arguments: String,
590}
591
592impl ToolCall {
593 pub fn function(id: String, name: String, arguments: String) -> Self {
595 Self {
596 id,
597 call_type: "function".to_string(),
598 function: FunctionCall { name, arguments },
599 }
600 }
601
602 pub fn parsed_arguments(&self) -> Result<Value, serde_json::Error> {
604 serde_json::from_str(&self.function.arguments)
605 }
606
607 pub fn validate(&self) -> Result<(), String> {
609 if self.call_type != "function" {
610 return Err(format!(
611 "Only 'function' type is supported, got: {}",
612 self.call_type
613 ));
614 }
615
616 if self.id.is_empty() {
617 return Err("Tool call ID cannot be empty".to_string());
618 }
619
620 if self.function.name.is_empty() {
621 return Err("Function name cannot be empty".to_string());
622 }
623
624 if let Err(e) = self.parsed_arguments() {
626 return Err(format!("Invalid JSON in function arguments: {}", e));
627 }
628
629 Ok(())
630 }
631}
632
633#[derive(Debug, Clone)]
635pub struct LLMResponse {
636 pub content: Option<String>,
637 pub tool_calls: Option<Vec<ToolCall>>,
638 pub usage: Option<Usage>,
639 pub finish_reason: FinishReason,
640 pub reasoning: Option<String>,
641}
642
643#[derive(Debug, Clone)]
644pub struct Usage {
645 pub prompt_tokens: u32,
646 pub completion_tokens: u32,
647 pub total_tokens: u32,
648 pub cached_prompt_tokens: Option<u32>,
649 pub cache_creation_tokens: Option<u32>,
650 pub cache_read_tokens: Option<u32>,
651}
652
653#[derive(Debug, Clone, PartialEq, Eq)]
654pub enum FinishReason {
655 Stop,
656 Length,
657 ToolCalls,
658 ContentFilter,
659 Error(String),
660}
661
662#[derive(Debug, Clone)]
663pub enum LLMStreamEvent {
664 Token { delta: String },
665 Reasoning { delta: String },
666 Completed { response: LLMResponse },
667}
668
669pub type LLMStream = Pin<Box<dyn futures::Stream<Item = Result<LLMStreamEvent, LLMError>> + Send>>;
670
671#[async_trait]
673pub trait LLMProvider: Send + Sync {
674 fn name(&self) -> &str;
676
677 fn supports_streaming(&self) -> bool {
679 false
680 }
681
682 fn supports_reasoning(&self, _model: &str) -> bool {
684 false
685 }
686
687 fn supports_reasoning_effort(&self, _model: &str) -> bool {
689 false
690 }
691
692 fn supports_tools(&self, _model: &str) -> bool {
694 true
695 }
696
697 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError>;
699
700 async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
702 let response = self.generate(request).await?;
704 let stream = try_stream! {
705 yield LLMStreamEvent::Completed { response };
706 };
707 Ok(Box::pin(stream))
708 }
709
710 fn supported_models(&self) -> Vec<String>;
712
713 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError>;
715}
716
717#[derive(Debug, thiserror::Error)]
718pub enum LLMError {
719 #[error("Authentication failed: {0}")]
720 Authentication(String),
721 #[error("Rate limit exceeded")]
722 RateLimit,
723 #[error("Invalid request: {0}")]
724 InvalidRequest(String),
725 #[error("Network error: {0}")]
726 Network(String),
727 #[error("Provider error: {0}")]
728 Provider(String),
729}
730
731impl From<LLMError> for crate::llm::types::LLMError {
733 fn from(err: LLMError) -> crate::llm::types::LLMError {
734 match err {
735 LLMError::Authentication(msg) => crate::llm::types::LLMError::ApiError(msg),
736 LLMError::RateLimit => crate::llm::types::LLMError::RateLimit,
737 LLMError::InvalidRequest(msg) => crate::llm::types::LLMError::InvalidRequest(msg),
738 LLMError::Network(msg) => crate::llm::types::LLMError::NetworkError(msg),
739 LLMError::Provider(msg) => crate::llm::types::LLMError::ApiError(msg),
740 }
741 }
742}