1use anyhow::{anyhow, Result};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use tracing::{info, warn};
15
16use crate::streaming::{
17 CallbackStream, ContentDelta, MessageStream, OnChunkCallback, StreamEvent, StreamProvider,
18};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum LlmProvider {
23 Anthropic,
24 OpenAI,
25 Gemini,
26 AzureOpenAI,
27 Bedrock,
28 Ollama,
29 OpenAICompatible {
31 base_url: String,
32 },
33 Custom(String),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct LlmRequestConfig {
39 pub model: String,
41 pub max_tokens: u32,
43 pub temperature: f32,
45 pub system_prompt: Option<String>,
47 pub stop_sequences: Vec<String>,
49}
50
51impl Default for LlmRequestConfig {
52 fn default() -> Self {
53 Self {
54 model: "claude-sonnet-4-6".to_string(),
55 max_tokens: 4096,
56 temperature: 0.7,
57 system_prompt: None,
58 stop_sequences: vec!["\n\n\n".to_string()],
59 }
60 }
61}
62
63#[derive(Debug, Serialize, Deserialize)]
65pub struct LlmResponse {
66 pub content: String,
68 pub usage: TokenUsage,
70 pub model: String,
72 pub response_id: String,
74}
75
76#[derive(Debug, Serialize, Deserialize)]
78pub struct TokenUsage {
79 pub input_tokens: u32,
81 pub output_tokens: u32,
83}
84
85#[async_trait]
87pub trait LlmClientTrait {
88 async fn send(&self, messages: Vec<Message>, config: &LlmRequestConfig) -> Result<LlmResponse>;
90
91 async fn send_stream(
93 &self,
94 messages: Vec<Message>,
95 config: &LlmRequestConfig,
96 ) -> Result<MessageStream>;
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct Message {
102 pub role: MessageRole,
103 pub content: String,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum MessageRole {
109 User,
110 Assistant,
111 System,
112}
113
114pub struct LlmClient {
116 client: Client,
118 api_key: String,
120 provider: LlmProvider,
122 base_url: String,
124}
125
126impl LlmClient {
127 pub fn new(provider: LlmProvider, api_key: String) -> Self {
128 let base_url = match &provider {
129 LlmProvider::Anthropic => "https://api.anthropic.com/v1".to_string(),
130 LlmProvider::OpenAI => "https://api.openai.com/v1".to_string(),
131 LlmProvider::Gemini => "https://generativelanguage.googleapis.com/v1".to_string(),
132 LlmProvider::AzureOpenAI => "https://YOUR_RESOURCE.openai.azure.com".to_string(),
133 LlmProvider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com".to_string(),
134 LlmProvider::Ollama => "http://localhost:11434".to_string(),
135 LlmProvider::OpenAICompatible { base_url } => base_url.clone(),
136 LlmProvider::Custom(url) => url.clone(),
137 };
138
139 Self {
140 client: Client::new(),
141 api_key,
142 provider,
143 base_url,
144 }
145 }
146
147 pub fn with_base_url(mut self, base_url: String) -> Self {
149 self.base_url = base_url;
150 self
151 }
152
153 pub async fn send_stream_with_callback(
155 &self,
156 messages: Vec<Message>,
157 config: &LlmRequestConfig,
158 on_chunk: OnChunkCallback,
159 ) -> Result<LlmResponse> {
160 let message_stream = self.send_stream(messages, config).await?;
161 let mut callback_stream = CallbackStream::new(message_stream, Some(on_chunk));
162
163 let mut content = String::new();
164 let mut input_tokens = 0u32;
165 let mut output_tokens = 0u32;
166 let mut message_id = String::new();
167 let mut model = config.model.clone();
168
169 while let Some(event) = callback_stream.next_event().await? {
170 match event {
171 StreamEvent::MessageStart { id, model: m } => {
172 message_id = id;
173 model = m;
174 }
175 StreamEvent::ContentBlockDelta {
176 delta: ContentDelta::Text(t),
177 ..
178 } => {
179 content.push_str(&t);
180 }
181 StreamEvent::ContentBlockDelta { .. } => {}
182 StreamEvent::MessageDelta { usage, .. } => {
183 input_tokens = usage.input_tokens;
184 output_tokens = usage.output_tokens;
185 }
186 _ => {}
187 }
188 }
189
190 Ok(LlmResponse {
191 content,
192 usage: TokenUsage {
193 input_tokens,
194 output_tokens,
195 },
196 model,
197 response_id: message_id,
198 })
199 }
200
201 pub async fn send_stream_abortable(
203 &self,
204 messages: Vec<Message>,
205 config: &LlmRequestConfig,
206 abort_flag: Arc<AtomicBool>,
207 ) -> Result<LlmResponse> {
208 let message_stream = self.send_stream(messages, config).await?;
209 let mut callback_stream = CallbackStream::new(message_stream, None);
210
211 let mut content = String::new();
212 let mut input_tokens = 0u32;
213 let mut output_tokens = 0u32;
214 let mut message_id = String::new();
215 let mut model = config.model.clone();
216
217 while !abort_flag.load(Ordering::Relaxed) {
218 match callback_stream.next_event().await {
219 Ok(Some(event)) => match event {
220 StreamEvent::MessageStart { id, model: m } => {
221 message_id = id;
222 model = m;
223 }
224 StreamEvent::ContentBlockDelta {
225 delta: ContentDelta::Text(t),
226 ..
227 } => {
228 content.push_str(&t);
229 }
230 StreamEvent::ContentBlockDelta { .. } => {}
231 StreamEvent::MessageDelta { usage, .. } => {
232 input_tokens = usage.input_tokens;
233 output_tokens = usage.output_tokens;
234 }
235 StreamEvent::MessageStop => {
236 break;
237 }
238 _ => {}
239 },
240 Ok(None) => break,
241 Err(e) => {
242 if abort_flag.load(Ordering::Relaxed) {
243 info!("Stream aborted by user");
244 break;
245 }
246 return Err(e);
247 }
248 }
249 }
250
251 if abort_flag.load(Ordering::Relaxed) {
252 info!("Stream was aborted");
253 }
254
255 Ok(LlmResponse {
256 content,
257 usage: TokenUsage {
258 input_tokens,
259 output_tokens,
260 },
261 model,
262 response_id: message_id,
263 })
264 }
265
266 pub async fn send_with_retry(
268 &self,
269 messages: Vec<Message>,
270 config: &LlmRequestConfig,
271 max_retries: u32,
272 ) -> Result<LlmResponse> {
273 let mut attempts = 0;
274 let mut last_error: Option<anyhow::Error> = None;
275
276 while attempts < max_retries {
277 attempts += 1;
278
279 match self.send(messages.clone(), config).await {
280 Ok(response) => {
281 info!("LLM request succeeded after {} attempts", attempts);
282 return Ok(response);
283 }
284 Err(e) => {
285 let error_msg = e.to_string();
286
287 if error_msg.contains("rate limit")
288 || error_msg.contains("429")
289 || error_msg.contains("overloaded")
290 || error_msg.contains("timeout")
291 {
292 warn!(
293 "LLM request failed (attempt {}/{}): {}",
294 attempts, max_retries, e
295 );
296 last_error = Some(e);
297
298 let delay = std::cmp::min(1000 * 2u64.pow(attempts - 1), 30000);
299 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
300 } else {
301 return Err(e);
302 }
303 }
304 }
305 }
306
307 Err(last_error.unwrap_or_else(|| anyhow!("Max retries exceeded")))
308 }
309
310 pub async fn send_stream_with_retry(
312 &self,
313 messages: Vec<Message>,
314 config: &LlmRequestConfig,
315 max_retries: u32,
316 ) -> Result<LlmResponse> {
317 let mut attempts = 0;
318 let mut last_error: Option<anyhow::Error> = None;
319
320 while attempts < max_retries {
321 attempts += 1;
322
323 match self
324 .send_stream_with_callback(messages.clone(), config, Box::new(|_| {}))
325 .await
326 {
327 Ok(response) => {
328 info!("Stream request succeeded after {} attempts", attempts);
329 return Ok(response);
330 }
331 Err(e) => {
332 let error_msg = e.to_string();
333
334 if error_msg.contains("rate limit")
335 || error_msg.contains("429")
336 || error_msg.contains("overloaded")
337 || error_msg.contains("timeout")
338 || error_msg.contains("aborted")
339 {
340 warn!(
341 "Stream request failed (attempt {}/{}): {}",
342 attempts, max_retries, e
343 );
344 last_error = Some(e);
345
346 let delay = std::cmp::min(1000 * 2u64.pow(attempts - 1), 30000);
347 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
348 } else {
349 return Err(e);
350 }
351 }
352 }
353 }
354
355 Err(last_error.unwrap_or_else(|| anyhow!("Max retries exceeded")))
356 }
357}
358
359#[async_trait]
360impl LlmClientTrait for LlmClient {
361 async fn send(&self, messages: Vec<Message>, config: &LlmRequestConfig) -> Result<LlmResponse> {
362 match self.provider {
363 LlmProvider::Anthropic => self.send_anthropic(messages, config).await,
364 LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. } => {
365 self.send_openai(messages, config).await
366 }
367 LlmProvider::Gemini => self.send_gemini(messages, config).await,
368 LlmProvider::AzureOpenAI => self.send_azure_openai(messages, config).await,
369 LlmProvider::Bedrock => self.send_bedrock(messages, config).await,
370 LlmProvider::Ollama => self.send_ollama(messages, config).await,
371 LlmProvider::Custom(_) => {
372 Err(anyhow!("Custom provider requires custom implementation. Use an OpenAI-compatible provider instead."))
373 }
374 }
375 }
376
377 async fn send_stream(
378 &self,
379 messages: Vec<Message>,
380 config: &LlmRequestConfig,
381 ) -> Result<MessageStream> {
382 match self.provider {
383 LlmProvider::Anthropic => self.stream_anthropic(messages, config).await,
384 LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. } => {
385 self.stream_openai(messages, config).await
386 }
387 LlmProvider::Gemini => self.stream_gemini(messages, config).await,
388 LlmProvider::AzureOpenAI => self.stream_azure_openai(messages, config).await,
389 LlmProvider::Bedrock => self.stream_bedrock(messages, config).await,
390 LlmProvider::Ollama => self.stream_ollama(messages, config).await,
391 LlmProvider::Custom(_) => Err(anyhow!("Custom provider does not support streaming. Use an OpenAI-compatible provider instead.")),
392 }
393 }
394}
395
396impl LlmClient {
397 async fn send_anthropic(
398 &self,
399 messages: Vec<Message>,
400 config: &LlmRequestConfig,
401 ) -> Result<LlmResponse> {
402 let url = if self.base_url.ends_with("/v1") || self.base_url.ends_with("/v1/") {
403 format!("{}/messages", self.base_url.trim_end_matches('/'))
404 } else {
405 format!("{}/v1/messages", self.base_url.trim_end_matches('/'))
406 };
407
408 let request_body = AnthropicRequest {
409 model: config.model.clone(),
410 max_tokens: config.max_tokens,
411 messages: messages
412 .into_iter()
413 .map(|m| AnthropicMessage {
414 role: match m.role {
415 MessageRole::User => "user",
416 MessageRole::Assistant => "assistant",
417 MessageRole::System => "system",
418 },
419 content: AnthropicContent::Text(m.content),
420 })
421 .collect(),
422 system: config.system_prompt.clone(),
423 temperature: config.temperature,
424 };
425
426 let response = self
427 .client
428 .post(&url)
429 .header("x-api-key", &self.api_key)
430 .header("anthropic-version", "2023-06-01")
431 .json(&request_body)
432 .send()
433 .await?;
434
435 let response_text = response.text().await?;
436 tracing::debug!("Anthropic API response: {}", response_text);
437
438 let response_body: AnthropicResponse = serde_json::from_str(&response_text)?;
439
440 Ok(LlmResponse {
441 content: response_body
442 .content
443 .first()
444 .map(|c| c.text.clone())
445 .unwrap_or_default(),
446 usage: TokenUsage {
447 input_tokens: response_body.usage.input_tokens,
448 output_tokens: response_body.usage.output_tokens,
449 },
450 model: response_body.model,
451 response_id: response_body.id,
452 })
453 }
454
455 async fn stream_anthropic(
456 &self,
457 messages: Vec<Message>,
458 config: &LlmRequestConfig,
459 ) -> Result<MessageStream> {
460 let url = if self.base_url.ends_with("/v1") || self.base_url.ends_with("/v1/") {
461 format!("{}/messages", self.base_url.trim_end_matches('/'))
462 } else {
463 format!("{}/v1/messages", self.base_url.trim_end_matches('/'))
464 };
465
466 let request_body = AnthropicStreamRequest {
467 model: config.model.clone(),
468 max_tokens: config.max_tokens,
469 messages: messages
470 .into_iter()
471 .map(|m| AnthropicMessage {
472 role: match m.role {
473 MessageRole::User => "user",
474 MessageRole::Assistant => "assistant",
475 MessageRole::System => "system",
476 },
477 content: AnthropicContent::Text(m.content),
478 })
479 .collect(),
480 system: config.system_prompt.clone(),
481 temperature: config.temperature,
482 stream: true,
483 };
484
485 let response = self
486 .client
487 .post(&url)
488 .header("x-api-key", &self.api_key)
489 .header("anthropic-version", "2023-06-01")
490 .header("Accept", "text/event-stream")
491 .json(&request_body)
492 .send()
493 .await?;
494
495 let status = response.status();
496 if !status.is_success() {
497 let error_text = response.text().await?;
498 return Err(anyhow!("Anthropic API error {}: {}", status, error_text));
499 }
500
501 Ok(MessageStream::new(
502 response,
503 StreamProvider::Anthropic,
504 config.model.clone(),
505 ))
506 }
507
508 async fn send_openai(
509 &self,
510 messages: Vec<Message>,
511 config: &LlmRequestConfig,
512 ) -> Result<LlmResponse> {
513 let url = format!("{}/chat/completions", self.base_url);
514
515 let mut openai_messages: Vec<OpenAiMessage> = Vec::new();
516
517 if let Some(ref system) = config.system_prompt {
518 openai_messages.push(OpenAiMessage {
519 role: "system",
520 content: system.clone(),
521 });
522 }
523
524 for m in messages {
525 openai_messages.push(OpenAiMessage {
526 role: match m.role {
527 MessageRole::User => "user",
528 MessageRole::Assistant => "assistant",
529 MessageRole::System => "system",
530 },
531 content: m.content,
532 });
533 }
534
535 let request_body = OpenAiRequest {
536 model: config.model.clone(),
537 messages: openai_messages,
538 max_tokens: Some(config.max_tokens),
539 temperature: Some(config.temperature),
540 stop: if config.stop_sequences.is_empty() {
541 None
542 } else {
543 Some(config.stop_sequences.clone())
544 },
545 };
546
547 let response = self
548 .client
549 .post(&url)
550 .header("Authorization", format!("Bearer {}", self.api_key))
551 .json(&request_body)
552 .send()
553 .await?;
554
555 let response_body: OpenAiResponse = response.json().await?;
556
557 let choice = response_body
558 .choices
559 .first()
560 .ok_or_else(|| anyhow!("No response choices"))?;
561
562 Ok(LlmResponse {
563 content: choice.message.content.clone(),
564 usage: TokenUsage {
565 input_tokens: response_body.usage.prompt_tokens,
566 output_tokens: response_body.usage.completion_tokens,
567 },
568 model: response_body.model,
569 response_id: response_body.id,
570 })
571 }
572
573 async fn stream_openai(
574 &self,
575 messages: Vec<Message>,
576 config: &LlmRequestConfig,
577 ) -> Result<MessageStream> {
578 let url = format!("{}/chat/completions", self.base_url);
579
580 let mut openai_messages: Vec<OpenAiMessage> = Vec::new();
581 if let Some(ref system) = config.system_prompt {
582 openai_messages.push(OpenAiMessage {
583 role: "system",
584 content: system.clone(),
585 });
586 }
587 for m in messages {
588 openai_messages.push(OpenAiMessage {
589 role: match m.role {
590 MessageRole::User => "user",
591 MessageRole::Assistant => "assistant",
592 MessageRole::System => "system",
593 },
594 content: m.content,
595 });
596 }
597
598 let request_body = OpenAiStreamRequest {
599 model: config.model.clone(),
600 messages: openai_messages,
601 max_tokens: Some(config.max_tokens),
602 temperature: Some(config.temperature),
603 stream: true,
604 };
605
606 let response = self
607 .client
608 .post(&url)
609 .header("Authorization", format!("Bearer {}", self.api_key))
610 .header("Accept", "text/event-stream")
611 .json(&request_body)
612 .send()
613 .await?;
614
615 let status = response.status();
616 if !status.is_success() {
617 let error_text = response.text().await?;
618 return Err(anyhow!("OpenAI API error {}: {}", status, error_text));
619 }
620
621 Ok(MessageStream::new(
622 response,
623 StreamProvider::OpenAI,
624 config.model.clone(),
625 ))
626 }
627
628 async fn send_gemini(
629 &self,
630 messages: Vec<Message>,
631 config: &LlmRequestConfig,
632 ) -> Result<LlmResponse> {
633 let url = format!(
634 "{}/models/{}:generateContent?key={}",
635 self.base_url, config.model, self.api_key
636 );
637
638 let mut contents: Vec<GeminiContent> = Vec::new();
639 let system_instruction = config.system_prompt.clone();
640
641 for m in messages {
642 contents.push(GeminiContent {
643 role: match m.role {
644 MessageRole::User => "user".to_string(),
645 MessageRole::Assistant => "model".to_string(),
646 MessageRole::System => "user".to_string(),
647 },
648 parts: vec![GeminiPart { text: m.content }],
649 });
650 }
651
652 let request_body = GeminiRequest {
653 contents,
654 generation_config: Some(GeminiGenerationConfig {
655 max_output_tokens: Some(config.max_tokens),
656 temperature: Some(config.temperature),
657 stop_sequences: if config.stop_sequences.is_empty() {
658 None
659 } else {
660 Some(config.stop_sequences.clone())
661 },
662 }),
663 system_instruction: system_instruction.map(|s| GeminiSystemInstruction {
664 parts: vec![GeminiPart { text: s }],
665 }),
666 };
667
668 let response = self.client.post(&url).json(&request_body).send().await?;
669
670 let response_body: GeminiResponse = response.json().await?;
671
672 let candidate = response_body
673 .candidates
674 .first()
675 .ok_or_else(|| anyhow!("No response candidates"))?;
676
677 let content = candidate
678 .content
679 .parts
680 .first()
681 .map(|p| p.text.clone())
682 .unwrap_or_default();
683
684 Ok(LlmResponse {
685 content,
686 usage: TokenUsage {
687 input_tokens: response_body.usage_metadata.prompt_token_count.unwrap_or(0),
688 output_tokens: response_body
689 .usage_metadata
690 .candidates_token_count
691 .unwrap_or(0),
692 },
693 model: config.model.clone(),
694 response_id: "".to_string(),
695 })
696 }
697
698 async fn stream_gemini(
699 &self,
700 messages: Vec<Message>,
701 config: &LlmRequestConfig,
702 ) -> Result<MessageStream> {
703 let url = format!(
704 "{}/models/{}:streamGenerateContent?key={}&alt=sse",
705 self.base_url, config.model, self.api_key
706 );
707
708 let mut contents: Vec<GeminiContent> = Vec::new();
709 let system_instruction = config.system_prompt.clone();
710
711 for m in messages {
712 contents.push(GeminiContent {
713 role: match m.role {
714 MessageRole::User => "user".to_string(),
715 MessageRole::Assistant => "model".to_string(),
716 MessageRole::System => "user".to_string(),
717 },
718 parts: vec![GeminiPart { text: m.content }],
719 });
720 }
721
722 let request_body = GeminiRequest {
723 contents,
724 generation_config: Some(GeminiGenerationConfig {
725 max_output_tokens: Some(config.max_tokens),
726 temperature: Some(config.temperature),
727 stop_sequences: if config.stop_sequences.is_empty() {
728 None
729 } else {
730 Some(config.stop_sequences.clone())
731 },
732 }),
733 system_instruction: system_instruction.map(|s| GeminiSystemInstruction {
734 parts: vec![GeminiPart { text: s }],
735 }),
736 };
737
738 let response = self.client.post(&url).json(&request_body).send().await?;
739
740 let status = response.status();
741 if !status.is_success() {
742 let error_text = response.text().await?;
743 return Err(anyhow!("Gemini API error {}: {}", status, error_text));
744 }
745
746 Ok(MessageStream::new(
747 response,
748 StreamProvider::Gemini,
749 config.model.clone(),
750 ))
751 }
752
753 async fn send_azure_openai(
758 &self,
759 messages: Vec<Message>,
760 config: &LlmRequestConfig,
761 ) -> Result<LlmResponse> {
762 let deployment = &config.model;
765 let url = format!(
766 "{}/openai/deployments/{}/chat/completions?api-version=2024-02-15-preview",
767 self.base_url, deployment
768 );
769
770 let mut azure_messages: Vec<OpenAiMessage> = Vec::new();
771 if let Some(ref system) = config.system_prompt {
772 azure_messages.push(OpenAiMessage {
773 role: "system",
774 content: system.clone(),
775 });
776 }
777 for m in messages {
778 azure_messages.push(OpenAiMessage {
779 role: match m.role {
780 MessageRole::User => "user",
781 MessageRole::Assistant => "assistant",
782 MessageRole::System => "system",
783 },
784 content: m.content,
785 });
786 }
787
788 let request_body = OpenAiRequest {
789 model: deployment.clone(), messages: azure_messages,
791 max_tokens: Some(config.max_tokens),
792 temperature: Some(config.temperature),
793 stop: if config.stop_sequences.is_empty() {
794 None
795 } else {
796 Some(config.stop_sequences.clone())
797 },
798 };
799
800 let response = self
801 .client
802 .post(&url)
803 .header("api-key", &self.api_key) .json(&request_body)
805 .send()
806 .await?;
807
808 let status = response.status();
809 if !status.is_success() {
810 let error_text = response.text().await?;
811 return Err(anyhow!("Azure OpenAI API error {}: {}", status, error_text));
812 }
813
814 let response_body: OpenAiResponse = response.json().await?;
815
816 let choice = response_body
817 .choices
818 .first()
819 .ok_or_else(|| anyhow!("No response choices"))?;
820
821 Ok(LlmResponse {
822 content: choice.message.content.clone(),
823 usage: TokenUsage {
824 input_tokens: response_body.usage.prompt_tokens,
825 output_tokens: response_body.usage.completion_tokens,
826 },
827 model: response_body.model,
828 response_id: response_body.id,
829 })
830 }
831
832 async fn stream_azure_openai(
833 &self,
834 messages: Vec<Message>,
835 config: &LlmRequestConfig,
836 ) -> Result<MessageStream> {
837 let deployment = &config.model;
838 let url = format!(
839 "{}/openai/deployments/{}/chat/completions?api-version=2024-02-15-preview",
840 self.base_url, deployment
841 );
842
843 let mut azure_messages: Vec<OpenAiMessage> = Vec::new();
844 if let Some(ref system) = config.system_prompt {
845 azure_messages.push(OpenAiMessage {
846 role: "system",
847 content: system.clone(),
848 });
849 }
850 for m in messages {
851 azure_messages.push(OpenAiMessage {
852 role: match m.role {
853 MessageRole::User => "user",
854 MessageRole::Assistant => "assistant",
855 MessageRole::System => "system",
856 },
857 content: m.content,
858 });
859 }
860
861 let request_body = OpenAiStreamRequest {
862 model: deployment.clone(),
863 messages: azure_messages,
864 max_tokens: Some(config.max_tokens),
865 temperature: Some(config.temperature),
866 stream: true,
867 };
868
869 let response = self
870 .client
871 .post(&url)
872 .header("api-key", &self.api_key)
873 .header("Accept", "text/event-stream")
874 .json(&request_body)
875 .send()
876 .await?;
877
878 let status = response.status();
879 if !status.is_success() {
880 let error_text = response.text().await?;
881 return Err(anyhow!("Azure OpenAI API error {}: {}", status, error_text));
882 }
883
884 Ok(MessageStream::new(
885 response,
886 StreamProvider::AzureOpenAI,
887 config.model.clone(),
888 ))
889 }
890
891 async fn send_bedrock(
896 &self,
897 messages: Vec<Message>,
898 config: &LlmRequestConfig,
899 ) -> Result<LlmResponse> {
900 let model_id = &config.model;
903 let url = format!("{}/model/{}/invoke", self.base_url, model_id);
904
905 let mut bedrock_messages: Vec<BedrockMessage> = Vec::new();
907 for m in messages {
908 bedrock_messages.push(BedrockMessage {
909 role: match m.role {
910 MessageRole::User => "user",
911 MessageRole::Assistant => "assistant",
912 MessageRole::System => "system",
913 },
914 content: vec![BedrockContent { text: m.content }],
915 });
916 }
917
918 let request_body = BedrockRequest {
919 messages: bedrock_messages,
920 system: config.system_prompt.clone(),
921 inference_config: Some(BedrockInferenceConfig {
922 max_tokens: config.max_tokens,
923 temperature: config.temperature,
924 top_p: None,
925 stop_sequences: if config.stop_sequences.is_empty() {
926 None
927 } else {
928 Some(config.stop_sequences.clone())
929 },
930 }),
931 };
932
933 let response = self
934 .client
935 .post(&url)
936 .header("Authorization", format!("Bearer {}", self.api_key))
937 .header("Content-Type", "application/json")
938 .json(&request_body)
939 .send()
940 .await?;
941
942 let status = response.status();
943 if !status.is_success() {
944 let error_text = response.text().await?;
945 return Err(anyhow!("Bedrock API error {}: {}", status, error_text));
946 }
947
948 let response_body: BedrockResponse = response.json().await?;
949
950 let content = response_body
951 .output
952 .message
953 .content
954 .first()
955 .map(|c| c.text.clone())
956 .unwrap_or_default();
957
958 Ok(LlmResponse {
959 content,
960 usage: TokenUsage {
961 input_tokens: response_body.usage.input_tokens,
962 output_tokens: response_body.usage.output_tokens,
963 },
964 model: config.model.clone(),
965 response_id: response_body.request_id.unwrap_or_default(),
966 })
967 }
968
969 async fn stream_bedrock(
970 &self,
971 messages: Vec<Message>,
972 config: &LlmRequestConfig,
973 ) -> Result<MessageStream> {
974 let model_id = &config.model;
975 let url = format!(
976 "{}/model/{}/invoke-with-response-stream",
977 self.base_url, model_id
978 );
979
980 let mut bedrock_messages: Vec<BedrockMessage> = Vec::new();
981 for m in messages {
982 bedrock_messages.push(BedrockMessage {
983 role: match m.role {
984 MessageRole::User => "user",
985 MessageRole::Assistant => "assistant",
986 MessageRole::System => "system",
987 },
988 content: vec![BedrockContent { text: m.content }],
989 });
990 }
991
992 let request_body = BedrockRequest {
993 messages: bedrock_messages,
994 system: config.system_prompt.clone(),
995 inference_config: Some(BedrockInferenceConfig {
996 max_tokens: config.max_tokens,
997 temperature: config.temperature,
998 top_p: None,
999 stop_sequences: if config.stop_sequences.is_empty() {
1000 None
1001 } else {
1002 Some(config.stop_sequences.clone())
1003 },
1004 }),
1005 };
1006
1007 let response = self
1008 .client
1009 .post(&url)
1010 .header("Authorization", format!("Bearer {}", self.api_key))
1011 .header("Accept", "text/event-stream")
1012 .header("Content-Type", "application/json")
1013 .json(&request_body)
1014 .send()
1015 .await?;
1016
1017 let status = response.status();
1018 if !status.is_success() {
1019 let error_text = response.text().await?;
1020 return Err(anyhow!("Bedrock API error {}: {}", status, error_text));
1021 }
1022
1023 Ok(MessageStream::new(
1024 response,
1025 StreamProvider::Bedrock,
1026 config.model.clone(),
1027 ))
1028 }
1029
1030 async fn send_ollama(
1035 &self,
1036 messages: Vec<Message>,
1037 config: &LlmRequestConfig,
1038 ) -> Result<LlmResponse> {
1039 let url = format!("{}/api/chat", self.base_url);
1041
1042 let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
1043 if let Some(ref system) = config.system_prompt {
1044 ollama_messages.push(OllamaMessage {
1045 role: "system",
1046 content: system.clone(),
1047 });
1048 }
1049 for m in messages {
1050 ollama_messages.push(OllamaMessage {
1051 role: match m.role {
1052 MessageRole::User => "user",
1053 MessageRole::Assistant => "assistant",
1054 MessageRole::System => "system",
1055 },
1056 content: m.content,
1057 });
1058 }
1059
1060 let request_body = OllamaChatRequest {
1061 model: config.model.clone(),
1062 messages: ollama_messages,
1063 stream: false,
1064 options: Some(OllamaOptions {
1065 num_predict: config.max_tokens as i32,
1066 temperature: config.temperature,
1067 stop: if config.stop_sequences.is_empty() {
1068 None
1069 } else {
1070 Some(config.stop_sequences.clone())
1071 },
1072 }),
1073 };
1074
1075 let response = self
1077 .client
1078 .post(&url)
1079 .header("Content-Type", "application/json")
1080 .json(&request_body)
1081 .send()
1082 .await?;
1083
1084 let status = response.status();
1085 if !status.is_success() {
1086 let error_text = response.text().await?;
1087 return Err(anyhow!("Ollama API error {}: {}", status, error_text));
1088 }
1089
1090 let response_body: OllamaChatResponse = response.json().await?;
1091
1092 Ok(LlmResponse {
1093 content: response_body.message.content,
1094 usage: TokenUsage {
1095 input_tokens: response_body.prompt_eval_count.unwrap_or(0),
1096 output_tokens: response_body.eval_count.unwrap_or(0),
1097 },
1098 model: response_body.model,
1099 response_id: "".to_string(),
1100 })
1101 }
1102
1103 async fn stream_ollama(
1104 &self,
1105 messages: Vec<Message>,
1106 config: &LlmRequestConfig,
1107 ) -> Result<MessageStream> {
1108 let url = format!("{}/api/chat", self.base_url);
1109
1110 let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
1111 if let Some(ref system) = config.system_prompt {
1112 ollama_messages.push(OllamaMessage {
1113 role: "system",
1114 content: system.clone(),
1115 });
1116 }
1117 for m in messages {
1118 ollama_messages.push(OllamaMessage {
1119 role: match m.role {
1120 MessageRole::User => "user",
1121 MessageRole::Assistant => "assistant",
1122 MessageRole::System => "system",
1123 },
1124 content: m.content,
1125 });
1126 }
1127
1128 let request_body = OllamaChatRequest {
1129 model: config.model.clone(),
1130 messages: ollama_messages,
1131 stream: true,
1132 options: Some(OllamaOptions {
1133 num_predict: config.max_tokens as i32,
1134 temperature: config.temperature,
1135 stop: if config.stop_sequences.is_empty() {
1136 None
1137 } else {
1138 Some(config.stop_sequences.clone())
1139 },
1140 }),
1141 };
1142
1143 let response = self
1144 .client
1145 .post(&url)
1146 .header("Accept", "application/json")
1147 .header("Content-Type", "application/json")
1148 .json(&request_body)
1149 .send()
1150 .await?;
1151
1152 let status = response.status();
1153 if !status.is_success() {
1154 let error_text = response.text().await?;
1155 return Err(anyhow!("Ollama API error {}: {}", status, error_text));
1156 }
1157
1158 Ok(MessageStream::new(
1159 response,
1160 StreamProvider::Ollama,
1161 config.model.clone(),
1162 ))
1163 }
1164}
1165
1166#[derive(Serialize)]
1168struct AnthropicRequest {
1169 model: String,
1170 max_tokens: u32,
1171 messages: Vec<AnthropicMessage>,
1172 system: Option<String>,
1173 temperature: f32,
1174}
1175
1176#[derive(Serialize)]
1177struct AnthropicStreamRequest {
1178 model: String,
1179 max_tokens: u32,
1180 messages: Vec<AnthropicMessage>,
1181 system: Option<String>,
1182 temperature: f32,
1183 stream: bool,
1184}
1185
1186#[derive(Serialize)]
1187struct AnthropicMessage {
1188 role: &'static str,
1189 content: AnthropicContent,
1190}
1191
1192#[derive(Serialize)]
1193#[serde(untagged)]
1194#[allow(dead_code)]
1195enum AnthropicContent {
1196 Text(String),
1197 Blocks(Vec<AnthropicContentBlock>),
1198}
1199
1200#[derive(Serialize)]
1201struct AnthropicContentBlock {
1202 #[serde(rename = "type")]
1203 content_type: String,
1204 text: String,
1205}
1206
1207#[derive(Deserialize)]
1208#[allow(dead_code)]
1209struct AnthropicResponse {
1210 #[serde(default)]
1211 id: String,
1212 #[serde(default)]
1213 model: String,
1214 #[serde(default)]
1215 content: Vec<AnthropicContentResponse>,
1216 #[serde(default)]
1217 usage: AnthropicUsage,
1218 #[serde(default)]
1219 #[serde(rename = "type")]
1220 response_type: Option<String>,
1221 #[serde(default)]
1222 role: Option<String>,
1223 #[serde(default)]
1224 stop_reason: Option<String>,
1225}
1226
1227#[derive(Deserialize)]
1228#[allow(dead_code)]
1229struct AnthropicContentResponse {
1230 #[serde(rename = "type", default)]
1231 content_type: String,
1232 #[serde(default)]
1233 text: String,
1234}
1235
1236#[derive(Deserialize, Default)]
1237struct AnthropicUsage {
1238 #[serde(default)]
1239 input_tokens: u32,
1240 #[serde(default)]
1241 output_tokens: u32,
1242}
1243
1244#[derive(Serialize)]
1246struct OpenAiRequest {
1247 model: String,
1248 messages: Vec<OpenAiMessage>,
1249 #[serde(skip_serializing_if = "Option::is_none")]
1250 max_tokens: Option<u32>,
1251 #[serde(skip_serializing_if = "Option::is_none")]
1252 temperature: Option<f32>,
1253 #[serde(skip_serializing_if = "Option::is_none")]
1254 stop: Option<Vec<String>>,
1255}
1256
1257#[derive(Serialize)]
1258struct OpenAiStreamRequest {
1259 model: String,
1260 messages: Vec<OpenAiMessage>,
1261 #[serde(skip_serializing_if = "Option::is_none")]
1262 max_tokens: Option<u32>,
1263 #[serde(skip_serializing_if = "Option::is_none")]
1264 temperature: Option<f32>,
1265 stream: bool,
1266}
1267
1268#[derive(Serialize)]
1269struct OpenAiMessage {
1270 role: &'static str,
1271 content: String,
1272}
1273
1274#[derive(Deserialize)]
1275struct OpenAiResponse {
1276 id: String,
1277 model: String,
1278 choices: Vec<OpenAiChoice>,
1279 usage: OpenAiUsage,
1280}
1281
1282#[derive(Deserialize)]
1283#[allow(dead_code)]
1284struct OpenAiChoice {
1285 message: OpenAiResponseMessage,
1286 finish_reason: String,
1287}
1288
1289#[derive(Deserialize)]
1290#[allow(dead_code)]
1291struct OpenAiResponseMessage {
1292 role: String,
1293 content: String,
1294}
1295
1296#[derive(Deserialize)]
1297#[allow(dead_code)]
1298struct OpenAiUsage {
1299 prompt_tokens: u32,
1300 completion_tokens: u32,
1301 total_tokens: u32,
1302}
1303
1304#[derive(Serialize)]
1306struct GeminiRequest {
1307 contents: Vec<GeminiContent>,
1308 #[serde(skip_serializing_if = "Option::is_none")]
1309 generation_config: Option<GeminiGenerationConfig>,
1310 #[serde(skip_serializing_if = "Option::is_none")]
1311 system_instruction: Option<GeminiSystemInstruction>,
1312}
1313
1314#[derive(Serialize)]
1315struct GeminiContent {
1316 role: String,
1317 parts: Vec<GeminiPart>,
1318}
1319
1320#[derive(Serialize)]
1321struct GeminiPart {
1322 text: String,
1323}
1324
1325#[derive(Serialize)]
1326struct GeminiGenerationConfig {
1327 #[serde(skip_serializing_if = "Option::is_none")]
1328 max_output_tokens: Option<u32>,
1329 #[serde(skip_serializing_if = "Option::is_none")]
1330 temperature: Option<f32>,
1331 #[serde(skip_serializing_if = "Option::is_none")]
1332 stop_sequences: Option<Vec<String>>,
1333}
1334
1335#[derive(Serialize)]
1336struct GeminiSystemInstruction {
1337 parts: Vec<GeminiPart>,
1338}
1339
1340#[derive(Deserialize)]
1341struct GeminiResponse {
1342 candidates: Vec<GeminiCandidate>,
1343 usage_metadata: GeminiUsageMetadata,
1344}
1345
1346#[derive(Deserialize)]
1347#[allow(dead_code)]
1348struct GeminiCandidate {
1349 content: GeminiContentResponse,
1350 finish_reason: String,
1351}
1352
1353#[derive(Deserialize)]
1354#[allow(dead_code)]
1355struct GeminiContentResponse {
1356 parts: Vec<GeminiPartResponse>,
1357 role: String,
1358}
1359
1360#[derive(Deserialize)]
1361struct GeminiPartResponse {
1362 text: String,
1363}
1364
1365#[derive(Deserialize)]
1366#[allow(dead_code)]
1367struct GeminiUsageMetadata {
1368 prompt_token_count: Option<u32>,
1369 candidates_token_count: Option<u32>,
1370 total_token_count: Option<u32>,
1371}
1372
1373#[derive(Serialize)]
1378struct BedrockRequest {
1379 messages: Vec<BedrockMessage>,
1380 #[serde(skip_serializing_if = "Option::is_none")]
1381 system: Option<String>,
1382 #[serde(skip_serializing_if = "Option::is_none")]
1383 inference_config: Option<BedrockInferenceConfig>,
1384}
1385
1386#[derive(Serialize)]
1387struct BedrockMessage {
1388 role: &'static str,
1389 content: Vec<BedrockContent>,
1390}
1391
1392#[derive(Serialize)]
1393struct BedrockContent {
1394 text: String,
1395}
1396
1397#[derive(Serialize)]
1398struct BedrockInferenceConfig {
1399 #[serde(rename = "maxTokens")]
1400 max_tokens: u32,
1401 temperature: f32,
1402 #[serde(skip_serializing_if = "Option::is_none")]
1403 top_p: Option<f32>,
1404 #[serde(skip_serializing_if = "Option::is_none")]
1405 stop_sequences: Option<Vec<String>>,
1406}
1407
1408#[derive(Deserialize)]
1409#[allow(dead_code)]
1410struct BedrockResponse {
1411 output: BedrockOutput,
1412 usage: BedrockUsage,
1413 #[serde(default)]
1414 request_id: Option<String>,
1415}
1416
1417#[derive(Deserialize)]
1418struct BedrockOutput {
1419 message: BedrockResponseMessage,
1420}
1421
1422#[derive(Deserialize)]
1423struct BedrockResponseMessage {
1424 content: Vec<BedrockResponseContent>,
1425}
1426
1427#[derive(Deserialize)]
1428struct BedrockResponseContent {
1429 text: String,
1430}
1431
1432#[derive(Deserialize)]
1433struct BedrockUsage {
1434 #[serde(default)]
1435 input_tokens: u32,
1436 #[serde(default)]
1437 output_tokens: u32,
1438}
1439
1440#[derive(Serialize)]
1445struct OllamaChatRequest {
1446 model: String,
1447 messages: Vec<OllamaMessage>,
1448 stream: bool,
1449 #[serde(skip_serializing_if = "Option::is_none")]
1450 options: Option<OllamaOptions>,
1451}
1452
1453#[derive(Serialize)]
1454struct OllamaMessage {
1455 role: &'static str,
1456 content: String,
1457}
1458
1459#[derive(Serialize)]
1460struct OllamaOptions {
1461 num_predict: i32,
1462 temperature: f32,
1463 #[serde(skip_serializing_if = "Option::is_none")]
1464 stop: Option<Vec<String>>,
1465}
1466
1467#[derive(Deserialize)]
1468struct OllamaChatResponse {
1469 model: String,
1470 message: OllamaResponseMessage,
1471 #[serde(default)]
1472 prompt_eval_count: Option<u32>,
1473 #[serde(default)]
1474 eval_count: Option<u32>,
1475}
1476
1477#[derive(Deserialize)]
1478struct OllamaResponseMessage {
1479 content: String,
1480}
1481
1482#[cfg(test)]
1483mod tests {
1484 use super::*;
1485
1486 #[test]
1487 fn test_default_config() {
1488 let config = LlmRequestConfig::default();
1489 assert_eq!(config.model, "claude-sonnet-4-6");
1490 assert_eq!(config.max_tokens, 4096);
1491 }
1492
1493 #[test]
1494 fn test_client_creation() {
1495 let client = LlmClient::new(LlmProvider::Anthropic, "test_key".to_string());
1496 assert_eq!(client.base_url, "https://api.anthropic.com/v1");
1497 }
1498
1499 #[test]
1500 fn test_openai_client_creation() {
1501 let client = LlmClient::new(LlmProvider::OpenAI, "test_key".to_string());
1502 assert_eq!(client.base_url, "https://api.openai.com/v1");
1503 }
1504
1505 #[test]
1506 fn test_gemini_client_creation() {
1507 let client = LlmClient::new(LlmProvider::Gemini, "test_key".to_string());
1508 assert_eq!(
1509 client.base_url,
1510 "https://generativelanguage.googleapis.com/v1"
1511 );
1512 }
1513
1514 #[test]
1515 fn test_custom_provider() {
1516 let client = LlmClient::new(
1517 LlmProvider::Custom("https://custom.api.com/v1".to_string()),
1518 "test_key".to_string(),
1519 );
1520 assert_eq!(client.base_url, "https://custom.api.com/v1");
1521 }
1522
1523 #[test]
1524 fn test_openai_compatible_provider() {
1525 let client = LlmClient::new(
1526 LlmProvider::OpenAICompatible {
1527 base_url: "https://api.deepseek.com/v1".to_string(),
1528 },
1529 "test_key".to_string(),
1530 );
1531 assert_eq!(client.base_url, "https://api.deepseek.com/v1");
1532 }
1533
1534 #[test]
1535 fn test_azure_openai_client_creation() {
1536 let client = LlmClient::new(LlmProvider::AzureOpenAI, "test_key".to_string());
1537 assert!(client.base_url.contains("openai.azure.com"));
1538 }
1539
1540 #[test]
1541 fn test_bedrock_client_creation() {
1542 let client = LlmClient::new(LlmProvider::Bedrock, "test_key".to_string());
1543 assert!(client.base_url.contains("bedrock-runtime"));
1544 }
1545
1546 #[test]
1547 fn test_ollama_client_creation() {
1548 let client = LlmClient::new(LlmProvider::Ollama, "".to_string());
1549 assert_eq!(client.base_url, "http://localhost:11434");
1550 }
1551
1552 #[test]
1553 fn test_azure_openai_with_custom_url() {
1554 let client = LlmClient::new(LlmProvider::AzureOpenAI, "test_key".to_string())
1555 .with_base_url("https://myresource.openai.azure.com".to_string());
1556 assert_eq!(client.base_url, "https://myresource.openai.azure.com");
1557 }
1558
1559 #[test]
1560 fn test_ollama_with_custom_url() {
1561 let client = LlmClient::new(LlmProvider::Ollama, "".to_string())
1562 .with_base_url("http://192.168.1.100:11434".to_string());
1563 assert_eq!(client.base_url, "http://192.168.1.100:11434");
1564 }
1565
1566 #[test]
1567 fn test_message_creation() {
1568 let message = Message {
1569 role: MessageRole::User,
1570 content: "Hello".to_string(),
1571 };
1572 assert_eq!(message.content, "Hello");
1573 }
1574
1575 #[test]
1576 fn test_config_with_system_prompt() {
1577 let config = LlmRequestConfig {
1578 model: "gpt-4".to_string(),
1579 max_tokens: 8192,
1580 temperature: 0.5,
1581 system_prompt: Some("You are a helpful assistant".to_string()),
1582 stop_sequences: vec![],
1583 };
1584 assert_eq!(config.model, "gpt-4");
1585 assert!(config.system_prompt.is_some());
1586 }
1587
1588 #[test]
1589 fn test_llm_response_creation() {
1590 let response = LlmResponse {
1591 content: "Hello".to_string(),
1592 usage: TokenUsage {
1593 input_tokens: 10,
1594 output_tokens: 5,
1595 },
1596 model: "gpt-4".to_string(),
1597 response_id: "resp_123".to_string(),
1598 };
1599 assert_eq!(response.content, "Hello");
1600 assert_eq!(response.usage.input_tokens, 10);
1601 }
1602
1603 #[test]
1604 fn test_provider_serialization() {
1605 let provider = LlmProvider::Anthropic;
1606 let json = serde_json::to_string(&provider).unwrap();
1607 assert!(json.contains("Anthropic"));
1608 }
1609
1610 #[test]
1611 fn test_message_role_serialization() {
1612 let role = MessageRole::User;
1613 let json = serde_json::to_string(&role).unwrap();
1614 assert!(json.contains("User"));
1615 }
1616}