1use async_stream::try_stream;
50use async_trait::async_trait;
51use serde::{Deserialize, Serialize};
52use serde_json::{Value, json};
53use std::pin::Pin;
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct LLMRequest {
58 pub messages: Vec<Message>,
59 pub system_prompt: Option<String>,
60 pub tools: Option<Vec<ToolDefinition>>,
61 pub model: String,
62 pub max_tokens: Option<u32>,
63 pub temperature: Option<f32>,
64 pub stream: bool,
65
66 pub tool_choice: Option<ToolChoice>,
69
70 pub parallel_tool_calls: Option<bool>,
72
73 pub parallel_tool_config: Option<ParallelToolConfig>,
75
76 pub reasoning_effort: Option<String>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(untagged)]
86pub enum ToolChoice {
87 Auto,
90
91 None,
94
95 Any,
98
99 Specific(SpecificToolChoice),
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct SpecificToolChoice {
107 #[serde(rename = "type")]
108 pub tool_type: String, pub function: SpecificFunctionChoice,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct SpecificFunctionChoice {
116 pub name: String,
117}
118
119impl ToolChoice {
120 pub fn auto() -> Self {
122 Self::Auto
123 }
124
125 pub fn none() -> Self {
127 Self::None
128 }
129
130 pub fn any() -> Self {
132 Self::Any
133 }
134
135 pub fn function(name: String) -> Self {
137 Self::Specific(SpecificToolChoice {
138 tool_type: "function".to_string(),
139 function: SpecificFunctionChoice { name },
140 })
141 }
142
143 pub fn allows_parallel_tools(&self) -> bool {
146 match self {
147 Self::Auto => true,
149 Self::Any => true,
151 Self::Specific(_) => false,
153 Self::None => false,
155 }
156 }
157
158 pub fn description(&self) -> &'static str {
160 match self {
161 Self::Auto => "Model decides when to use tools (allows parallel)",
162 Self::None => "No tools will be used",
163 Self::Any => "At least one tool must be used (allows parallel)",
164 Self::Specific(_) => "Specific tool must be used (no parallel)",
165 }
166 }
167
168 pub fn to_provider_format(&self, provider: &str) -> Value {
170 match (self, provider) {
171 (Self::Auto, "openai") | (Self::Auto, "deepseek") => json!("auto"),
172 (Self::None, "openai") | (Self::None, "deepseek") => json!("none"),
173 (Self::Any, "openai") | (Self::Any, "deepseek") => json!("required"),
174 (Self::Specific(choice), "openai") | (Self::Specific(choice), "deepseek") => {
175 json!(choice)
176 }
177
178 (Self::Auto, "anthropic") => json!({"type": "auto"}),
179 (Self::None, "anthropic") => json!({"type": "none"}),
180 (Self::Any, "anthropic") => json!({"type": "any"}),
181 (Self::Specific(choice), "anthropic") => {
182 json!({"type": "tool", "name": choice.function.name})
183 }
184
185 (Self::Auto, "gemini") => json!({"mode": "auto"}),
186 (Self::None, "gemini") => json!({"mode": "none"}),
187 (Self::Any, "gemini") => json!({"mode": "any"}),
188 (Self::Specific(choice), "gemini") => {
189 json!({"mode": "any", "allowed_function_names": [choice.function.name]})
190 }
191
192 _ => match self {
194 Self::Auto => json!("auto"),
195 Self::None => json!("none"),
196 Self::Any => json!("required"),
197 Self::Specific(choice) => json!(choice),
198 },
199 }
200 }
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct ParallelToolConfig {
207 pub disable_parallel_tool_use: bool,
210
211 pub max_parallel_tools: Option<usize>,
214
215 pub encourage_parallel: bool,
217}
218
219impl Default for ParallelToolConfig {
220 fn default() -> Self {
221 Self {
222 disable_parallel_tool_use: false,
223 max_parallel_tools: Some(5), encourage_parallel: true,
225 }
226 }
227}
228
229impl ParallelToolConfig {
230 pub fn anthropic_optimized() -> Self {
232 Self {
233 disable_parallel_tool_use: false,
234 max_parallel_tools: None, encourage_parallel: true,
236 }
237 }
238
239 pub fn sequential_only() -> Self {
241 Self {
242 disable_parallel_tool_use: true,
243 max_parallel_tools: Some(1),
244 encourage_parallel: false,
245 }
246 }
247}
248
249#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
251pub struct Message {
252 pub role: MessageRole,
253 pub content: String,
254 pub tool_calls: Option<Vec<ToolCall>>,
255 pub tool_call_id: Option<String>,
256}
257
258impl Message {
259 pub fn user(content: String) -> Self {
261 Self {
262 role: MessageRole::User,
263 content,
264 tool_calls: None,
265 tool_call_id: None,
266 }
267 }
268
269 pub fn assistant(content: String) -> Self {
271 Self {
272 role: MessageRole::Assistant,
273 content,
274 tool_calls: None,
275 tool_call_id: None,
276 }
277 }
278
279 pub fn assistant_with_tools(content: String, tool_calls: Vec<ToolCall>) -> Self {
282 Self {
283 role: MessageRole::Assistant,
284 content,
285 tool_calls: Some(tool_calls),
286 tool_call_id: None,
287 }
288 }
289
290 pub fn system(content: String) -> Self {
292 Self {
293 role: MessageRole::System,
294 content,
295 tool_calls: None,
296 tool_call_id: None,
297 }
298 }
299
300 pub fn tool_response(tool_call_id: String, content: String) -> Self {
310 Self {
311 role: MessageRole::Tool,
312 content,
313 tool_calls: None,
314 tool_call_id: Some(tool_call_id),
315 }
316 }
317
318 pub fn tool_response_with_name(
321 tool_call_id: String,
322 _function_name: String,
323 content: String,
324 ) -> Self {
325 Self::tool_response(tool_call_id, content)
327 }
328
329 pub fn validate_for_provider(&self, provider: &str) -> Result<(), String> {
332 self.role
334 .validate_for_provider(provider, self.tool_call_id.is_some())?;
335
336 if let Some(tool_calls) = &self.tool_calls {
338 if !self.role.can_make_tool_calls() {
339 return Err(format!("Role {:?} cannot make tool calls", self.role));
340 }
341
342 if tool_calls.is_empty() {
343 return Err("Tool calls array should not be empty".to_string());
344 }
345
346 for tool_call in tool_calls {
348 tool_call.validate()?;
349 }
350 }
351
352 match provider {
354 "openai" | "openrouter" => {
355 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
356 return Err(format!(
357 "{} requires tool_call_id for tool messages",
358 provider
359 ));
360 }
361 }
362 "gemini" => {
363 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
364 return Err(
365 "Gemini tool responses need tool_call_id for function name mapping"
366 .to_string(),
367 );
368 }
369 if self.role == MessageRole::System && !self.content.is_empty() {
371 }
373 }
374 "anthropic" => {
375 }
378 _ => {} }
380
381 Ok(())
382 }
383
384 pub fn has_tool_calls(&self) -> bool {
386 self.tool_calls
387 .as_ref()
388 .map_or(false, |calls| !calls.is_empty())
389 }
390
391 pub fn get_tool_calls(&self) -> Option<&[ToolCall]> {
393 self.tool_calls.as_deref()
394 }
395
396 pub fn is_tool_response(&self) -> bool {
398 self.role == MessageRole::Tool
399 }
400}
401
402#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
403pub enum MessageRole {
404 System,
405 User,
406 Assistant,
407 Tool,
408}
409
410impl MessageRole {
411 pub fn as_gemini_str(&self) -> &'static str {
417 match self {
418 MessageRole::System => "system", MessageRole::User => "user",
420 MessageRole::Assistant => "model", MessageRole::Tool => "user", }
423 }
424
425 pub fn as_openai_str(&self) -> &'static str {
430 match self {
431 MessageRole::System => "system",
432 MessageRole::User => "user",
433 MessageRole::Assistant => "assistant",
434 MessageRole::Tool => "tool", }
436 }
437
438 pub fn as_anthropic_str(&self) -> &'static str {
444 match self {
445 MessageRole::System => "system", MessageRole::User => "user",
447 MessageRole::Assistant => "assistant",
448 MessageRole::Tool => "user", }
450 }
451
452 pub fn as_generic_str(&self) -> &'static str {
455 match self {
456 MessageRole::System => "system",
457 MessageRole::User => "user",
458 MessageRole::Assistant => "assistant",
459 MessageRole::Tool => "tool",
460 }
461 }
462
463 pub fn can_make_tool_calls(&self) -> bool {
466 matches!(self, MessageRole::Assistant)
467 }
468
469 pub fn is_tool_response(&self) -> bool {
471 matches!(self, MessageRole::Tool)
472 }
473
474 pub fn validate_for_provider(
477 &self,
478 provider: &str,
479 has_tool_call_id: bool,
480 ) -> Result<(), String> {
481 match (self, provider) {
482 (MessageRole::Tool, provider)
483 if matches!(provider, "openai" | "openrouter" | "xai" | "deepseek")
484 && !has_tool_call_id =>
485 {
486 Err(format!("{} tool messages must have tool_call_id", provider))
487 }
488 (MessageRole::Tool, "gemini") if !has_tool_call_id => {
489 Err("Gemini tool messages need tool_call_id for function mapping".to_string())
490 }
491 _ => Ok(()),
492 }
493 }
494}
495
496#[derive(Debug, Clone, Serialize, Deserialize)]
499pub struct ToolDefinition {
500 #[serde(rename = "type")]
502 pub tool_type: String,
503
504 pub function: FunctionDefinition,
506}
507
508#[derive(Debug, Clone, Serialize, Deserialize)]
510pub struct FunctionDefinition {
511 pub name: String,
513
514 pub description: String,
516
517 pub parameters: Value,
519}
520
521impl ToolDefinition {
522 pub fn function(name: String, description: String, parameters: Value) -> Self {
524 Self {
525 tool_type: "function".to_string(),
526 function: FunctionDefinition {
527 name,
528 description,
529 parameters,
530 },
531 }
532 }
533
534 pub fn function_name(&self) -> &str {
536 &self.function.name
537 }
538
539 pub fn validate(&self) -> Result<(), String> {
541 if self.tool_type != "function" {
542 return Err(format!(
543 "Only 'function' type is supported, got: {}",
544 self.tool_type
545 ));
546 }
547
548 if self.function.name.is_empty() {
549 return Err("Function name cannot be empty".to_string());
550 }
551
552 if self.function.description.is_empty() {
553 return Err("Function description cannot be empty".to_string());
554 }
555
556 if !self.function.parameters.is_object() {
558 return Err("Function parameters must be a JSON object".to_string());
559 }
560
561 Ok(())
562 }
563}
564
565#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
568pub struct ToolCall {
569 pub id: String,
571
572 #[serde(rename = "type")]
574 pub call_type: String,
575
576 pub function: FunctionCall,
578}
579
580#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
582pub struct FunctionCall {
583 pub name: String,
585
586 pub arguments: String,
588}
589
590impl ToolCall {
591 pub fn function(id: String, name: String, arguments: String) -> Self {
593 Self {
594 id,
595 call_type: "function".to_string(),
596 function: FunctionCall { name, arguments },
597 }
598 }
599
600 pub fn parsed_arguments(&self) -> Result<Value, serde_json::Error> {
602 serde_json::from_str(&self.function.arguments)
603 }
604
605 pub fn validate(&self) -> Result<(), String> {
607 if self.call_type != "function" {
608 return Err(format!(
609 "Only 'function' type is supported, got: {}",
610 self.call_type
611 ));
612 }
613
614 if self.id.is_empty() {
615 return Err("Tool call ID cannot be empty".to_string());
616 }
617
618 if self.function.name.is_empty() {
619 return Err("Function name cannot be empty".to_string());
620 }
621
622 if let Err(e) = self.parsed_arguments() {
624 return Err(format!("Invalid JSON in function arguments: {}", e));
625 }
626
627 Ok(())
628 }
629}
630
631#[derive(Debug, Clone)]
633pub struct LLMResponse {
634 pub content: Option<String>,
635 pub tool_calls: Option<Vec<ToolCall>>,
636 pub usage: Option<Usage>,
637 pub finish_reason: FinishReason,
638 pub reasoning: Option<String>,
639}
640
641#[derive(Debug, Clone)]
642pub struct Usage {
643 pub prompt_tokens: u32,
644 pub completion_tokens: u32,
645 pub total_tokens: u32,
646 pub cached_prompt_tokens: Option<u32>,
647 pub cache_creation_tokens: Option<u32>,
648 pub cache_read_tokens: Option<u32>,
649}
650
651#[derive(Debug, Clone, PartialEq, Eq)]
652pub enum FinishReason {
653 Stop,
654 Length,
655 ToolCalls,
656 ContentFilter,
657 Error(String),
658}
659
660#[derive(Debug, Clone)]
661pub enum LLMStreamEvent {
662 Token { delta: String },
663 Reasoning { delta: String },
664 Completed { response: LLMResponse },
665}
666
667pub type LLMStream = Pin<Box<dyn futures::Stream<Item = Result<LLMStreamEvent, LLMError>> + Send>>;
668
669#[async_trait]
671pub trait LLMProvider: Send + Sync {
672 fn name(&self) -> &str;
674
675 fn supports_streaming(&self) -> bool {
677 false
678 }
679
680 fn supports_reasoning(&self, _model: &str) -> bool {
682 false
683 }
684
685 fn supports_reasoning_effort(&self, _model: &str) -> bool {
687 false
688 }
689
690 fn supports_tools(&self, _model: &str) -> bool {
692 true
693 }
694
695 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError>;
697
698 async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
700 let response = self.generate(request).await?;
702 let stream = try_stream! {
703 yield LLMStreamEvent::Completed { response };
704 };
705 Ok(Box::pin(stream))
706 }
707
708 fn supported_models(&self) -> Vec<String>;
710
711 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError>;
713}
714
715#[derive(Debug, thiserror::Error)]
716pub enum LLMError {
717 #[error("Authentication failed: {0}")]
718 Authentication(String),
719 #[error("Rate limit exceeded")]
720 RateLimit,
721 #[error("Invalid request: {0}")]
722 InvalidRequest(String),
723 #[error("Network error: {0}")]
724 Network(String),
725 #[error("Provider error: {0}")]
726 Provider(String),
727}
728
729impl From<LLMError> for crate::llm::types::LLMError {
731 fn from(err: LLMError) -> crate::llm::types::LLMError {
732 match err {
733 LLMError::Authentication(msg) => crate::llm::types::LLMError::ApiError(msg),
734 LLMError::RateLimit => crate::llm::types::LLMError::RateLimit,
735 LLMError::InvalidRequest(msg) => crate::llm::types::LLMError::InvalidRequest(msg),
736 LLMError::Network(msg) => crate::llm::types::LLMError::NetworkError(msg),
737 LLMError::Provider(msg) => crate::llm::types::LLMError::ApiError(msg),
738 }
739 }
740}