1use async_stream::try_stream;
50use async_trait::async_trait;
51use serde::{Deserialize, Serialize};
52use serde_json::{Value, json};
53use std::pin::Pin;
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct LLMRequest {
58 pub messages: Vec<Message>,
59 pub system_prompt: Option<String>,
60 pub tools: Option<Vec<ToolDefinition>>,
61 pub model: String,
62 pub max_tokens: Option<u32>,
63 pub temperature: Option<f32>,
64 pub stream: bool,
65
66 pub tool_choice: Option<ToolChoice>,
69
70 pub parallel_tool_calls: Option<bool>,
72
73 pub parallel_tool_config: Option<ParallelToolConfig>,
75
76 pub reasoning_effort: Option<String>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(untagged)]
86pub enum ToolChoice {
87 Auto,
90
91 None,
94
95 Any,
98
99 Specific(SpecificToolChoice),
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct SpecificToolChoice {
107 #[serde(rename = "type")]
108 pub tool_type: String, pub function: SpecificFunctionChoice,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct SpecificFunctionChoice {
116 pub name: String,
117}
118
119impl ToolChoice {
120 pub fn auto() -> Self {
122 Self::Auto
123 }
124
125 pub fn none() -> Self {
127 Self::None
128 }
129
130 pub fn any() -> Self {
132 Self::Any
133 }
134
135 pub fn function(name: String) -> Self {
137 Self::Specific(SpecificToolChoice {
138 tool_type: "function".to_string(),
139 function: SpecificFunctionChoice { name },
140 })
141 }
142
143 pub fn allows_parallel_tools(&self) -> bool {
146 match self {
147 Self::Auto => true,
149 Self::Any => true,
151 Self::Specific(_) => false,
153 Self::None => false,
155 }
156 }
157
158 pub fn description(&self) -> &'static str {
160 match self {
161 Self::Auto => "Model decides when to use tools (allows parallel)",
162 Self::None => "No tools will be used",
163 Self::Any => "At least one tool must be used (allows parallel)",
164 Self::Specific(_) => "Specific tool must be used (no parallel)",
165 }
166 }
167
168 pub fn to_provider_format(&self, provider: &str) -> Value {
170 match (self, provider) {
171 (Self::Auto, "openai") => json!("auto"),
172 (Self::None, "openai") => json!("none"),
173 (Self::Any, "openai") => json!("required"), (Self::Specific(choice), "openai") => json!(choice),
175
176 (Self::Auto, "anthropic") => json!({"type": "auto"}),
177 (Self::None, "anthropic") => json!({"type": "none"}),
178 (Self::Any, "anthropic") => json!({"type": "any"}),
179 (Self::Specific(choice), "anthropic") => {
180 json!({"type": "tool", "name": choice.function.name})
181 }
182
183 (Self::Auto, "gemini") => json!({"mode": "auto"}),
184 (Self::None, "gemini") => json!({"mode": "none"}),
185 (Self::Any, "gemini") => json!({"mode": "any"}),
186 (Self::Specific(choice), "gemini") => {
187 json!({"mode": "any", "allowed_function_names": [choice.function.name]})
188 }
189
190 _ => match self {
192 Self::Auto => json!("auto"),
193 Self::None => json!("none"),
194 Self::Any => json!("required"),
195 Self::Specific(choice) => json!(choice),
196 },
197 }
198 }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct ParallelToolConfig {
205 pub disable_parallel_tool_use: bool,
208
209 pub max_parallel_tools: Option<usize>,
212
213 pub encourage_parallel: bool,
215}
216
217impl Default for ParallelToolConfig {
218 fn default() -> Self {
219 Self {
220 disable_parallel_tool_use: false,
221 max_parallel_tools: Some(5), encourage_parallel: true,
223 }
224 }
225}
226
227impl ParallelToolConfig {
228 pub fn anthropic_optimized() -> Self {
230 Self {
231 disable_parallel_tool_use: false,
232 max_parallel_tools: None, encourage_parallel: true,
234 }
235 }
236
237 pub fn sequential_only() -> Self {
239 Self {
240 disable_parallel_tool_use: true,
241 max_parallel_tools: Some(1),
242 encourage_parallel: false,
243 }
244 }
245}
246
247#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
249pub struct Message {
250 pub role: MessageRole,
251 pub content: String,
252 pub tool_calls: Option<Vec<ToolCall>>,
253 pub tool_call_id: Option<String>,
254}
255
256impl Message {
257 pub fn user(content: String) -> Self {
259 Self {
260 role: MessageRole::User,
261 content,
262 tool_calls: None,
263 tool_call_id: None,
264 }
265 }
266
267 pub fn assistant(content: String) -> Self {
269 Self {
270 role: MessageRole::Assistant,
271 content,
272 tool_calls: None,
273 tool_call_id: None,
274 }
275 }
276
277 pub fn assistant_with_tools(content: String, tool_calls: Vec<ToolCall>) -> Self {
280 Self {
281 role: MessageRole::Assistant,
282 content,
283 tool_calls: Some(tool_calls),
284 tool_call_id: None,
285 }
286 }
287
288 pub fn system(content: String) -> Self {
290 Self {
291 role: MessageRole::System,
292 content,
293 tool_calls: None,
294 tool_call_id: None,
295 }
296 }
297
298 pub fn tool_response(tool_call_id: String, content: String) -> Self {
308 Self {
309 role: MessageRole::Tool,
310 content,
311 tool_calls: None,
312 tool_call_id: Some(tool_call_id),
313 }
314 }
315
316 pub fn tool_response_with_name(
319 tool_call_id: String,
320 _function_name: String,
321 content: String,
322 ) -> Self {
323 Self::tool_response(tool_call_id, content)
325 }
326
327 pub fn validate_for_provider(&self, provider: &str) -> Result<(), String> {
330 self.role
332 .validate_for_provider(provider, self.tool_call_id.is_some())?;
333
334 if let Some(tool_calls) = &self.tool_calls {
336 if !self.role.can_make_tool_calls() {
337 return Err(format!("Role {:?} cannot make tool calls", self.role));
338 }
339
340 if tool_calls.is_empty() {
341 return Err("Tool calls array should not be empty".to_string());
342 }
343
344 for tool_call in tool_calls {
346 tool_call.validate()?;
347 }
348 }
349
350 match provider {
352 "openai" | "openrouter" => {
353 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
354 return Err(format!(
355 "{} requires tool_call_id for tool messages",
356 provider
357 ));
358 }
359 }
360 "gemini" => {
361 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
362 return Err(
363 "Gemini tool responses need tool_call_id for function name mapping"
364 .to_string(),
365 );
366 }
367 if self.role == MessageRole::System && !self.content.is_empty() {
369 }
371 }
372 "anthropic" => {
373 }
376 _ => {} }
378
379 Ok(())
380 }
381
382 pub fn has_tool_calls(&self) -> bool {
384 self.tool_calls
385 .as_ref()
386 .map_or(false, |calls| !calls.is_empty())
387 }
388
389 pub fn get_tool_calls(&self) -> Option<&[ToolCall]> {
391 self.tool_calls.as_deref()
392 }
393
394 pub fn is_tool_response(&self) -> bool {
396 self.role == MessageRole::Tool
397 }
398}
399
400#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
401pub enum MessageRole {
402 System,
403 User,
404 Assistant,
405 Tool,
406}
407
408impl MessageRole {
409 pub fn as_gemini_str(&self) -> &'static str {
415 match self {
416 MessageRole::System => "system", MessageRole::User => "user",
418 MessageRole::Assistant => "model", MessageRole::Tool => "user", }
421 }
422
423 pub fn as_openai_str(&self) -> &'static str {
428 match self {
429 MessageRole::System => "system",
430 MessageRole::User => "user",
431 MessageRole::Assistant => "assistant",
432 MessageRole::Tool => "tool", }
434 }
435
436 pub fn as_anthropic_str(&self) -> &'static str {
442 match self {
443 MessageRole::System => "system", MessageRole::User => "user",
445 MessageRole::Assistant => "assistant",
446 MessageRole::Tool => "user", }
448 }
449
450 pub fn as_generic_str(&self) -> &'static str {
453 match self {
454 MessageRole::System => "system",
455 MessageRole::User => "user",
456 MessageRole::Assistant => "assistant",
457 MessageRole::Tool => "tool",
458 }
459 }
460
461 pub fn can_make_tool_calls(&self) -> bool {
464 matches!(self, MessageRole::Assistant)
465 }
466
467 pub fn is_tool_response(&self) -> bool {
469 matches!(self, MessageRole::Tool)
470 }
471
472 pub fn validate_for_provider(
475 &self,
476 provider: &str,
477 has_tool_call_id: bool,
478 ) -> Result<(), String> {
479 match (self, provider) {
480 (MessageRole::Tool, provider)
481 if matches!(provider, "openai" | "openrouter" | "xai") && !has_tool_call_id =>
482 {
483 Err(format!("{} tool messages must have tool_call_id", provider))
484 }
485 (MessageRole::Tool, "gemini") if !has_tool_call_id => {
486 Err("Gemini tool messages need tool_call_id for function mapping".to_string())
487 }
488 _ => Ok(()),
489 }
490 }
491}
492
493#[derive(Debug, Clone, Serialize, Deserialize)]
496pub struct ToolDefinition {
497 #[serde(rename = "type")]
499 pub tool_type: String,
500
501 pub function: FunctionDefinition,
503}
504
505#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct FunctionDefinition {
508 pub name: String,
510
511 pub description: String,
513
514 pub parameters: Value,
516}
517
518impl ToolDefinition {
519 pub fn function(name: String, description: String, parameters: Value) -> Self {
521 Self {
522 tool_type: "function".to_string(),
523 function: FunctionDefinition {
524 name,
525 description,
526 parameters,
527 },
528 }
529 }
530
531 pub fn function_name(&self) -> &str {
533 &self.function.name
534 }
535
536 pub fn validate(&self) -> Result<(), String> {
538 if self.tool_type != "function" {
539 return Err(format!(
540 "Only 'function' type is supported, got: {}",
541 self.tool_type
542 ));
543 }
544
545 if self.function.name.is_empty() {
546 return Err("Function name cannot be empty".to_string());
547 }
548
549 if self.function.description.is_empty() {
550 return Err("Function description cannot be empty".to_string());
551 }
552
553 if !self.function.parameters.is_object() {
555 return Err("Function parameters must be a JSON object".to_string());
556 }
557
558 Ok(())
559 }
560}
561
562#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
565pub struct ToolCall {
566 pub id: String,
568
569 #[serde(rename = "type")]
571 pub call_type: String,
572
573 pub function: FunctionCall,
575}
576
577#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
579pub struct FunctionCall {
580 pub name: String,
582
583 pub arguments: String,
585}
586
587impl ToolCall {
588 pub fn function(id: String, name: String, arguments: String) -> Self {
590 Self {
591 id,
592 call_type: "function".to_string(),
593 function: FunctionCall { name, arguments },
594 }
595 }
596
597 pub fn parsed_arguments(&self) -> Result<Value, serde_json::Error> {
599 serde_json::from_str(&self.function.arguments)
600 }
601
602 pub fn validate(&self) -> Result<(), String> {
604 if self.call_type != "function" {
605 return Err(format!(
606 "Only 'function' type is supported, got: {}",
607 self.call_type
608 ));
609 }
610
611 if self.id.is_empty() {
612 return Err("Tool call ID cannot be empty".to_string());
613 }
614
615 if self.function.name.is_empty() {
616 return Err("Function name cannot be empty".to_string());
617 }
618
619 if let Err(e) = self.parsed_arguments() {
621 return Err(format!("Invalid JSON in function arguments: {}", e));
622 }
623
624 Ok(())
625 }
626}
627
628#[derive(Debug, Clone)]
630pub struct LLMResponse {
631 pub content: Option<String>,
632 pub tool_calls: Option<Vec<ToolCall>>,
633 pub usage: Option<Usage>,
634 pub finish_reason: FinishReason,
635 pub reasoning: Option<String>,
636}
637
638#[derive(Debug, Clone)]
639pub struct Usage {
640 pub prompt_tokens: u32,
641 pub completion_tokens: u32,
642 pub total_tokens: u32,
643 pub cached_prompt_tokens: Option<u32>,
644 pub cache_creation_tokens: Option<u32>,
645 pub cache_read_tokens: Option<u32>,
646}
647
648#[derive(Debug, Clone, PartialEq, Eq)]
649pub enum FinishReason {
650 Stop,
651 Length,
652 ToolCalls,
653 ContentFilter,
654 Error(String),
655}
656
657#[derive(Debug, Clone)]
658pub enum LLMStreamEvent {
659 Token { delta: String },
660 Reasoning { delta: String },
661 Completed { response: LLMResponse },
662}
663
664pub type LLMStream = Pin<Box<dyn futures::Stream<Item = Result<LLMStreamEvent, LLMError>> + Send>>;
665
666#[async_trait]
668pub trait LLMProvider: Send + Sync {
669 fn name(&self) -> &str;
671
672 fn supports_streaming(&self) -> bool {
674 false
675 }
676
677 fn supports_reasoning(&self, _model: &str) -> bool {
679 false
680 }
681
682 fn supports_reasoning_effort(&self, _model: &str) -> bool {
684 false
685 }
686
687 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError>;
689
690 async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
692 let response = self.generate(request).await?;
694 let stream = try_stream! {
695 yield LLMStreamEvent::Completed { response };
696 };
697 Ok(Box::pin(stream))
698 }
699
700 fn supported_models(&self) -> Vec<String>;
702
703 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError>;
705}
706
707#[derive(Debug, thiserror::Error)]
708pub enum LLMError {
709 #[error("Authentication failed: {0}")]
710 Authentication(String),
711 #[error("Rate limit exceeded")]
712 RateLimit,
713 #[error("Invalid request: {0}")]
714 InvalidRequest(String),
715 #[error("Network error: {0}")]
716 Network(String),
717 #[error("Provider error: {0}")]
718 Provider(String),
719}
720
721impl From<LLMError> for crate::llm::types::LLMError {
723 fn from(err: LLMError) -> crate::llm::types::LLMError {
724 match err {
725 LLMError::Authentication(msg) => crate::llm::types::LLMError::ApiError(msg),
726 LLMError::RateLimit => crate::llm::types::LLMError::RateLimit,
727 LLMError::InvalidRequest(msg) => crate::llm::types::LLMError::InvalidRequest(msg),
728 LLMError::Network(msg) => crate::llm::types::LLMError::NetworkError(msg),
729 LLMError::Provider(msg) => crate::llm::types::LLMError::ApiError(msg),
730 }
731 }
732}