1use anyhow::{anyhow, Result};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use tracing::{info, warn};
15
16use crate::streaming::{
17 CallbackStream, ContentDelta, MessageStream, OnChunkCallback, StreamEvent, StreamProvider,
18};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum LlmProvider {
23 Anthropic,
24 OpenAI,
25 Gemini,
26 AzureOpenAI,
27 Bedrock,
28 Ollama,
29 OpenAICompatible {
31 base_url: String,
32 },
33 AnthropicCompatible {
35 base_url: String,
36 },
37 Custom(String),
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct LlmRequestConfig {
43 pub model: String,
45 pub max_tokens: u32,
47 pub temperature: f32,
49 pub system_prompt: Option<String>,
51 pub stop_sequences: Vec<String>,
53}
54
55impl Default for LlmRequestConfig {
56 fn default() -> Self {
57 Self {
58 model: "claude-sonnet-4-6".to_string(),
59 max_tokens: 4096,
60 temperature: 0.7,
61 system_prompt: None,
62 stop_sequences: vec!["\n\n\n".to_string()],
63 }
64 }
65}
66
67#[derive(Debug, Serialize, Deserialize)]
69pub struct LlmResponse {
70 pub content: String,
72 pub usage: TokenUsage,
74 pub model: String,
76 pub response_id: String,
78}
79
80#[derive(Debug, Serialize, Deserialize)]
82pub struct TokenUsage {
83 pub input_tokens: u32,
85 pub output_tokens: u32,
87}
88
89#[async_trait]
91pub trait LlmClientTrait {
92 async fn send(&self, messages: Vec<Message>, config: &LlmRequestConfig) -> Result<LlmResponse>;
94
95 async fn send_stream(
97 &self,
98 messages: Vec<Message>,
99 config: &LlmRequestConfig,
100 ) -> Result<MessageStream>;
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct Message {
106 pub role: MessageRole,
107 pub content: String,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum MessageRole {
113 User,
114 Assistant,
115 System,
116}
117
118pub struct LlmClient {
120 client: Client,
122 api_key: String,
124 provider: LlmProvider,
126 base_url: String,
128}
129
130impl LlmClient {
131 pub fn new(provider: LlmProvider, api_key: String) -> Self {
132 let base_url = match &provider {
133 LlmProvider::Anthropic => "https://api.anthropic.com/v1".to_string(),
134 LlmProvider::OpenAI => "https://api.openai.com/v1".to_string(),
135 LlmProvider::Gemini => "https://generativelanguage.googleapis.com/v1".to_string(),
136 LlmProvider::AzureOpenAI => "https://YOUR_RESOURCE.openai.azure.com".to_string(),
137 LlmProvider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com".to_string(),
138 LlmProvider::Ollama => "http://localhost:11434".to_string(),
139 LlmProvider::OpenAICompatible { base_url } => base_url.clone(),
140 LlmProvider::AnthropicCompatible { base_url } => base_url.clone(),
141 LlmProvider::Custom(url) => url.clone(),
142 };
143
144 Self {
145 client: Client::new(),
146 api_key,
147 provider,
148 base_url,
149 }
150 }
151
152 pub fn with_base_url(mut self, base_url: String) -> Self {
154 self.base_url = base_url;
155 self
156 }
157
158 pub async fn send_stream_with_callback(
160 &self,
161 messages: Vec<Message>,
162 config: &LlmRequestConfig,
163 on_chunk: OnChunkCallback,
164 ) -> Result<LlmResponse> {
165 let message_stream = self.send_stream(messages, config).await?;
166 let mut callback_stream = CallbackStream::new(message_stream, Some(on_chunk));
167
168 let mut content = String::new();
169 let mut input_tokens = 0u32;
170 let mut output_tokens = 0u32;
171 let mut message_id = String::new();
172 let mut model = config.model.clone();
173
174 while let Some(event) = callback_stream.next_event().await? {
175 match event {
176 StreamEvent::MessageStart { id, model: m } => {
177 message_id = id;
178 model = m;
179 }
180 StreamEvent::ContentBlockDelta {
181 delta: ContentDelta::Text(t),
182 ..
183 } => {
184 content.push_str(&t);
185 }
186 StreamEvent::ContentBlockDelta { .. } => {}
187 StreamEvent::MessageDelta { usage, .. } => {
188 input_tokens = usage.input_tokens;
189 output_tokens = usage.output_tokens;
190 }
191 _ => {}
192 }
193 }
194
195 Ok(LlmResponse {
196 content,
197 usage: TokenUsage {
198 input_tokens,
199 output_tokens,
200 },
201 model,
202 response_id: message_id,
203 })
204 }
205
206 pub async fn send_stream_abortable(
208 &self,
209 messages: Vec<Message>,
210 config: &LlmRequestConfig,
211 abort_flag: Arc<AtomicBool>,
212 ) -> Result<LlmResponse> {
213 let message_stream = self.send_stream(messages, config).await?;
214 let mut callback_stream = CallbackStream::new(message_stream, None);
215
216 let mut content = String::new();
217 let mut input_tokens = 0u32;
218 let mut output_tokens = 0u32;
219 let mut message_id = String::new();
220 let mut model = config.model.clone();
221
222 while !abort_flag.load(Ordering::Relaxed) {
223 match callback_stream.next_event().await {
224 Ok(Some(event)) => match event {
225 StreamEvent::MessageStart { id, model: m } => {
226 message_id = id;
227 model = m;
228 }
229 StreamEvent::ContentBlockDelta {
230 delta: ContentDelta::Text(t),
231 ..
232 } => {
233 content.push_str(&t);
234 }
235 StreamEvent::ContentBlockDelta { .. } => {}
236 StreamEvent::MessageDelta { usage, .. } => {
237 input_tokens = usage.input_tokens;
238 output_tokens = usage.output_tokens;
239 }
240 StreamEvent::MessageStop => {
241 break;
242 }
243 _ => {}
244 },
245 Ok(None) => break,
246 Err(e) => {
247 if abort_flag.load(Ordering::Relaxed) {
248 info!("Stream aborted by user");
249 break;
250 }
251 return Err(e);
252 }
253 }
254 }
255
256 if abort_flag.load(Ordering::Relaxed) {
257 info!("Stream was aborted");
258 }
259
260 Ok(LlmResponse {
261 content,
262 usage: TokenUsage {
263 input_tokens,
264 output_tokens,
265 },
266 model,
267 response_id: message_id,
268 })
269 }
270
271 pub async fn send_with_retry(
273 &self,
274 messages: Vec<Message>,
275 config: &LlmRequestConfig,
276 max_retries: u32,
277 ) -> Result<LlmResponse> {
278 let mut attempts = 0;
279 let mut last_error: Option<anyhow::Error> = None;
280
281 while attempts < max_retries {
282 attempts += 1;
283
284 match self.send(messages.clone(), config).await {
285 Ok(response) => {
286 info!("LLM request succeeded after {} attempts", attempts);
287 return Ok(response);
288 }
289 Err(e) => {
290 let error_msg = e.to_string();
291
292 if error_msg.contains("rate limit")
293 || error_msg.contains("429")
294 || error_msg.contains("overloaded")
295 || error_msg.contains("timeout")
296 {
297 warn!(
298 "LLM request failed (attempt {}/{}): {}",
299 attempts, max_retries, e
300 );
301 last_error = Some(e);
302
303 let delay = std::cmp::min(1000 * 2u64.pow(attempts - 1), 30000);
304 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
305 } else {
306 return Err(e);
307 }
308 }
309 }
310 }
311
312 Err(last_error.unwrap_or_else(|| anyhow!("Max retries exceeded")))
313 }
314
315 pub async fn send_stream_with_retry(
317 &self,
318 messages: Vec<Message>,
319 config: &LlmRequestConfig,
320 max_retries: u32,
321 ) -> Result<LlmResponse> {
322 let mut attempts = 0;
323 let mut last_error: Option<anyhow::Error> = None;
324
325 while attempts < max_retries {
326 attempts += 1;
327
328 match self
329 .send_stream_with_callback(messages.clone(), config, Box::new(|_| {}))
330 .await
331 {
332 Ok(response) => {
333 info!("Stream request succeeded after {} attempts", attempts);
334 return Ok(response);
335 }
336 Err(e) => {
337 let error_msg = e.to_string();
338
339 if error_msg.contains("rate limit")
340 || error_msg.contains("429")
341 || error_msg.contains("overloaded")
342 || error_msg.contains("timeout")
343 || error_msg.contains("aborted")
344 {
345 warn!(
346 "Stream request failed (attempt {}/{}): {}",
347 attempts, max_retries, e
348 );
349 last_error = Some(e);
350
351 let delay = std::cmp::min(1000 * 2u64.pow(attempts - 1), 30000);
352 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
353 } else {
354 return Err(e);
355 }
356 }
357 }
358 }
359
360 Err(last_error.unwrap_or_else(|| anyhow!("Max retries exceeded")))
361 }
362}
363
364#[async_trait]
365impl LlmClientTrait for LlmClient {
366 async fn send(&self, messages: Vec<Message>, config: &LlmRequestConfig) -> Result<LlmResponse> {
367 match self.provider {
368 LlmProvider::Anthropic | LlmProvider::AnthropicCompatible { .. } => {
369 self.send_anthropic(messages, config).await
370 }
371 LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. } => {
372 self.send_openai(messages, config).await
373 }
374 LlmProvider::Gemini => self.send_gemini(messages, config).await,
375 LlmProvider::AzureOpenAI => self.send_azure_openai(messages, config).await,
376 LlmProvider::Bedrock => self.send_bedrock(messages, config).await,
377 LlmProvider::Ollama => self.send_ollama(messages, config).await,
378 LlmProvider::Custom(_) => {
379 Err(anyhow!("Custom provider requires custom implementation. Use an OpenAI-compatible provider instead."))
380 }
381 }
382 }
383
384 async fn send_stream(
385 &self,
386 messages: Vec<Message>,
387 config: &LlmRequestConfig,
388 ) -> Result<MessageStream> {
389 match self.provider {
390 LlmProvider::Anthropic | LlmProvider::AnthropicCompatible { .. } => {
391 self.stream_anthropic(messages, config).await
392 }
393 LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. } => {
394 self.stream_openai(messages, config).await
395 }
396 LlmProvider::Gemini => self.stream_gemini(messages, config).await,
397 LlmProvider::AzureOpenAI => self.stream_azure_openai(messages, config).await,
398 LlmProvider::Bedrock => self.stream_bedrock(messages, config).await,
399 LlmProvider::Ollama => self.stream_ollama(messages, config).await,
400 LlmProvider::Custom(_) => Err(anyhow!("Custom provider does not support streaming. Use an OpenAI-compatible provider instead.")),
401 }
402 }
403}
404
405impl LlmClient {
406 pub fn build_anthropic_messages_url(base_url: &str) -> String {
414 let base = base_url.trim_end_matches('/');
415
416 if base.ends_with("/messages") {
418 return base.to_string();
419 }
420
421 if base.ends_with("/v1") {
423 return format!("{}/messages", base);
424 }
425
426 if base.contains("/anthropic") {
429 return format!("{}/messages", base);
430 }
431
432 format!("{}/v1/messages", base)
434 }
435
436 async fn send_anthropic(
437 &self,
438 messages: Vec<Message>,
439 config: &LlmRequestConfig,
440 ) -> Result<LlmResponse> {
441 let url = Self::build_anthropic_messages_url(&self.base_url);
442
443 let request_body = AnthropicRequest {
444 model: config.model.clone(),
445 max_tokens: config.max_tokens,
446 messages: messages
447 .into_iter()
448 .map(|m| AnthropicMessage {
449 role: match m.role {
450 MessageRole::User => "user",
451 MessageRole::Assistant => "assistant",
452 MessageRole::System => "system",
453 },
454 content: AnthropicContent::Text(m.content),
455 })
456 .collect(),
457 system: config.system_prompt.clone(),
458 temperature: config.temperature,
459 };
460
461 let response = self
462 .client
463 .post(&url)
464 .header("x-api-key", &self.api_key)
465 .header("anthropic-version", "2023-06-01")
466 .json(&request_body)
467 .send()
468 .await?;
469
470 let response_text = response.text().await?;
471 tracing::debug!("Anthropic API response: {}", response_text);
472
473 let response_body: AnthropicResponse = serde_json::from_str(&response_text)?;
474
475 Ok(LlmResponse {
476 content: response_body
477 .content
478 .first()
479 .map(|c| c.text.clone())
480 .unwrap_or_default(),
481 usage: TokenUsage {
482 input_tokens: response_body.usage.input_tokens,
483 output_tokens: response_body.usage.output_tokens,
484 },
485 model: response_body.model,
486 response_id: response_body.id,
487 })
488 }
489
490 async fn stream_anthropic(
491 &self,
492 messages: Vec<Message>,
493 config: &LlmRequestConfig,
494 ) -> Result<MessageStream> {
495 let url = Self::build_anthropic_messages_url(&self.base_url);
496
497 let request_body = AnthropicStreamRequest {
498 model: config.model.clone(),
499 max_tokens: config.max_tokens,
500 messages: messages
501 .into_iter()
502 .map(|m| AnthropicMessage {
503 role: match m.role {
504 MessageRole::User => "user",
505 MessageRole::Assistant => "assistant",
506 MessageRole::System => "system",
507 },
508 content: AnthropicContent::Text(m.content),
509 })
510 .collect(),
511 system: config.system_prompt.clone(),
512 temperature: config.temperature,
513 stream: true,
514 };
515
516 let response = self
517 .client
518 .post(&url)
519 .header("x-api-key", &self.api_key)
520 .header("anthropic-version", "2023-06-01")
521 .header("Accept", "text/event-stream")
522 .json(&request_body)
523 .send()
524 .await?;
525
526 let status = response.status();
527 if !status.is_success() {
528 let error_text = response.text().await?;
529 return Err(anyhow!("Anthropic API error {}: {}", status, error_text));
530 }
531
532 Ok(MessageStream::new(
533 response,
534 match self.provider {
535 LlmProvider::Anthropic => StreamProvider::Anthropic,
536 LlmProvider::AnthropicCompatible { .. } => StreamProvider::AnthropicCompatible,
537 _ => StreamProvider::Anthropic, },
539 config.model.clone(),
540 ))
541 }
542
543 async fn send_openai(
544 &self,
545 messages: Vec<Message>,
546 config: &LlmRequestConfig,
547 ) -> Result<LlmResponse> {
548 let url = format!("{}/chat/completions", self.base_url);
549
550 let mut openai_messages: Vec<OpenAiMessage> = Vec::new();
551
552 if let Some(ref system) = config.system_prompt {
553 openai_messages.push(OpenAiMessage {
554 role: "system",
555 content: system.clone(),
556 });
557 }
558
559 for m in messages {
560 openai_messages.push(OpenAiMessage {
561 role: match m.role {
562 MessageRole::User => "user",
563 MessageRole::Assistant => "assistant",
564 MessageRole::System => "system",
565 },
566 content: m.content,
567 });
568 }
569
570 let request_body = OpenAiRequest {
571 model: config.model.clone(),
572 messages: openai_messages,
573 max_tokens: Some(config.max_tokens),
574 temperature: Some(config.temperature),
575 stop: if config.stop_sequences.is_empty() {
576 None
577 } else {
578 Some(config.stop_sequences.clone())
579 },
580 };
581
582 let response = self
583 .client
584 .post(&url)
585 .header("Authorization", format!("Bearer {}", self.api_key))
586 .json(&request_body)
587 .send()
588 .await?;
589
590 let response_body: OpenAiResponse = response.json().await?;
591
592 let choice = response_body
593 .choices
594 .first()
595 .ok_or_else(|| anyhow!("No response choices"))?;
596
597 Ok(LlmResponse {
598 content: choice.message.content.clone(),
599 usage: TokenUsage {
600 input_tokens: response_body.usage.prompt_tokens,
601 output_tokens: response_body.usage.completion_tokens,
602 },
603 model: response_body.model,
604 response_id: response_body.id,
605 })
606 }
607
608 async fn stream_openai(
609 &self,
610 messages: Vec<Message>,
611 config: &LlmRequestConfig,
612 ) -> Result<MessageStream> {
613 let url = format!("{}/chat/completions", self.base_url);
614
615 let mut openai_messages: Vec<OpenAiMessage> = Vec::new();
616 if let Some(ref system) = config.system_prompt {
617 openai_messages.push(OpenAiMessage {
618 role: "system",
619 content: system.clone(),
620 });
621 }
622 for m in messages {
623 openai_messages.push(OpenAiMessage {
624 role: match m.role {
625 MessageRole::User => "user",
626 MessageRole::Assistant => "assistant",
627 MessageRole::System => "system",
628 },
629 content: m.content,
630 });
631 }
632
633 let request_body = OpenAiStreamRequest {
634 model: config.model.clone(),
635 messages: openai_messages,
636 max_tokens: Some(config.max_tokens),
637 temperature: Some(config.temperature),
638 stream: true,
639 };
640
641 let response = self
642 .client
643 .post(&url)
644 .header("Authorization", format!("Bearer {}", self.api_key))
645 .header("Accept", "text/event-stream")
646 .json(&request_body)
647 .send()
648 .await?;
649
650 let status = response.status();
651 if !status.is_success() {
652 let error_text = response.text().await?;
653 return Err(anyhow!("OpenAI API error {}: {}", status, error_text));
654 }
655
656 Ok(MessageStream::new(
657 response,
658 match self.provider {
659 LlmProvider::OpenAI => StreamProvider::OpenAI,
660 LlmProvider::OpenAICompatible { .. } => StreamProvider::OpenAICompatible,
661 _ => StreamProvider::OpenAI, },
663 config.model.clone(),
664 ))
665 }
666
667 async fn send_gemini(
668 &self,
669 messages: Vec<Message>,
670 config: &LlmRequestConfig,
671 ) -> Result<LlmResponse> {
672 let url = format!(
673 "{}/models/{}:generateContent?key={}",
674 self.base_url, config.model, self.api_key
675 );
676
677 let mut contents: Vec<GeminiContent> = Vec::new();
678 let system_instruction = config.system_prompt.clone();
679
680 for m in messages {
681 contents.push(GeminiContent {
682 role: match m.role {
683 MessageRole::User => "user".to_string(),
684 MessageRole::Assistant => "model".to_string(),
685 MessageRole::System => "user".to_string(),
686 },
687 parts: vec![GeminiPart { text: m.content }],
688 });
689 }
690
691 let request_body = GeminiRequest {
692 contents,
693 generation_config: Some(GeminiGenerationConfig {
694 max_output_tokens: Some(config.max_tokens),
695 temperature: Some(config.temperature),
696 stop_sequences: if config.stop_sequences.is_empty() {
697 None
698 } else {
699 Some(config.stop_sequences.clone())
700 },
701 }),
702 system_instruction: system_instruction.map(|s| GeminiSystemInstruction {
703 parts: vec![GeminiPart { text: s }],
704 }),
705 };
706
707 let response = self.client.post(&url).json(&request_body).send().await?;
708
709 let response_body: GeminiResponse = response.json().await?;
710
711 let candidate = response_body
712 .candidates
713 .first()
714 .ok_or_else(|| anyhow!("No response candidates"))?;
715
716 let content = candidate
717 .content
718 .parts
719 .first()
720 .map(|p| p.text.clone())
721 .unwrap_or_default();
722
723 Ok(LlmResponse {
724 content,
725 usage: TokenUsage {
726 input_tokens: response_body.usage_metadata.prompt_token_count.unwrap_or(0),
727 output_tokens: response_body
728 .usage_metadata
729 .candidates_token_count
730 .unwrap_or(0),
731 },
732 model: config.model.clone(),
733 response_id: "".to_string(),
734 })
735 }
736
737 async fn stream_gemini(
738 &self,
739 messages: Vec<Message>,
740 config: &LlmRequestConfig,
741 ) -> Result<MessageStream> {
742 let url = format!(
743 "{}/models/{}:streamGenerateContent?key={}&alt=sse",
744 self.base_url, config.model, self.api_key
745 );
746
747 let mut contents: Vec<GeminiContent> = Vec::new();
748 let system_instruction = config.system_prompt.clone();
749
750 for m in messages {
751 contents.push(GeminiContent {
752 role: match m.role {
753 MessageRole::User => "user".to_string(),
754 MessageRole::Assistant => "model".to_string(),
755 MessageRole::System => "user".to_string(),
756 },
757 parts: vec![GeminiPart { text: m.content }],
758 });
759 }
760
761 let request_body = GeminiRequest {
762 contents,
763 generation_config: Some(GeminiGenerationConfig {
764 max_output_tokens: Some(config.max_tokens),
765 temperature: Some(config.temperature),
766 stop_sequences: if config.stop_sequences.is_empty() {
767 None
768 } else {
769 Some(config.stop_sequences.clone())
770 },
771 }),
772 system_instruction: system_instruction.map(|s| GeminiSystemInstruction {
773 parts: vec![GeminiPart { text: s }],
774 }),
775 };
776
777 let response = self.client.post(&url).json(&request_body).send().await?;
778
779 let status = response.status();
780 if !status.is_success() {
781 let error_text = response.text().await?;
782 return Err(anyhow!("Gemini API error {}: {}", status, error_text));
783 }
784
785 Ok(MessageStream::new(
786 response,
787 StreamProvider::Gemini,
788 config.model.clone(),
789 ))
790 }
791
792 async fn send_azure_openai(
797 &self,
798 messages: Vec<Message>,
799 config: &LlmRequestConfig,
800 ) -> Result<LlmResponse> {
801 let deployment = &config.model;
804 let url = format!(
805 "{}/openai/deployments/{}/chat/completions?api-version=2024-02-15-preview",
806 self.base_url, deployment
807 );
808
809 let mut azure_messages: Vec<OpenAiMessage> = Vec::new();
810 if let Some(ref system) = config.system_prompt {
811 azure_messages.push(OpenAiMessage {
812 role: "system",
813 content: system.clone(),
814 });
815 }
816 for m in messages {
817 azure_messages.push(OpenAiMessage {
818 role: match m.role {
819 MessageRole::User => "user",
820 MessageRole::Assistant => "assistant",
821 MessageRole::System => "system",
822 },
823 content: m.content,
824 });
825 }
826
827 let request_body = OpenAiRequest {
828 model: deployment.clone(), messages: azure_messages,
830 max_tokens: Some(config.max_tokens),
831 temperature: Some(config.temperature),
832 stop: if config.stop_sequences.is_empty() {
833 None
834 } else {
835 Some(config.stop_sequences.clone())
836 },
837 };
838
839 let response = self
840 .client
841 .post(&url)
842 .header("api-key", &self.api_key) .json(&request_body)
844 .send()
845 .await?;
846
847 let status = response.status();
848 if !status.is_success() {
849 let error_text = response.text().await?;
850 return Err(anyhow!("Azure OpenAI API error {}: {}", status, error_text));
851 }
852
853 let response_body: OpenAiResponse = response.json().await?;
854
855 let choice = response_body
856 .choices
857 .first()
858 .ok_or_else(|| anyhow!("No response choices"))?;
859
860 Ok(LlmResponse {
861 content: choice.message.content.clone(),
862 usage: TokenUsage {
863 input_tokens: response_body.usage.prompt_tokens,
864 output_tokens: response_body.usage.completion_tokens,
865 },
866 model: response_body.model,
867 response_id: response_body.id,
868 })
869 }
870
871 async fn stream_azure_openai(
872 &self,
873 messages: Vec<Message>,
874 config: &LlmRequestConfig,
875 ) -> Result<MessageStream> {
876 let deployment = &config.model;
877 let url = format!(
878 "{}/openai/deployments/{}/chat/completions?api-version=2024-02-15-preview",
879 self.base_url, deployment
880 );
881
882 let mut azure_messages: Vec<OpenAiMessage> = Vec::new();
883 if let Some(ref system) = config.system_prompt {
884 azure_messages.push(OpenAiMessage {
885 role: "system",
886 content: system.clone(),
887 });
888 }
889 for m in messages {
890 azure_messages.push(OpenAiMessage {
891 role: match m.role {
892 MessageRole::User => "user",
893 MessageRole::Assistant => "assistant",
894 MessageRole::System => "system",
895 },
896 content: m.content,
897 });
898 }
899
900 let request_body = OpenAiStreamRequest {
901 model: deployment.clone(),
902 messages: azure_messages,
903 max_tokens: Some(config.max_tokens),
904 temperature: Some(config.temperature),
905 stream: true,
906 };
907
908 let response = self
909 .client
910 .post(&url)
911 .header("api-key", &self.api_key)
912 .header("Accept", "text/event-stream")
913 .json(&request_body)
914 .send()
915 .await?;
916
917 let status = response.status();
918 if !status.is_success() {
919 let error_text = response.text().await?;
920 return Err(anyhow!("Azure OpenAI API error {}: {}", status, error_text));
921 }
922
923 Ok(MessageStream::new(
924 response,
925 StreamProvider::AzureOpenAI,
926 config.model.clone(),
927 ))
928 }
929
930 async fn send_bedrock(
935 &self,
936 messages: Vec<Message>,
937 config: &LlmRequestConfig,
938 ) -> Result<LlmResponse> {
939 let model_id = &config.model;
942 let url = format!("{}/model/{}/invoke", self.base_url, model_id);
943
944 let mut bedrock_messages: Vec<BedrockMessage> = Vec::new();
946 for m in messages {
947 bedrock_messages.push(BedrockMessage {
948 role: match m.role {
949 MessageRole::User => "user",
950 MessageRole::Assistant => "assistant",
951 MessageRole::System => "system",
952 },
953 content: vec![BedrockContent { text: m.content }],
954 });
955 }
956
957 let request_body = BedrockRequest {
958 messages: bedrock_messages,
959 system: config.system_prompt.clone(),
960 inference_config: Some(BedrockInferenceConfig {
961 max_tokens: config.max_tokens,
962 temperature: config.temperature,
963 top_p: None,
964 stop_sequences: if config.stop_sequences.is_empty() {
965 None
966 } else {
967 Some(config.stop_sequences.clone())
968 },
969 }),
970 };
971
972 let response = self
973 .client
974 .post(&url)
975 .header("Authorization", format!("Bearer {}", self.api_key))
976 .header("Content-Type", "application/json")
977 .json(&request_body)
978 .send()
979 .await?;
980
981 let status = response.status();
982 if !status.is_success() {
983 let error_text = response.text().await?;
984 return Err(anyhow!("Bedrock API error {}: {}", status, error_text));
985 }
986
987 let response_body: BedrockResponse = response.json().await?;
988
989 let content = response_body
990 .output
991 .message
992 .content
993 .first()
994 .map(|c| c.text.clone())
995 .unwrap_or_default();
996
997 Ok(LlmResponse {
998 content,
999 usage: TokenUsage {
1000 input_tokens: response_body.usage.input_tokens,
1001 output_tokens: response_body.usage.output_tokens,
1002 },
1003 model: config.model.clone(),
1004 response_id: response_body.request_id.unwrap_or_default(),
1005 })
1006 }
1007
1008 async fn stream_bedrock(
1009 &self,
1010 messages: Vec<Message>,
1011 config: &LlmRequestConfig,
1012 ) -> Result<MessageStream> {
1013 let model_id = &config.model;
1014 let url = format!(
1015 "{}/model/{}/invoke-with-response-stream",
1016 self.base_url, model_id
1017 );
1018
1019 let mut bedrock_messages: Vec<BedrockMessage> = Vec::new();
1020 for m in messages {
1021 bedrock_messages.push(BedrockMessage {
1022 role: match m.role {
1023 MessageRole::User => "user",
1024 MessageRole::Assistant => "assistant",
1025 MessageRole::System => "system",
1026 },
1027 content: vec![BedrockContent { text: m.content }],
1028 });
1029 }
1030
1031 let request_body = BedrockRequest {
1032 messages: bedrock_messages,
1033 system: config.system_prompt.clone(),
1034 inference_config: Some(BedrockInferenceConfig {
1035 max_tokens: config.max_tokens,
1036 temperature: config.temperature,
1037 top_p: None,
1038 stop_sequences: if config.stop_sequences.is_empty() {
1039 None
1040 } else {
1041 Some(config.stop_sequences.clone())
1042 },
1043 }),
1044 };
1045
1046 let response = self
1047 .client
1048 .post(&url)
1049 .header("Authorization", format!("Bearer {}", self.api_key))
1050 .header("Accept", "text/event-stream")
1051 .header("Content-Type", "application/json")
1052 .json(&request_body)
1053 .send()
1054 .await?;
1055
1056 let status = response.status();
1057 if !status.is_success() {
1058 let error_text = response.text().await?;
1059 return Err(anyhow!("Bedrock API error {}: {}", status, error_text));
1060 }
1061
1062 Ok(MessageStream::new(
1063 response,
1064 StreamProvider::Bedrock,
1065 config.model.clone(),
1066 ))
1067 }
1068
1069 async fn send_ollama(
1074 &self,
1075 messages: Vec<Message>,
1076 config: &LlmRequestConfig,
1077 ) -> Result<LlmResponse> {
1078 let url = format!("{}/api/chat", self.base_url);
1080
1081 let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
1082 if let Some(ref system) = config.system_prompt {
1083 ollama_messages.push(OllamaMessage {
1084 role: "system",
1085 content: system.clone(),
1086 });
1087 }
1088 for m in messages {
1089 ollama_messages.push(OllamaMessage {
1090 role: match m.role {
1091 MessageRole::User => "user",
1092 MessageRole::Assistant => "assistant",
1093 MessageRole::System => "system",
1094 },
1095 content: m.content,
1096 });
1097 }
1098
1099 let request_body = OllamaChatRequest {
1100 model: config.model.clone(),
1101 messages: ollama_messages,
1102 stream: false,
1103 options: Some(OllamaOptions {
1104 num_predict: config.max_tokens as i32,
1105 temperature: config.temperature,
1106 stop: if config.stop_sequences.is_empty() {
1107 None
1108 } else {
1109 Some(config.stop_sequences.clone())
1110 },
1111 }),
1112 };
1113
1114 let response = self
1116 .client
1117 .post(&url)
1118 .header("Content-Type", "application/json")
1119 .json(&request_body)
1120 .send()
1121 .await?;
1122
1123 let status = response.status();
1124 if !status.is_success() {
1125 let error_text = response.text().await?;
1126 return Err(anyhow!("Ollama API error {}: {}", status, error_text));
1127 }
1128
1129 let response_body: OllamaChatResponse = response.json().await?;
1130
1131 Ok(LlmResponse {
1132 content: response_body.message.content,
1133 usage: TokenUsage {
1134 input_tokens: response_body.prompt_eval_count.unwrap_or(0),
1135 output_tokens: response_body.eval_count.unwrap_or(0),
1136 },
1137 model: response_body.model,
1138 response_id: "".to_string(),
1139 })
1140 }
1141
1142 async fn stream_ollama(
1143 &self,
1144 messages: Vec<Message>,
1145 config: &LlmRequestConfig,
1146 ) -> Result<MessageStream> {
1147 let url = format!("{}/api/chat", self.base_url);
1148
1149 let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
1150 if let Some(ref system) = config.system_prompt {
1151 ollama_messages.push(OllamaMessage {
1152 role: "system",
1153 content: system.clone(),
1154 });
1155 }
1156 for m in messages {
1157 ollama_messages.push(OllamaMessage {
1158 role: match m.role {
1159 MessageRole::User => "user",
1160 MessageRole::Assistant => "assistant",
1161 MessageRole::System => "system",
1162 },
1163 content: m.content,
1164 });
1165 }
1166
1167 let request_body = OllamaChatRequest {
1168 model: config.model.clone(),
1169 messages: ollama_messages,
1170 stream: true,
1171 options: Some(OllamaOptions {
1172 num_predict: config.max_tokens as i32,
1173 temperature: config.temperature,
1174 stop: if config.stop_sequences.is_empty() {
1175 None
1176 } else {
1177 Some(config.stop_sequences.clone())
1178 },
1179 }),
1180 };
1181
1182 let response = self
1183 .client
1184 .post(&url)
1185 .header("Accept", "application/json")
1186 .header("Content-Type", "application/json")
1187 .json(&request_body)
1188 .send()
1189 .await?;
1190
1191 let status = response.status();
1192 if !status.is_success() {
1193 let error_text = response.text().await?;
1194 return Err(anyhow!("Ollama API error {}: {}", status, error_text));
1195 }
1196
1197 Ok(MessageStream::new(
1198 response,
1199 StreamProvider::Ollama,
1200 config.model.clone(),
1201 ))
1202 }
1203}
1204
1205#[derive(Serialize)]
1207struct AnthropicRequest {
1208 model: String,
1209 max_tokens: u32,
1210 messages: Vec<AnthropicMessage>,
1211 system: Option<String>,
1212 temperature: f32,
1213}
1214
1215#[derive(Serialize)]
1216struct AnthropicStreamRequest {
1217 model: String,
1218 max_tokens: u32,
1219 messages: Vec<AnthropicMessage>,
1220 system: Option<String>,
1221 temperature: f32,
1222 stream: bool,
1223}
1224
1225#[derive(Serialize)]
1226struct AnthropicMessage {
1227 role: &'static str,
1228 content: AnthropicContent,
1229}
1230
1231#[derive(Serialize)]
1232#[serde(untagged)]
1233#[allow(dead_code)]
1234enum AnthropicContent {
1235 Text(String),
1236 Blocks(Vec<AnthropicContentBlock>),
1237}
1238
1239#[derive(Serialize)]
1240struct AnthropicContentBlock {
1241 #[serde(rename = "type")]
1242 content_type: String,
1243 text: String,
1244}
1245
1246#[derive(Deserialize)]
1247#[allow(dead_code)]
1248struct AnthropicResponse {
1249 #[serde(default)]
1250 id: String,
1251 #[serde(default)]
1252 model: String,
1253 #[serde(default)]
1254 content: Vec<AnthropicContentResponse>,
1255 #[serde(default)]
1256 usage: AnthropicUsage,
1257 #[serde(default)]
1258 #[serde(rename = "type")]
1259 response_type: Option<String>,
1260 #[serde(default)]
1261 role: Option<String>,
1262 #[serde(default)]
1263 stop_reason: Option<String>,
1264}
1265
1266#[derive(Deserialize)]
1267#[allow(dead_code)]
1268struct AnthropicContentResponse {
1269 #[serde(rename = "type", default)]
1270 content_type: String,
1271 #[serde(default)]
1272 text: String,
1273}
1274
1275#[derive(Deserialize, Default)]
1276struct AnthropicUsage {
1277 #[serde(default)]
1278 input_tokens: u32,
1279 #[serde(default)]
1280 output_tokens: u32,
1281}
1282
1283#[derive(Serialize)]
1285struct OpenAiRequest {
1286 model: String,
1287 messages: Vec<OpenAiMessage>,
1288 #[serde(skip_serializing_if = "Option::is_none")]
1289 max_tokens: Option<u32>,
1290 #[serde(skip_serializing_if = "Option::is_none")]
1291 temperature: Option<f32>,
1292 #[serde(skip_serializing_if = "Option::is_none")]
1293 stop: Option<Vec<String>>,
1294}
1295
1296#[derive(Serialize)]
1297struct OpenAiStreamRequest {
1298 model: String,
1299 messages: Vec<OpenAiMessage>,
1300 #[serde(skip_serializing_if = "Option::is_none")]
1301 max_tokens: Option<u32>,
1302 #[serde(skip_serializing_if = "Option::is_none")]
1303 temperature: Option<f32>,
1304 stream: bool,
1305}
1306
1307#[derive(Serialize)]
1308struct OpenAiMessage {
1309 role: &'static str,
1310 content: String,
1311}
1312
1313#[derive(Deserialize)]
1314struct OpenAiResponse {
1315 id: String,
1316 model: String,
1317 choices: Vec<OpenAiChoice>,
1318 usage: OpenAiUsage,
1319}
1320
1321#[derive(Deserialize)]
1322#[allow(dead_code)]
1323struct OpenAiChoice {
1324 message: OpenAiResponseMessage,
1325 finish_reason: String,
1326}
1327
1328#[derive(Deserialize)]
1329#[allow(dead_code)]
1330struct OpenAiResponseMessage {
1331 role: String,
1332 content: String,
1333}
1334
1335#[derive(Deserialize)]
1336#[allow(dead_code)]
1337struct OpenAiUsage {
1338 prompt_tokens: u32,
1339 completion_tokens: u32,
1340 total_tokens: u32,
1341}
1342
1343#[derive(Serialize)]
1345struct GeminiRequest {
1346 contents: Vec<GeminiContent>,
1347 #[serde(skip_serializing_if = "Option::is_none")]
1348 generation_config: Option<GeminiGenerationConfig>,
1349 #[serde(skip_serializing_if = "Option::is_none")]
1350 system_instruction: Option<GeminiSystemInstruction>,
1351}
1352
1353#[derive(Serialize)]
1354struct GeminiContent {
1355 role: String,
1356 parts: Vec<GeminiPart>,
1357}
1358
1359#[derive(Serialize)]
1360struct GeminiPart {
1361 text: String,
1362}
1363
1364#[derive(Serialize)]
1365struct GeminiGenerationConfig {
1366 #[serde(skip_serializing_if = "Option::is_none")]
1367 max_output_tokens: Option<u32>,
1368 #[serde(skip_serializing_if = "Option::is_none")]
1369 temperature: Option<f32>,
1370 #[serde(skip_serializing_if = "Option::is_none")]
1371 stop_sequences: Option<Vec<String>>,
1372}
1373
1374#[derive(Serialize)]
1375struct GeminiSystemInstruction {
1376 parts: Vec<GeminiPart>,
1377}
1378
1379#[derive(Deserialize)]
1380struct GeminiResponse {
1381 candidates: Vec<GeminiCandidate>,
1382 usage_metadata: GeminiUsageMetadata,
1383}
1384
1385#[derive(Deserialize)]
1386#[allow(dead_code)]
1387struct GeminiCandidate {
1388 content: GeminiContentResponse,
1389 finish_reason: String,
1390}
1391
1392#[derive(Deserialize)]
1393#[allow(dead_code)]
1394struct GeminiContentResponse {
1395 parts: Vec<GeminiPartResponse>,
1396 role: String,
1397}
1398
1399#[derive(Deserialize)]
1400struct GeminiPartResponse {
1401 text: String,
1402}
1403
1404#[derive(Deserialize)]
1405#[allow(dead_code)]
1406struct GeminiUsageMetadata {
1407 prompt_token_count: Option<u32>,
1408 candidates_token_count: Option<u32>,
1409 total_token_count: Option<u32>,
1410}
1411
1412#[derive(Serialize)]
1417struct BedrockRequest {
1418 messages: Vec<BedrockMessage>,
1419 #[serde(skip_serializing_if = "Option::is_none")]
1420 system: Option<String>,
1421 #[serde(skip_serializing_if = "Option::is_none")]
1422 inference_config: Option<BedrockInferenceConfig>,
1423}
1424
1425#[derive(Serialize)]
1426struct BedrockMessage {
1427 role: &'static str,
1428 content: Vec<BedrockContent>,
1429}
1430
1431#[derive(Serialize)]
1432struct BedrockContent {
1433 text: String,
1434}
1435
1436#[derive(Serialize)]
1437struct BedrockInferenceConfig {
1438 #[serde(rename = "maxTokens")]
1439 max_tokens: u32,
1440 temperature: f32,
1441 #[serde(skip_serializing_if = "Option::is_none")]
1442 top_p: Option<f32>,
1443 #[serde(skip_serializing_if = "Option::is_none")]
1444 stop_sequences: Option<Vec<String>>,
1445}
1446
1447#[derive(Deserialize)]
1448#[allow(dead_code)]
1449struct BedrockResponse {
1450 output: BedrockOutput,
1451 usage: BedrockUsage,
1452 #[serde(default)]
1453 request_id: Option<String>,
1454}
1455
1456#[derive(Deserialize)]
1457struct BedrockOutput {
1458 message: BedrockResponseMessage,
1459}
1460
1461#[derive(Deserialize)]
1462struct BedrockResponseMessage {
1463 content: Vec<BedrockResponseContent>,
1464}
1465
1466#[derive(Deserialize)]
1467struct BedrockResponseContent {
1468 text: String,
1469}
1470
1471#[derive(Deserialize)]
1472struct BedrockUsage {
1473 #[serde(default)]
1474 input_tokens: u32,
1475 #[serde(default)]
1476 output_tokens: u32,
1477}
1478
1479#[derive(Serialize)]
1484struct OllamaChatRequest {
1485 model: String,
1486 messages: Vec<OllamaMessage>,
1487 stream: bool,
1488 #[serde(skip_serializing_if = "Option::is_none")]
1489 options: Option<OllamaOptions>,
1490}
1491
1492#[derive(Serialize)]
1493struct OllamaMessage {
1494 role: &'static str,
1495 content: String,
1496}
1497
1498#[derive(Serialize)]
1499struct OllamaOptions {
1500 num_predict: i32,
1501 temperature: f32,
1502 #[serde(skip_serializing_if = "Option::is_none")]
1503 stop: Option<Vec<String>>,
1504}
1505
1506#[derive(Deserialize)]
1507struct OllamaChatResponse {
1508 model: String,
1509 message: OllamaResponseMessage,
1510 #[serde(default)]
1511 prompt_eval_count: Option<u32>,
1512 #[serde(default)]
1513 eval_count: Option<u32>,
1514}
1515
1516#[derive(Deserialize)]
1517struct OllamaResponseMessage {
1518 content: String,
1519}
1520
1521#[cfg(test)]
1522mod tests {
1523 use super::*;
1524
1525 #[test]
1526 fn test_default_config() {
1527 let config = LlmRequestConfig::default();
1528 assert_eq!(config.model, "claude-sonnet-4-6");
1529 assert_eq!(config.max_tokens, 4096);
1530 }
1531
1532 #[test]
1533 fn test_client_creation() {
1534 let client = LlmClient::new(LlmProvider::Anthropic, "test_key".to_string());
1535 assert_eq!(client.base_url, "https://api.anthropic.com/v1");
1536 }
1537
1538 #[test]
1539 fn test_openai_client_creation() {
1540 let client = LlmClient::new(LlmProvider::OpenAI, "test_key".to_string());
1541 assert_eq!(client.base_url, "https://api.openai.com/v1");
1542 }
1543
1544 #[test]
1545 fn test_gemini_client_creation() {
1546 let client = LlmClient::new(LlmProvider::Gemini, "test_key".to_string());
1547 assert_eq!(
1548 client.base_url,
1549 "https://generativelanguage.googleapis.com/v1"
1550 );
1551 }
1552
1553 #[test]
1554 fn test_custom_provider() {
1555 let client = LlmClient::new(
1556 LlmProvider::Custom("https://custom.api.com/v1".to_string()),
1557 "test_key".to_string(),
1558 );
1559 assert_eq!(client.base_url, "https://custom.api.com/v1");
1560 }
1561
1562 #[test]
1563 fn test_openai_compatible_provider() {
1564 let client = LlmClient::new(
1565 LlmProvider::OpenAICompatible {
1566 base_url: "https://api.deepseek.com/v1".to_string(),
1567 },
1568 "test_key".to_string(),
1569 );
1570 assert_eq!(client.base_url, "https://api.deepseek.com/v1");
1571 }
1572
1573 #[test]
1574 fn test_azure_openai_client_creation() {
1575 let client = LlmClient::new(LlmProvider::AzureOpenAI, "test_key".to_string());
1576 assert!(client.base_url.contains("openai.azure.com"));
1577 }
1578
1579 #[test]
1580 fn test_bedrock_client_creation() {
1581 let client = LlmClient::new(LlmProvider::Bedrock, "test_key".to_string());
1582 assert!(client.base_url.contains("bedrock-runtime"));
1583 }
1584
1585 #[test]
1586 fn test_ollama_client_creation() {
1587 let client = LlmClient::new(LlmProvider::Ollama, "".to_string());
1588 assert_eq!(client.base_url, "http://localhost:11434");
1589 }
1590
1591 #[test]
1592 fn test_azure_openai_with_custom_url() {
1593 let client = LlmClient::new(LlmProvider::AzureOpenAI, "test_key".to_string())
1594 .with_base_url("https://myresource.openai.azure.com".to_string());
1595 assert_eq!(client.base_url, "https://myresource.openai.azure.com");
1596 }
1597
1598 #[test]
1599 fn test_ollama_with_custom_url() {
1600 let client = LlmClient::new(LlmProvider::Ollama, "".to_string())
1601 .with_base_url("http://192.168.1.100:11434".to_string());
1602 assert_eq!(client.base_url, "http://192.168.1.100:11434");
1603 }
1604
1605 #[test]
1606 fn test_message_creation() {
1607 let message = Message {
1608 role: MessageRole::User,
1609 content: "Hello".to_string(),
1610 };
1611 assert_eq!(message.content, "Hello");
1612 }
1613
1614 #[test]
1615 fn test_config_with_system_prompt() {
1616 let config = LlmRequestConfig {
1617 model: "gpt-4".to_string(),
1618 max_tokens: 8192,
1619 temperature: 0.5,
1620 system_prompt: Some("You are a helpful assistant".to_string()),
1621 stop_sequences: vec![],
1622 };
1623 assert_eq!(config.model, "gpt-4");
1624 assert!(config.system_prompt.is_some());
1625 }
1626
1627 #[test]
1628 fn test_llm_response_creation() {
1629 let response = LlmResponse {
1630 content: "Hello".to_string(),
1631 usage: TokenUsage {
1632 input_tokens: 10,
1633 output_tokens: 5,
1634 },
1635 model: "gpt-4".to_string(),
1636 response_id: "resp_123".to_string(),
1637 };
1638 assert_eq!(response.content, "Hello");
1639 assert_eq!(response.usage.input_tokens, 10);
1640 }
1641
1642 #[test]
1643 fn test_provider_serialization() {
1644 let provider = LlmProvider::Anthropic;
1645 let json = serde_json::to_string(&provider).unwrap();
1646 assert!(json.contains("Anthropic"));
1647 }
1648
1649 #[test]
1650 fn test_message_role_serialization() {
1651 let role = MessageRole::User;
1652 let json = serde_json::to_string(&role).unwrap();
1653 assert!(json.contains("User"));
1654 }
1655
1656 #[test]
1658 fn test_anthropic_compatible_provider_creation() {
1659 let client = LlmClient::new(
1660 LlmProvider::AnthropicCompatible {
1661 base_url: "https://api.lkeap.cloud.tencent.com/coding/anthropic".to_string(),
1662 },
1663 "test_key".to_string(),
1664 );
1665 assert_eq!(
1666 client.base_url,
1667 "https://api.lkeap.cloud.tencent.com/coding/anthropic"
1668 );
1669 }
1670
1671 #[test]
1672 fn test_anthropic_compatible_provider_serialization() {
1673 let provider = LlmProvider::AnthropicCompatible {
1674 base_url: "https://example.com".to_string(),
1675 };
1676 let json = serde_json::to_string(&provider).unwrap();
1677 assert!(json.contains("anthropic_compatible") || json.contains("AnthropicCompatible"));
1678 }
1679
1680 #[test]
1682 fn test_build_anthropic_messages_url_official_api() {
1683 let url = LlmClient::build_anthropic_messages_url("https://api.anthropic.com");
1684 assert_eq!(url, "https://api.anthropic.com/v1/messages");
1685 }
1686
1687 #[test]
1688 fn test_build_anthropic_messages_url_already_has_v1() {
1689 let url = LlmClient::build_anthropic_messages_url("https://api.anthropic.com/v1");
1690 assert_eq!(url, "https://api.anthropic.com/v1/messages");
1691 }
1692
1693 #[test]
1694 fn test_build_anthropic_messages_url_already_has_messages() {
1695 let url =
1696 LlmClient::build_anthropic_messages_url("https://api.example.com/anthropic/messages");
1697 assert_eq!(url, "https://api.example.com/anthropic/messages");
1698 }
1699
1700 #[test]
1701 fn test_build_anthropic_messages_url_tencent_endpoint() {
1702 let url = LlmClient::build_anthropic_messages_url(
1703 "https://api.lkeap.cloud.tencent.com/coding/anthropic",
1704 );
1705 assert_eq!(
1706 url,
1707 "https://api.lkeap.cloud.tencent.com/coding/anthropic/messages"
1708 );
1709 }
1710
1711 #[test]
1712 fn test_build_anthropic_messages_url_with_trailing_slash() {
1713 let url = LlmClient::build_anthropic_messages_url("https://api.anthropic.com/v1/");
1714 assert_eq!(url, "https://api.anthropic.com/v1/messages");
1715 }
1716
1717 #[test]
1719 fn test_provider_routing_anthropic_compatible() {
1720 let provider = LlmProvider::AnthropicCompatible {
1722 base_url: "https://example.com".to_string(),
1723 };
1724 assert!(matches!(
1725 provider,
1726 LlmProvider::Anthropic | LlmProvider::AnthropicCompatible { .. }
1727 ));
1728 }
1729
1730 #[test]
1731 fn test_provider_routing_openai_compatible() {
1732 let provider = LlmProvider::OpenAICompatible {
1734 base_url: "https://example.com".to_string(),
1735 };
1736 assert!(matches!(
1737 provider,
1738 LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. }
1739 ));
1740 }
1741}