1use async_trait::async_trait;
50use serde::{Deserialize, Serialize};
51use serde_json::{Value, json};
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct LLMRequest {
56 pub messages: Vec<Message>,
57 pub system_prompt: Option<String>,
58 pub tools: Option<Vec<ToolDefinition>>,
59 pub model: String,
60 pub max_tokens: Option<u32>,
61 pub temperature: Option<f32>,
62 pub stream: bool,
63
64 pub tool_choice: Option<ToolChoice>,
67
68 pub parallel_tool_calls: Option<bool>,
70
71 pub parallel_tool_config: Option<ParallelToolConfig>,
73
74 pub reasoning_effort: Option<String>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
83#[serde(untagged)]
84pub enum ToolChoice {
85 Auto,
88
89 None,
92
93 Any,
96
97 Specific(SpecificToolChoice),
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct SpecificToolChoice {
105 #[serde(rename = "type")]
106 pub tool_type: String, pub function: SpecificFunctionChoice,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct SpecificFunctionChoice {
114 pub name: String,
115}
116
117impl ToolChoice {
118 pub fn auto() -> Self {
120 Self::Auto
121 }
122
123 pub fn none() -> Self {
125 Self::None
126 }
127
128 pub fn any() -> Self {
130 Self::Any
131 }
132
133 pub fn function(name: String) -> Self {
135 Self::Specific(SpecificToolChoice {
136 tool_type: "function".to_string(),
137 function: SpecificFunctionChoice { name },
138 })
139 }
140
141 pub fn allows_parallel_tools(&self) -> bool {
144 match self {
145 Self::Auto => true,
147 Self::Any => true,
149 Self::Specific(_) => false,
151 Self::None => false,
153 }
154 }
155
156 pub fn description(&self) -> &'static str {
158 match self {
159 Self::Auto => "Model decides when to use tools (allows parallel)",
160 Self::None => "No tools will be used",
161 Self::Any => "At least one tool must be used (allows parallel)",
162 Self::Specific(_) => "Specific tool must be used (no parallel)",
163 }
164 }
165
166 pub fn to_provider_format(&self, provider: &str) -> Value {
168 match (self, provider) {
169 (Self::Auto, "openai") => json!("auto"),
170 (Self::None, "openai") => json!("none"),
171 (Self::Any, "openai") => json!("required"), (Self::Specific(choice), "openai") => json!(choice),
173
174 (Self::Auto, "anthropic") => json!({"type": "auto"}),
175 (Self::None, "anthropic") => json!({"type": "none"}),
176 (Self::Any, "anthropic") => json!({"type": "any"}),
177 (Self::Specific(choice), "anthropic") => {
178 json!({"type": "tool", "name": choice.function.name})
179 }
180
181 (Self::Auto, "gemini") => json!({"mode": "auto"}),
182 (Self::None, "gemini") => json!({"mode": "none"}),
183 (Self::Any, "gemini") => json!({"mode": "any"}),
184 (Self::Specific(choice), "gemini") => {
185 json!({"mode": "any", "allowed_function_names": [choice.function.name]})
186 }
187
188 _ => match self {
190 Self::Auto => json!("auto"),
191 Self::None => json!("none"),
192 Self::Any => json!("required"),
193 Self::Specific(choice) => json!(choice),
194 },
195 }
196 }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct ParallelToolConfig {
203 pub disable_parallel_tool_use: bool,
206
207 pub max_parallel_tools: Option<usize>,
210
211 pub encourage_parallel: bool,
213}
214
215impl Default for ParallelToolConfig {
216 fn default() -> Self {
217 Self {
218 disable_parallel_tool_use: false,
219 max_parallel_tools: Some(5), encourage_parallel: true,
221 }
222 }
223}
224
225impl ParallelToolConfig {
226 pub fn anthropic_optimized() -> Self {
228 Self {
229 disable_parallel_tool_use: false,
230 max_parallel_tools: None, encourage_parallel: true,
232 }
233 }
234
235 pub fn sequential_only() -> Self {
237 Self {
238 disable_parallel_tool_use: true,
239 max_parallel_tools: Some(1),
240 encourage_parallel: false,
241 }
242 }
243}
244
245#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
247pub struct Message {
248 pub role: MessageRole,
249 pub content: String,
250 pub tool_calls: Option<Vec<ToolCall>>,
251 pub tool_call_id: Option<String>,
252}
253
254impl Message {
255 pub fn user(content: String) -> Self {
257 Self {
258 role: MessageRole::User,
259 content,
260 tool_calls: None,
261 tool_call_id: None,
262 }
263 }
264
265 pub fn assistant(content: String) -> Self {
267 Self {
268 role: MessageRole::Assistant,
269 content,
270 tool_calls: None,
271 tool_call_id: None,
272 }
273 }
274
275 pub fn assistant_with_tools(content: String, tool_calls: Vec<ToolCall>) -> Self {
278 Self {
279 role: MessageRole::Assistant,
280 content,
281 tool_calls: Some(tool_calls),
282 tool_call_id: None,
283 }
284 }
285
286 pub fn system(content: String) -> Self {
288 Self {
289 role: MessageRole::System,
290 content,
291 tool_calls: None,
292 tool_call_id: None,
293 }
294 }
295
296 pub fn tool_response(tool_call_id: String, content: String) -> Self {
306 Self {
307 role: MessageRole::Tool,
308 content,
309 tool_calls: None,
310 tool_call_id: Some(tool_call_id),
311 }
312 }
313
314 pub fn tool_response_with_name(
317 tool_call_id: String,
318 _function_name: String,
319 content: String,
320 ) -> Self {
321 Self::tool_response(tool_call_id, content)
323 }
324
325 pub fn validate_for_provider(&self, provider: &str) -> Result<(), String> {
328 self.role
330 .validate_for_provider(provider, self.tool_call_id.is_some())?;
331
332 if let Some(tool_calls) = &self.tool_calls {
334 if !self.role.can_make_tool_calls() {
335 return Err(format!("Role {:?} cannot make tool calls", self.role));
336 }
337
338 if tool_calls.is_empty() {
339 return Err("Tool calls array should not be empty".to_string());
340 }
341
342 for tool_call in tool_calls {
344 tool_call.validate()?;
345 }
346 }
347
348 match provider {
350 "openai" => {
351 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
352 return Err(format!(
353 "{} requires tool_call_id for tool messages",
354 provider
355 ));
356 }
357 }
358 "gemini" => {
359 if self.role == MessageRole::Tool && self.tool_call_id.is_none() {
360 return Err(
361 "Gemini tool responses need tool_call_id for function name mapping"
362 .to_string(),
363 );
364 }
365 if self.role == MessageRole::System && !self.content.is_empty() {
367 }
369 }
370 "anthropic" => {
371 }
374 _ => {} }
376
377 Ok(())
378 }
379
380 pub fn has_tool_calls(&self) -> bool {
382 self.tool_calls
383 .as_ref()
384 .map_or(false, |calls| !calls.is_empty())
385 }
386
387 pub fn get_tool_calls(&self) -> Option<&[ToolCall]> {
389 self.tool_calls.as_deref()
390 }
391
392 pub fn is_tool_response(&self) -> bool {
394 self.role == MessageRole::Tool
395 }
396}
397
398#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
399pub enum MessageRole {
400 System,
401 User,
402 Assistant,
403 Tool,
404}
405
406impl MessageRole {
407 pub fn as_gemini_str(&self) -> &'static str {
413 match self {
414 MessageRole::System => "system", MessageRole::User => "user",
416 MessageRole::Assistant => "model", MessageRole::Tool => "user", }
419 }
420
421 pub fn as_openai_str(&self) -> &'static str {
426 match self {
427 MessageRole::System => "system",
428 MessageRole::User => "user",
429 MessageRole::Assistant => "assistant",
430 MessageRole::Tool => "tool", }
432 }
433
434 pub fn as_anthropic_str(&self) -> &'static str {
440 match self {
441 MessageRole::System => "system", MessageRole::User => "user",
443 MessageRole::Assistant => "assistant",
444 MessageRole::Tool => "user", }
446 }
447
448 pub fn as_generic_str(&self) -> &'static str {
451 match self {
452 MessageRole::System => "system",
453 MessageRole::User => "user",
454 MessageRole::Assistant => "assistant",
455 MessageRole::Tool => "tool",
456 }
457 }
458
459 pub fn can_make_tool_calls(&self) -> bool {
462 matches!(self, MessageRole::Assistant)
463 }
464
465 pub fn is_tool_response(&self) -> bool {
467 matches!(self, MessageRole::Tool)
468 }
469
470 pub fn validate_for_provider(
473 &self,
474 provider: &str,
475 has_tool_call_id: bool,
476 ) -> Result<(), String> {
477 match (self, provider) {
478 (MessageRole::Tool, "openai") if !has_tool_call_id => {
479 Err("OpenAI tool messages must have tool_call_id".to_string())
480 }
481 (MessageRole::Tool, "gemini") if !has_tool_call_id => {
482 Err("Gemini tool messages need tool_call_id for function mapping".to_string())
483 }
484 _ => Ok(()),
485 }
486 }
487}
488
489#[derive(Debug, Clone, Serialize, Deserialize)]
492pub struct ToolDefinition {
493 #[serde(rename = "type")]
495 pub tool_type: String,
496
497 pub function: FunctionDefinition,
499}
500
501#[derive(Debug, Clone, Serialize, Deserialize)]
503pub struct FunctionDefinition {
504 pub name: String,
506
507 pub description: String,
509
510 pub parameters: Value,
512}
513
514impl ToolDefinition {
515 pub fn function(name: String, description: String, parameters: Value) -> Self {
517 Self {
518 tool_type: "function".to_string(),
519 function: FunctionDefinition {
520 name,
521 description,
522 parameters,
523 },
524 }
525 }
526
527 pub fn function_name(&self) -> &str {
529 &self.function.name
530 }
531
532 pub fn validate(&self) -> Result<(), String> {
534 if self.tool_type != "function" {
535 return Err(format!(
536 "Only 'function' type is supported, got: {}",
537 self.tool_type
538 ));
539 }
540
541 if self.function.name.is_empty() {
542 return Err("Function name cannot be empty".to_string());
543 }
544
545 if self.function.description.is_empty() {
546 return Err("Function description cannot be empty".to_string());
547 }
548
549 if !self.function.parameters.is_object() {
551 return Err("Function parameters must be a JSON object".to_string());
552 }
553
554 Ok(())
555 }
556}
557
558#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
561pub struct ToolCall {
562 pub id: String,
564
565 #[serde(rename = "type")]
567 pub call_type: String,
568
569 pub function: FunctionCall,
571}
572
573#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
575pub struct FunctionCall {
576 pub name: String,
578
579 pub arguments: String,
581}
582
583impl ToolCall {
584 pub fn function(id: String, name: String, arguments: String) -> Self {
586 Self {
587 id,
588 call_type: "function".to_string(),
589 function: FunctionCall { name, arguments },
590 }
591 }
592
593 pub fn parsed_arguments(&self) -> Result<Value, serde_json::Error> {
595 serde_json::from_str(&self.function.arguments)
596 }
597
598 pub fn validate(&self) -> Result<(), String> {
600 if self.call_type != "function" {
601 return Err(format!(
602 "Only 'function' type is supported, got: {}",
603 self.call_type
604 ));
605 }
606
607 if self.id.is_empty() {
608 return Err("Tool call ID cannot be empty".to_string());
609 }
610
611 if self.function.name.is_empty() {
612 return Err("Function name cannot be empty".to_string());
613 }
614
615 if let Err(e) = self.parsed_arguments() {
617 return Err(format!("Invalid JSON in function arguments: {}", e));
618 }
619
620 Ok(())
621 }
622}
623
624#[derive(Debug, Clone)]
626pub struct LLMResponse {
627 pub content: Option<String>,
628 pub tool_calls: Option<Vec<ToolCall>>,
629 pub usage: Option<Usage>,
630 pub finish_reason: FinishReason,
631}
632
633#[derive(Debug, Clone)]
634pub struct Usage {
635 pub prompt_tokens: u32,
636 pub completion_tokens: u32,
637 pub total_tokens: u32,
638}
639
640#[derive(Debug, Clone)]
641pub enum FinishReason {
642 Stop,
643 Length,
644 ToolCalls,
645 ContentFilter,
646 Error(String),
647}
648
649#[async_trait]
651pub trait LLMProvider: Send + Sync {
652 fn name(&self) -> &str;
654
655 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError>;
657
658 async fn stream(
660 &self,
661 request: LLMRequest,
662 ) -> Result<Box<dyn futures::Stream<Item = LLMResponse> + Unpin + Send>, LLMError> {
663 let response = self.generate(request).await?;
665 Ok(Box::new(futures::stream::once(async { response }).boxed()))
666 }
667
668 fn supported_models(&self) -> Vec<String>;
670
671 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError>;
673}
674
675#[derive(Debug, thiserror::Error)]
676pub enum LLMError {
677 #[error("Authentication failed: {0}")]
678 Authentication(String),
679 #[error("Rate limit exceeded")]
680 RateLimit,
681 #[error("Invalid request: {0}")]
682 InvalidRequest(String),
683 #[error("Network error: {0}")]
684 Network(String),
685 #[error("Provider error: {0}")]
686 Provider(String),
687}
688
689impl From<LLMError> for crate::llm::types::LLMError {
691 fn from(err: LLMError) -> crate::llm::types::LLMError {
692 match err {
693 LLMError::Authentication(msg) => crate::llm::types::LLMError::ApiError(msg),
694 LLMError::RateLimit => crate::llm::types::LLMError::RateLimit,
695 LLMError::InvalidRequest(msg) => crate::llm::types::LLMError::InvalidRequest(msg),
696 LLMError::Network(msg) => crate::llm::types::LLMError::NetworkError(msg),
697 LLMError::Provider(msg) => crate::llm::types::LLMError::ApiError(msg),
698 }
699 }
700}
701
702use futures::StreamExt;