1use async_stream::try_stream;
50use async_trait::async_trait;
51use serde::{Deserialize, Serialize};
52use serde_json::{Value, json};
53use std::pin::Pin;
54
55use crate::config::types::ReasoningEffortLevel;
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct LLMRequest {
60 pub messages: Vec<Message>,
61 pub system_prompt: Option<String>,
62 pub tools: Option<Vec<ToolDefinition>>,
63 pub model: String,
64 pub max_tokens: Option<u32>,
65 pub temperature: Option<f32>,
66 pub stream: bool,
67
68 pub tool_choice: Option<ToolChoice>,
71
72 pub parallel_tool_calls: Option<bool>,
74
75 pub parallel_tool_config: Option<ParallelToolConfig>,
77
78 pub reasoning_effort: Option<ReasoningEffortLevel>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(untagged)]
88pub enum ToolChoice {
89 Auto,
92
93 None,
96
97 Any,
100
101 Specific(SpecificToolChoice),
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SpecificToolChoice {
109 #[serde(rename = "type")]
110 pub tool_type: String, pub function: SpecificFunctionChoice,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SpecificFunctionChoice {
118 pub name: String,
119}
120
121impl ToolChoice {
122 pub fn auto() -> Self {
124 Self::Auto
125 }
126
127 pub fn none() -> Self {
129 Self::None
130 }
131
132 pub fn any() -> Self {
134 Self::Any
135 }
136
137 pub fn function(name: String) -> Self {
139 Self::Specific(SpecificToolChoice {
140 tool_type: "function".to_string(),
141 function: SpecificFunctionChoice { name },
142 })
143 }
144
145 pub fn allows_parallel_tools(&self) -> bool {
148 match self {
149 Self::Auto => true,
151 Self::Any => true,
153 Self::Specific(_) => false,
155 Self::None => false,
157 }
158 }
159
160 pub fn description(&self) -> &'static str {
162 match self {
163 Self::Auto => "Model decides when to use tools (allows parallel)",
164 Self::None => "No tools will be used",
165 Self::Any => "At least one tool must be used (allows parallel)",
166 Self::Specific(_) => "Specific tool must be used (no parallel)",
167 }
168 }
169
170 pub fn to_provider_format(&self, provider: &str) -> Value {
172 match (self, provider) {
173 (Self::Auto, "openai") | (Self::Auto, "deepseek") => json!("auto"),
174 (Self::None, "openai") | (Self::None, "deepseek") => json!("none"),
175 (Self::Any, "openai") | (Self::Any, "deepseek") => json!("required"),
176 (Self::Specific(choice), "openai") | (Self::Specific(choice), "deepseek") => {
177 json!(choice)
178 }
179
180 (Self::Auto, "anthropic") => json!({"type": "auto"}),
181 (Self::None, "anthropic") => json!({"type": "none"}),
182 (Self::Any, "anthropic") => json!({"type": "any"}),
183 (Self::Specific(choice), "anthropic") => {
184 json!({"type": "tool", "name": choice.function.name})
185 }
186
187 (Self::Auto, "gemini") => json!({"mode": "auto"}),
188 (Self::None, "gemini") => json!({"mode": "none"}),
189 (Self::Any, "gemini") => json!({"mode": "any"}),
190 (Self::Specific(choice), "gemini") => {
191 json!({"mode": "any", "allowed_function_names": [choice.function.name]})
192 }
193
194 _ => match self {
196 Self::Auto => json!("auto"),
197 Self::None => json!("none"),
198 Self::Any => json!("required"),
199 Self::Specific(choice) => json!(choice),
200 },
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct ParallelToolConfig {
209 pub disable_parallel_tool_use: bool,
212
213 pub max_parallel_tools: Option<usize>,
216
217 pub encourage_parallel: bool,
219}
220
221impl Default for ParallelToolConfig {
222 fn default() -> Self {
223 Self {
224 disable_parallel_tool_use: false,
225 max_parallel_tools: Some(5), encourage_parallel: true,
227 }
228 }
229}
230
231impl ParallelToolConfig {
232 pub fn anthropic_optimized() -> Self {
234 Self {
235 disable_parallel_tool_use: false,
236 max_parallel_tools: None, encourage_parallel: true,
238 }
239 }
240
241 pub fn sequential_only() -> Self {
243 Self {
244 disable_parallel_tool_use: true,
245 max_parallel_tools: Some(1),
246 encourage_parallel: false,
247 }
248 }
249}
250
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
253pub struct Message {
254 pub role: MessageRole,
255 pub content: String,
256 pub tool_calls: Option<Vec<ToolCall>>,
257 pub tool_call_id: Option<String>,
258}
259
260impl Message {
261 pub fn user(content: String) -> Self {
263 Self {
264 role: MessageRole::User,
265 content,
266 tool_calls: None,
267 tool_call_id: None,
268 }
269 }
270
271 pub fn assistant(content: String) -> Self {
273 Self {
274 role: MessageRole::Assistant,
275 content,
276 tool_calls: None,
277 tool_call_id: None,
278 }
279 }
280
281 pub fn assistant_with_tools(content: String, tool_calls: Vec<ToolCall>) -> Self {
284 Self {
285 role: MessageRole::Assistant,
286 content,
287 tool_calls: Some(tool_calls),
288 tool_call_id: None,
289 }
290 }
291
292 pub fn system(content: String) -> Self {
294 Self {
295 role: MessageRole::System,
296 content,
297 tool_calls: None,
298 tool_call_id: None,
299 }
300 }
301
302 pub fn tool_response(tool_call_id: String, content: String) -> Self {
312 Self {
313 role: MessageRole::Tool,
314 content,
315 tool_calls: None,
316 tool_call_id: Some(tool_call_id),
317 }
318 }
319
320 pub fn tool_response_with_name(
323 tool_call_id: String,
324 _function_name: String,
325 content: String,
326 ) -> Self {
327 Self::tool_response(tool_call_id, content)
329 }
330
331 pub fn validate_for_provider(&self, provider: &str) -> Result<(), String> {
334 self.role
336 .validate_for_provider(provider, self.tool_call_id.is_some())?;
337
338 if let Some(tool_calls) = &self.tool_calls {
340 if !self.role.can_make_tool_calls() {
341 return Err(format!("Role {:?} cannot make tool calls", self.role));
342 }
343
344 if tool_calls.is_empty() {
345 return Err("Tool calls array should not be empty".to_string());
346 }
347
348 for tool_call in tool_calls {
350 tool_call.validate()?;
351 }
352 }
353
354 match provider {
356 "openai" | "openrouter" | "zai" => {
357 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
358 return Err(format!(
359 "{} requires tool_call_id for tool messages",
360 provider
361 ));
362 }
363 }
364 "gemini" => {
365 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
366 return Err(
367 "Gemini tool responses need tool_call_id for function name mapping"
368 .to_string(),
369 );
370 }
371 if self.role == MessageRole::System && !self.content.is_empty() {
373 }
375 }
376 "anthropic" => {
377 }
380 _ => {} }
382
383 Ok(())
384 }
385
386 pub fn has_tool_calls(&self) -> bool {
388 self.tool_calls
389 .as_ref()
390 .map_or(false, |calls| !calls.is_empty())
391 }
392
393 pub fn get_tool_calls(&self) -> Option<&[ToolCall]> {
395 self.tool_calls.as_deref()
396 }
397
398 pub fn is_tool_response(&self) -> bool {
400 self.role == MessageRole::Tool
401 }
402}
403
404#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
405pub enum MessageRole {
406 System,
407 User,
408 Assistant,
409 Tool,
410}
411
412impl MessageRole {
413 pub fn as_gemini_str(&self) -> &'static str {
419 match self {
420 MessageRole::System => "system", MessageRole::User => "user",
422 MessageRole::Assistant => "model", MessageRole::Tool => "user", }
425 }
426
427 pub fn as_openai_str(&self) -> &'static str {
432 match self {
433 MessageRole::System => "system",
434 MessageRole::User => "user",
435 MessageRole::Assistant => "assistant",
436 MessageRole::Tool => "tool", }
438 }
439
440 pub fn as_anthropic_str(&self) -> &'static str {
446 match self {
447 MessageRole::System => "system", MessageRole::User => "user",
449 MessageRole::Assistant => "assistant",
450 MessageRole::Tool => "user", }
452 }
453
454 pub fn as_generic_str(&self) -> &'static str {
457 match self {
458 MessageRole::System => "system",
459 MessageRole::User => "user",
460 MessageRole::Assistant => "assistant",
461 MessageRole::Tool => "tool",
462 }
463 }
464
465 pub fn can_make_tool_calls(&self) -> bool {
468 matches!(self, MessageRole::Assistant)
469 }
470
471 pub fn is_tool_response(&self) -> bool {
473 matches!(self, MessageRole::Tool)
474 }
475
476 pub fn validate_for_provider(
479 &self,
480 provider: &str,
481 has_tool_call_id: bool,
482 ) -> Result<(), String> {
483 match (self, provider) {
484 (MessageRole::Tool, provider)
485 if matches!(
486 provider,
487 "openai" | "openrouter" | "xai" | "deepseek" | "zai"
488 ) && !has_tool_call_id =>
489 {
490 Err(format!("{} tool messages must have tool_call_id", provider))
491 }
492 (MessageRole::Tool, "gemini") if !has_tool_call_id => {
493 Err("Gemini tool messages need tool_call_id for function mapping".to_string())
494 }
495 _ => Ok(()),
496 }
497 }
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
503pub struct ToolDefinition {
504 #[serde(rename = "type")]
506 pub tool_type: String,
507
508 pub function: FunctionDefinition,
510}
511
512#[derive(Debug, Clone, Serialize, Deserialize)]
514pub struct FunctionDefinition {
515 pub name: String,
517
518 pub description: String,
520
521 pub parameters: Value,
523}
524
525impl ToolDefinition {
526 pub fn function(name: String, description: String, parameters: Value) -> Self {
528 Self {
529 tool_type: "function".to_string(),
530 function: FunctionDefinition {
531 name,
532 description,
533 parameters,
534 },
535 }
536 }
537
538 pub fn function_name(&self) -> &str {
540 &self.function.name
541 }
542
543 pub fn validate(&self) -> Result<(), String> {
545 if self.tool_type != "function" {
546 return Err(format!(
547 "Only 'function' type is supported, got: {}",
548 self.tool_type
549 ));
550 }
551
552 if self.function.name.is_empty() {
553 return Err("Function name cannot be empty".to_string());
554 }
555
556 if self.function.description.is_empty() {
557 return Err("Function description cannot be empty".to_string());
558 }
559
560 if !self.function.parameters.is_object() {
562 return Err("Function parameters must be a JSON object".to_string());
563 }
564
565 Ok(())
566 }
567}
568
569#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
572pub struct ToolCall {
573 pub id: String,
575
576 #[serde(rename = "type")]
578 pub call_type: String,
579
580 pub function: FunctionCall,
582}
583
584#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
586pub struct FunctionCall {
587 pub name: String,
589
590 pub arguments: String,
592}
593
594impl ToolCall {
595 pub fn function(id: String, name: String, arguments: String) -> Self {
597 Self {
598 id,
599 call_type: "function".to_string(),
600 function: FunctionCall { name, arguments },
601 }
602 }
603
604 pub fn parsed_arguments(&self) -> Result<Value, serde_json::Error> {
606 serde_json::from_str(&self.function.arguments)
607 }
608
609 pub fn validate(&self) -> Result<(), String> {
611 if self.call_type != "function" {
612 return Err(format!(
613 "Only 'function' type is supported, got: {}",
614 self.call_type
615 ));
616 }
617
618 if self.id.is_empty() {
619 return Err("Tool call ID cannot be empty".to_string());
620 }
621
622 if self.function.name.is_empty() {
623 return Err("Function name cannot be empty".to_string());
624 }
625
626 if let Err(e) = self.parsed_arguments() {
628 return Err(format!("Invalid JSON in function arguments: {}", e));
629 }
630
631 Ok(())
632 }
633}
634
635#[derive(Debug, Clone)]
637pub struct LLMResponse {
638 pub content: Option<String>,
639 pub tool_calls: Option<Vec<ToolCall>>,
640 pub usage: Option<Usage>,
641 pub finish_reason: FinishReason,
642 pub reasoning: Option<String>,
643}
644
645#[derive(Debug, Clone)]
646pub struct Usage {
647 pub prompt_tokens: u32,
648 pub completion_tokens: u32,
649 pub total_tokens: u32,
650 pub cached_prompt_tokens: Option<u32>,
651 pub cache_creation_tokens: Option<u32>,
652 pub cache_read_tokens: Option<u32>,
653}
654
655#[derive(Debug, Clone, PartialEq, Eq)]
656pub enum FinishReason {
657 Stop,
658 Length,
659 ToolCalls,
660 ContentFilter,
661 Error(String),
662}
663
664#[derive(Debug, Clone)]
665pub enum LLMStreamEvent {
666 Token { delta: String },
667 Reasoning { delta: String },
668 Completed { response: LLMResponse },
669}
670
671pub type LLMStream = Pin<Box<dyn futures::Stream<Item = Result<LLMStreamEvent, LLMError>> + Send>>;
672
673#[async_trait]
675pub trait LLMProvider: Send + Sync {
676 fn name(&self) -> &str;
678
679 fn supports_streaming(&self) -> bool {
681 false
682 }
683
684 fn supports_reasoning(&self, _model: &str) -> bool {
686 false
687 }
688
689 fn supports_reasoning_effort(&self, _model: &str) -> bool {
691 false
692 }
693
694 fn supports_tools(&self, _model: &str) -> bool {
696 true
697 }
698
699 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError>;
701
702 async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
704 let response = self.generate(request).await?;
706 let stream = try_stream! {
707 yield LLMStreamEvent::Completed { response };
708 };
709 Ok(Box::pin(stream))
710 }
711
712 fn supported_models(&self) -> Vec<String>;
714
715 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError>;
717}
718
719#[derive(Debug, thiserror::Error)]
720pub enum LLMError {
721 #[error("Authentication failed: {0}")]
722 Authentication(String),
723 #[error("Rate limit exceeded")]
724 RateLimit,
725 #[error("Invalid request: {0}")]
726 InvalidRequest(String),
727 #[error("Network error: {0}")]
728 Network(String),
729 #[error("Provider error: {0}")]
730 Provider(String),
731}
732
733impl From<LLMError> for crate::llm::types::LLMError {
735 fn from(err: LLMError) -> crate::llm::types::LLMError {
736 match err {
737 LLMError::Authentication(msg) => crate::llm::types::LLMError::ApiError(msg),
738 LLMError::RateLimit => crate::llm::types::LLMError::RateLimit,
739 LLMError::InvalidRequest(msg) => crate::llm::types::LLMError::InvalidRequest(msg),
740 LLMError::Network(msg) => crate::llm::types::LLMError::NetworkError(msg),
741 LLMError::Provider(msg) => crate::llm::types::LLMError::ApiError(msg),
742 }
743 }
744}