Skip to main content

sh_layer1/
llm_client.rs

1//! LLM 客户端模块
2//!
3//! 统一的 LLM API 客户端,支持多提供商。
4//!
5//! [STABLE] 基础请求功能完整
6//! [STABLE] 流式响应支持 Anthropic/OpenAI 格式
7
8use anyhow::{anyhow, Result};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use tracing::{info, warn};
15
16use crate::streaming::{
17    CallbackStream, ContentDelta, MessageStream, OnChunkCallback, StreamEvent, StreamProvider,
18};
19
20/// LLM 提供商类型
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum LlmProvider {
23    Anthropic,
24    OpenAI,
25    Gemini,
26    AzureOpenAI,
27    Bedrock,
28    Ollama,
29    /// OpenAI-compatible provider with custom base_url (e.g. deepseek, glm, qwen, kimi, grok)
30    OpenAICompatible {
31        base_url: String,
32    },
33    /// Anthropic-compatible provider with custom base_url (e.g. tencent-coding, other Claude API proxies)
34    AnthropicCompatible {
35        base_url: String,
36    },
37    Custom(String),
38}
39
40/// LLM 请求配置
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct LlmRequestConfig {
43    /// 模型名称
44    pub model: String,
45    /// 最大 token 数
46    pub max_tokens: u32,
47    /// 温度参数
48    pub temperature: f32,
49    /// 系统提示
50    pub system_prompt: Option<String>,
51    /// 停止词
52    pub stop_sequences: Vec<String>,
53}
54
55impl Default for LlmRequestConfig {
56    fn default() -> Self {
57        Self {
58            model: "claude-sonnet-4-6".to_string(),
59            max_tokens: 4096,
60            temperature: 0.7,
61            system_prompt: None,
62            stop_sequences: vec!["\n\n\n".to_string()],
63        }
64    }
65}
66
67/// LLM 响应
68#[derive(Debug, Serialize, Deserialize)]
69pub struct LlmResponse {
70    /// 响应内容
71    pub content: String,
72    /// Token 使用情况
73    pub usage: TokenUsage,
74    /// 模型名称
75    pub model: String,
76    /// 响应 ID
77    pub response_id: String,
78}
79
80/// Token 使用情况
81#[derive(Debug, Serialize, Deserialize)]
82pub struct TokenUsage {
83    /// 输入 token 数
84    pub input_tokens: u32,
85    /// 输出 token 数
86    pub output_tokens: u32,
87}
88
89/// LLM 客户端 trait
90#[async_trait]
91pub trait LlmClientTrait {
92    /// 发送请求并获取响应
93    async fn send(&self, messages: Vec<Message>, config: &LlmRequestConfig) -> Result<LlmResponse>;
94
95    /// 发送请求并流式获取响应
96    async fn send_stream(
97        &self,
98        messages: Vec<Message>,
99        config: &LlmRequestConfig,
100    ) -> Result<MessageStream>;
101}
102
103/// 消息
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct Message {
106    pub role: MessageRole,
107    pub content: String,
108}
109
110/// 消息角色
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum MessageRole {
113    User,
114    Assistant,
115    System,
116}
117
118/// LLM 客户端实现
119pub struct LlmClient {
120    /// HTTP 客户端
121    client: Client,
122    /// API 密钥
123    api_key: String,
124    /// 提供商
125    provider: LlmProvider,
126    /// API 基础 URL
127    base_url: String,
128}
129
130impl LlmClient {
131    pub fn new(provider: LlmProvider, api_key: String) -> Self {
132        let base_url = match &provider {
133            LlmProvider::Anthropic => "https://api.anthropic.com/v1".to_string(),
134            LlmProvider::OpenAI => "https://api.openai.com/v1".to_string(),
135            LlmProvider::Gemini => "https://generativelanguage.googleapis.com/v1".to_string(),
136            LlmProvider::AzureOpenAI => "https://YOUR_RESOURCE.openai.azure.com".to_string(),
137            LlmProvider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com".to_string(),
138            LlmProvider::Ollama => "http://localhost:11434".to_string(),
139            LlmProvider::OpenAICompatible { base_url } => base_url.clone(),
140            LlmProvider::AnthropicCompatible { base_url } => base_url.clone(),
141            LlmProvider::Custom(url) => url.clone(),
142        };
143
144        Self {
145            client: Client::new(),
146            api_key,
147            provider,
148            base_url,
149        }
150    }
151
152    /// 创建客户端并指定自定义 base_url(覆盖 provider 默认值)
153    pub fn with_base_url(mut self, base_url: String) -> Self {
154        self.base_url = base_url;
155        self
156    }
157
158    /// 发送带回调的流式请求
159    pub async fn send_stream_with_callback(
160        &self,
161        messages: Vec<Message>,
162        config: &LlmRequestConfig,
163        on_chunk: OnChunkCallback,
164    ) -> Result<LlmResponse> {
165        let message_stream = self.send_stream(messages, config).await?;
166        let mut callback_stream = CallbackStream::new(message_stream, Some(on_chunk));
167
168        let mut content = String::new();
169        let mut input_tokens = 0u32;
170        let mut output_tokens = 0u32;
171        let mut message_id = String::new();
172        let mut model = config.model.clone();
173
174        while let Some(event) = callback_stream.next_event().await? {
175            match event {
176                StreamEvent::MessageStart { id, model: m } => {
177                    message_id = id;
178                    model = m;
179                }
180                StreamEvent::ContentBlockDelta {
181                    delta: ContentDelta::Text(t),
182                    ..
183                } => {
184                    content.push_str(&t);
185                }
186                StreamEvent::ContentBlockDelta { .. } => {}
187                StreamEvent::MessageDelta { usage, .. } => {
188                    input_tokens = usage.input_tokens;
189                    output_tokens = usage.output_tokens;
190                }
191                _ => {}
192            }
193        }
194
195        Ok(LlmResponse {
196            content,
197            usage: TokenUsage {
198                input_tokens,
199                output_tokens,
200            },
201            model,
202            response_id: message_id,
203        })
204    }
205
206    /// 发送可中断的流式请求
207    pub async fn send_stream_abortable(
208        &self,
209        messages: Vec<Message>,
210        config: &LlmRequestConfig,
211        abort_flag: Arc<AtomicBool>,
212    ) -> Result<LlmResponse> {
213        let message_stream = self.send_stream(messages, config).await?;
214        let mut callback_stream = CallbackStream::new(message_stream, None);
215
216        let mut content = String::new();
217        let mut input_tokens = 0u32;
218        let mut output_tokens = 0u32;
219        let mut message_id = String::new();
220        let mut model = config.model.clone();
221
222        while !abort_flag.load(Ordering::Relaxed) {
223            match callback_stream.next_event().await {
224                Ok(Some(event)) => match event {
225                    StreamEvent::MessageStart { id, model: m } => {
226                        message_id = id;
227                        model = m;
228                    }
229                    StreamEvent::ContentBlockDelta {
230                        delta: ContentDelta::Text(t),
231                        ..
232                    } => {
233                        content.push_str(&t);
234                    }
235                    StreamEvent::ContentBlockDelta { .. } => {}
236                    StreamEvent::MessageDelta { usage, .. } => {
237                        input_tokens = usage.input_tokens;
238                        output_tokens = usage.output_tokens;
239                    }
240                    StreamEvent::MessageStop => {
241                        break;
242                    }
243                    _ => {}
244                },
245                Ok(None) => break,
246                Err(e) => {
247                    if abort_flag.load(Ordering::Relaxed) {
248                        info!("Stream aborted by user");
249                        break;
250                    }
251                    return Err(e);
252                }
253            }
254        }
255
256        if abort_flag.load(Ordering::Relaxed) {
257            info!("Stream was aborted");
258        }
259
260        Ok(LlmResponse {
261            content,
262            usage: TokenUsage {
263                input_tokens,
264                output_tokens,
265            },
266            model,
267            response_id: message_id,
268        })
269    }
270
271    /// 带错误恢复的请求重试
272    pub async fn send_with_retry(
273        &self,
274        messages: Vec<Message>,
275        config: &LlmRequestConfig,
276        max_retries: u32,
277    ) -> Result<LlmResponse> {
278        let mut attempts = 0;
279        let mut last_error: Option<anyhow::Error> = None;
280
281        while attempts < max_retries {
282            attempts += 1;
283
284            match self.send(messages.clone(), config).await {
285                Ok(response) => {
286                    info!("LLM request succeeded after {} attempts", attempts);
287                    return Ok(response);
288                }
289                Err(e) => {
290                    let error_msg = e.to_string();
291
292                    if error_msg.contains("rate limit")
293                        || error_msg.contains("429")
294                        || error_msg.contains("overloaded")
295                        || error_msg.contains("timeout")
296                    {
297                        warn!(
298                            "LLM request failed (attempt {}/{}): {}",
299                            attempts, max_retries, e
300                        );
301                        last_error = Some(e);
302
303                        let delay = std::cmp::min(1000 * 2u64.pow(attempts - 1), 30000);
304                        tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
305                    } else {
306                        return Err(e);
307                    }
308                }
309            }
310        }
311
312        Err(last_error.unwrap_or_else(|| anyhow!("Max retries exceeded")))
313    }
314
315    /// 带错误恢复的流式请求重试
316    pub async fn send_stream_with_retry(
317        &self,
318        messages: Vec<Message>,
319        config: &LlmRequestConfig,
320        max_retries: u32,
321    ) -> Result<LlmResponse> {
322        let mut attempts = 0;
323        let mut last_error: Option<anyhow::Error> = None;
324
325        while attempts < max_retries {
326            attempts += 1;
327
328            match self
329                .send_stream_with_callback(messages.clone(), config, Box::new(|_| {}))
330                .await
331            {
332                Ok(response) => {
333                    info!("Stream request succeeded after {} attempts", attempts);
334                    return Ok(response);
335                }
336                Err(e) => {
337                    let error_msg = e.to_string();
338
339                    if error_msg.contains("rate limit")
340                        || error_msg.contains("429")
341                        || error_msg.contains("overloaded")
342                        || error_msg.contains("timeout")
343                        || error_msg.contains("aborted")
344                    {
345                        warn!(
346                            "Stream request failed (attempt {}/{}): {}",
347                            attempts, max_retries, e
348                        );
349                        last_error = Some(e);
350
351                        let delay = std::cmp::min(1000 * 2u64.pow(attempts - 1), 30000);
352                        tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
353                    } else {
354                        return Err(e);
355                    }
356                }
357            }
358        }
359
360        Err(last_error.unwrap_or_else(|| anyhow!("Max retries exceeded")))
361    }
362}
363
364#[async_trait]
365impl LlmClientTrait for LlmClient {
366    async fn send(&self, messages: Vec<Message>, config: &LlmRequestConfig) -> Result<LlmResponse> {
367        match self.provider {
368            LlmProvider::Anthropic | LlmProvider::AnthropicCompatible { .. } => {
369                self.send_anthropic(messages, config).await
370            }
371            LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. } => {
372                self.send_openai(messages, config).await
373            }
374            LlmProvider::Gemini => self.send_gemini(messages, config).await,
375            LlmProvider::AzureOpenAI => self.send_azure_openai(messages, config).await,
376            LlmProvider::Bedrock => self.send_bedrock(messages, config).await,
377            LlmProvider::Ollama => self.send_ollama(messages, config).await,
378            LlmProvider::Custom(_) => {
379                Err(anyhow!("Custom provider requires custom implementation. Use an OpenAI-compatible provider instead."))
380            }
381        }
382    }
383
384    async fn send_stream(
385        &self,
386        messages: Vec<Message>,
387        config: &LlmRequestConfig,
388    ) -> Result<MessageStream> {
389        match self.provider {
390            LlmProvider::Anthropic | LlmProvider::AnthropicCompatible { .. } => {
391                self.stream_anthropic(messages, config).await
392            }
393            LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. } => {
394                self.stream_openai(messages, config).await
395            }
396            LlmProvider::Gemini => self.stream_gemini(messages, config).await,
397            LlmProvider::AzureOpenAI => self.stream_azure_openai(messages, config).await,
398            LlmProvider::Bedrock => self.stream_bedrock(messages, config).await,
399            LlmProvider::Ollama => self.stream_ollama(messages, config).await,
400            LlmProvider::Custom(_) => Err(anyhow!("Custom provider does not support streaming. Use an OpenAI-compatible provider instead.")),
401        }
402    }
403}
404
405impl LlmClient {
406    /// Construct the messages endpoint URL for Anthropic API
407    ///
408    /// Handles three cases:
409    /// 1. Official Anthropic API: https://api.anthropic.com -> https://api.anthropic.com/v1/messages
410    /// 2. Already contains full path: https://api.example.com/anthropic/messages -> unchanged
411    /// 3. Anthropic-compatible endpoint (contains /anthropic): https://api.example.com/anthropic -> /messages
412    /// 4. Already contains v1: https://api.example.com/v1 -> https://api.example.com/v1/messages
413    pub fn build_anthropic_messages_url(base_url: &str) -> String {
414        let base = base_url.trim_end_matches('/');
415
416        // If URL already ends with /messages, return as-is
417        if base.ends_with("/messages") {
418            return base.to_string();
419        }
420
421        // If URL ends with /v1, just append /messages
422        if base.ends_with("/v1") {
423            return format!("{}/messages", base);
424        }
425
426        // If URL contains /anthropic (Anthropic-compatible endpoint), just append /messages
427        // This handles third-party Anthropic-compatible endpoints like Tencent Coding
428        if base.contains("/anthropic") {
429            return format!("{}/messages", base);
430        }
431
432        // Otherwise, append /v1/messages (official Anthropic API case)
433        format!("{}/v1/messages", base)
434    }
435
436    async fn send_anthropic(
437        &self,
438        messages: Vec<Message>,
439        config: &LlmRequestConfig,
440    ) -> Result<LlmResponse> {
441        let url = Self::build_anthropic_messages_url(&self.base_url);
442
443        let request_body = AnthropicRequest {
444            model: config.model.clone(),
445            max_tokens: config.max_tokens,
446            messages: messages
447                .into_iter()
448                .map(|m| AnthropicMessage {
449                    role: match m.role {
450                        MessageRole::User => "user",
451                        MessageRole::Assistant => "assistant",
452                        MessageRole::System => "system",
453                    },
454                    content: AnthropicContent::Text(m.content),
455                })
456                .collect(),
457            system: config.system_prompt.clone(),
458            temperature: config.temperature,
459        };
460
461        let response = self
462            .client
463            .post(&url)
464            .header("x-api-key", &self.api_key)
465            .header("anthropic-version", "2023-06-01")
466            .json(&request_body)
467            .send()
468            .await?;
469
470        let response_text = response.text().await?;
471        tracing::debug!("Anthropic API response: {}", response_text);
472
473        let response_body: AnthropicResponse = serde_json::from_str(&response_text)?;
474
475        Ok(LlmResponse {
476            content: response_body
477                .content
478                .first()
479                .map(|c| c.text.clone())
480                .unwrap_or_default(),
481            usage: TokenUsage {
482                input_tokens: response_body.usage.input_tokens,
483                output_tokens: response_body.usage.output_tokens,
484            },
485            model: response_body.model,
486            response_id: response_body.id,
487        })
488    }
489
490    async fn stream_anthropic(
491        &self,
492        messages: Vec<Message>,
493        config: &LlmRequestConfig,
494    ) -> Result<MessageStream> {
495        let url = Self::build_anthropic_messages_url(&self.base_url);
496
497        let request_body = AnthropicStreamRequest {
498            model: config.model.clone(),
499            max_tokens: config.max_tokens,
500            messages: messages
501                .into_iter()
502                .map(|m| AnthropicMessage {
503                    role: match m.role {
504                        MessageRole::User => "user",
505                        MessageRole::Assistant => "assistant",
506                        MessageRole::System => "system",
507                    },
508                    content: AnthropicContent::Text(m.content),
509                })
510                .collect(),
511            system: config.system_prompt.clone(),
512            temperature: config.temperature,
513            stream: true,
514        };
515
516        let response = self
517            .client
518            .post(&url)
519            .header("x-api-key", &self.api_key)
520            .header("anthropic-version", "2023-06-01")
521            .header("Accept", "text/event-stream")
522            .json(&request_body)
523            .send()
524            .await?;
525
526        let status = response.status();
527        if !status.is_success() {
528            let error_text = response.text().await?;
529            return Err(anyhow!("Anthropic API error {}: {}", status, error_text));
530        }
531
532        Ok(MessageStream::new(
533            response,
534            match self.provider {
535                LlmProvider::Anthropic => StreamProvider::Anthropic,
536                LlmProvider::AnthropicCompatible { .. } => StreamProvider::AnthropicCompatible,
537                _ => StreamProvider::Anthropic, // fallback
538            },
539            config.model.clone(),
540        ))
541    }
542
543    async fn send_openai(
544        &self,
545        messages: Vec<Message>,
546        config: &LlmRequestConfig,
547    ) -> Result<LlmResponse> {
548        let url = format!("{}/chat/completions", self.base_url);
549
550        let mut openai_messages: Vec<OpenAiMessage> = Vec::new();
551
552        if let Some(ref system) = config.system_prompt {
553            openai_messages.push(OpenAiMessage {
554                role: "system",
555                content: system.clone(),
556            });
557        }
558
559        for m in messages {
560            openai_messages.push(OpenAiMessage {
561                role: match m.role {
562                    MessageRole::User => "user",
563                    MessageRole::Assistant => "assistant",
564                    MessageRole::System => "system",
565                },
566                content: m.content,
567            });
568        }
569
570        let request_body = OpenAiRequest {
571            model: config.model.clone(),
572            messages: openai_messages,
573            max_tokens: Some(config.max_tokens),
574            temperature: Some(config.temperature),
575            stop: if config.stop_sequences.is_empty() {
576                None
577            } else {
578                Some(config.stop_sequences.clone())
579            },
580        };
581
582        let response = self
583            .client
584            .post(&url)
585            .header("Authorization", format!("Bearer {}", self.api_key))
586            .json(&request_body)
587            .send()
588            .await?;
589
590        let response_body: OpenAiResponse = response.json().await?;
591
592        let choice = response_body
593            .choices
594            .first()
595            .ok_or_else(|| anyhow!("No response choices"))?;
596
597        Ok(LlmResponse {
598            content: choice.message.content.clone(),
599            usage: TokenUsage {
600                input_tokens: response_body.usage.prompt_tokens,
601                output_tokens: response_body.usage.completion_tokens,
602            },
603            model: response_body.model,
604            response_id: response_body.id,
605        })
606    }
607
608    async fn stream_openai(
609        &self,
610        messages: Vec<Message>,
611        config: &LlmRequestConfig,
612    ) -> Result<MessageStream> {
613        let url = format!("{}/chat/completions", self.base_url);
614
615        let mut openai_messages: Vec<OpenAiMessage> = Vec::new();
616        if let Some(ref system) = config.system_prompt {
617            openai_messages.push(OpenAiMessage {
618                role: "system",
619                content: system.clone(),
620            });
621        }
622        for m in messages {
623            openai_messages.push(OpenAiMessage {
624                role: match m.role {
625                    MessageRole::User => "user",
626                    MessageRole::Assistant => "assistant",
627                    MessageRole::System => "system",
628                },
629                content: m.content,
630            });
631        }
632
633        let request_body = OpenAiStreamRequest {
634            model: config.model.clone(),
635            messages: openai_messages,
636            max_tokens: Some(config.max_tokens),
637            temperature: Some(config.temperature),
638            stream: true,
639        };
640
641        let response = self
642            .client
643            .post(&url)
644            .header("Authorization", format!("Bearer {}", self.api_key))
645            .header("Accept", "text/event-stream")
646            .json(&request_body)
647            .send()
648            .await?;
649
650        let status = response.status();
651        if !status.is_success() {
652            let error_text = response.text().await?;
653            return Err(anyhow!("OpenAI API error {}: {}", status, error_text));
654        }
655
656        Ok(MessageStream::new(
657            response,
658            match self.provider {
659                LlmProvider::OpenAI => StreamProvider::OpenAI,
660                LlmProvider::OpenAICompatible { .. } => StreamProvider::OpenAICompatible,
661                _ => StreamProvider::OpenAI, // fallback
662            },
663            config.model.clone(),
664        ))
665    }
666
667    async fn send_gemini(
668        &self,
669        messages: Vec<Message>,
670        config: &LlmRequestConfig,
671    ) -> Result<LlmResponse> {
672        let url = format!(
673            "{}/models/{}:generateContent?key={}",
674            self.base_url, config.model, self.api_key
675        );
676
677        let mut contents: Vec<GeminiContent> = Vec::new();
678        let system_instruction = config.system_prompt.clone();
679
680        for m in messages {
681            contents.push(GeminiContent {
682                role: match m.role {
683                    MessageRole::User => "user".to_string(),
684                    MessageRole::Assistant => "model".to_string(),
685                    MessageRole::System => "user".to_string(),
686                },
687                parts: vec![GeminiPart { text: m.content }],
688            });
689        }
690
691        let request_body = GeminiRequest {
692            contents,
693            generation_config: Some(GeminiGenerationConfig {
694                max_output_tokens: Some(config.max_tokens),
695                temperature: Some(config.temperature),
696                stop_sequences: if config.stop_sequences.is_empty() {
697                    None
698                } else {
699                    Some(config.stop_sequences.clone())
700                },
701            }),
702            system_instruction: system_instruction.map(|s| GeminiSystemInstruction {
703                parts: vec![GeminiPart { text: s }],
704            }),
705        };
706
707        let response = self.client.post(&url).json(&request_body).send().await?;
708
709        let response_body: GeminiResponse = response.json().await?;
710
711        let candidate = response_body
712            .candidates
713            .first()
714            .ok_or_else(|| anyhow!("No response candidates"))?;
715
716        let content = candidate
717            .content
718            .parts
719            .first()
720            .map(|p| p.text.clone())
721            .unwrap_or_default();
722
723        Ok(LlmResponse {
724            content,
725            usage: TokenUsage {
726                input_tokens: response_body.usage_metadata.prompt_token_count.unwrap_or(0),
727                output_tokens: response_body
728                    .usage_metadata
729                    .candidates_token_count
730                    .unwrap_or(0),
731            },
732            model: config.model.clone(),
733            response_id: "".to_string(),
734        })
735    }
736
737    async fn stream_gemini(
738        &self,
739        messages: Vec<Message>,
740        config: &LlmRequestConfig,
741    ) -> Result<MessageStream> {
742        let url = format!(
743            "{}/models/{}:streamGenerateContent?key={}&alt=sse",
744            self.base_url, config.model, self.api_key
745        );
746
747        let mut contents: Vec<GeminiContent> = Vec::new();
748        let system_instruction = config.system_prompt.clone();
749
750        for m in messages {
751            contents.push(GeminiContent {
752                role: match m.role {
753                    MessageRole::User => "user".to_string(),
754                    MessageRole::Assistant => "model".to_string(),
755                    MessageRole::System => "user".to_string(),
756                },
757                parts: vec![GeminiPart { text: m.content }],
758            });
759        }
760
761        let request_body = GeminiRequest {
762            contents,
763            generation_config: Some(GeminiGenerationConfig {
764                max_output_tokens: Some(config.max_tokens),
765                temperature: Some(config.temperature),
766                stop_sequences: if config.stop_sequences.is_empty() {
767                    None
768                } else {
769                    Some(config.stop_sequences.clone())
770                },
771            }),
772            system_instruction: system_instruction.map(|s| GeminiSystemInstruction {
773                parts: vec![GeminiPart { text: s }],
774            }),
775        };
776
777        let response = self.client.post(&url).json(&request_body).send().await?;
778
779        let status = response.status();
780        if !status.is_success() {
781            let error_text = response.text().await?;
782            return Err(anyhow!("Gemini API error {}: {}", status, error_text));
783        }
784
785        Ok(MessageStream::new(
786            response,
787            StreamProvider::Gemini,
788            config.model.clone(),
789        ))
790    }
791
792    // ========================================================================
793    // Azure OpenAI 实现
794    // ========================================================================
795
796    async fn send_azure_openai(
797        &self,
798        messages: Vec<Message>,
799        config: &LlmRequestConfig,
800    ) -> Result<LlmResponse> {
801        // Azure OpenAI 使用 deployment name 而非 model name
802        // URL format: {base_url}/openai/deployments/{deployment}/chat/completions?api-version=2024-02-15-preview
803        let deployment = &config.model;
804        let url = format!(
805            "{}/openai/deployments/{}/chat/completions?api-version=2024-02-15-preview",
806            self.base_url, deployment
807        );
808
809        let mut azure_messages: Vec<OpenAiMessage> = Vec::new();
810        if let Some(ref system) = config.system_prompt {
811            azure_messages.push(OpenAiMessage {
812                role: "system",
813                content: system.clone(),
814            });
815        }
816        for m in messages {
817            azure_messages.push(OpenAiMessage {
818                role: match m.role {
819                    MessageRole::User => "user",
820                    MessageRole::Assistant => "assistant",
821                    MessageRole::System => "system",
822                },
823                content: m.content,
824            });
825        }
826
827        let request_body = OpenAiRequest {
828            model: deployment.clone(), // Azure 使用 deployment name
829            messages: azure_messages,
830            max_tokens: Some(config.max_tokens),
831            temperature: Some(config.temperature),
832            stop: if config.stop_sequences.is_empty() {
833                None
834            } else {
835                Some(config.stop_sequences.clone())
836            },
837        };
838
839        let response = self
840            .client
841            .post(&url)
842            .header("api-key", &self.api_key) // Azure 使用 api-key header 而非 Authorization
843            .json(&request_body)
844            .send()
845            .await?;
846
847        let status = response.status();
848        if !status.is_success() {
849            let error_text = response.text().await?;
850            return Err(anyhow!("Azure OpenAI API error {}: {}", status, error_text));
851        }
852
853        let response_body: OpenAiResponse = response.json().await?;
854
855        let choice = response_body
856            .choices
857            .first()
858            .ok_or_else(|| anyhow!("No response choices"))?;
859
860        Ok(LlmResponse {
861            content: choice.message.content.clone(),
862            usage: TokenUsage {
863                input_tokens: response_body.usage.prompt_tokens,
864                output_tokens: response_body.usage.completion_tokens,
865            },
866            model: response_body.model,
867            response_id: response_body.id,
868        })
869    }
870
871    async fn stream_azure_openai(
872        &self,
873        messages: Vec<Message>,
874        config: &LlmRequestConfig,
875    ) -> Result<MessageStream> {
876        let deployment = &config.model;
877        let url = format!(
878            "{}/openai/deployments/{}/chat/completions?api-version=2024-02-15-preview",
879            self.base_url, deployment
880        );
881
882        let mut azure_messages: Vec<OpenAiMessage> = Vec::new();
883        if let Some(ref system) = config.system_prompt {
884            azure_messages.push(OpenAiMessage {
885                role: "system",
886                content: system.clone(),
887            });
888        }
889        for m in messages {
890            azure_messages.push(OpenAiMessage {
891                role: match m.role {
892                    MessageRole::User => "user",
893                    MessageRole::Assistant => "assistant",
894                    MessageRole::System => "system",
895                },
896                content: m.content,
897            });
898        }
899
900        let request_body = OpenAiStreamRequest {
901            model: deployment.clone(),
902            messages: azure_messages,
903            max_tokens: Some(config.max_tokens),
904            temperature: Some(config.temperature),
905            stream: true,
906        };
907
908        let response = self
909            .client
910            .post(&url)
911            .header("api-key", &self.api_key)
912            .header("Accept", "text/event-stream")
913            .json(&request_body)
914            .send()
915            .await?;
916
917        let status = response.status();
918        if !status.is_success() {
919            let error_text = response.text().await?;
920            return Err(anyhow!("Azure OpenAI API error {}: {}", status, error_text));
921        }
922
923        Ok(MessageStream::new(
924            response,
925            StreamProvider::AzureOpenAI,
926            config.model.clone(),
927        ))
928    }
929
930    // ========================================================================
931    // AWS Bedrock 实现
932    // ========================================================================
933
934    async fn send_bedrock(
935        &self,
936        messages: Vec<Message>,
937        config: &LlmRequestConfig,
938    ) -> Result<LlmResponse> {
939        // Bedrock URL: {base_url}/model/{model_id}/invoke
940        // 需要 AWS SigV4 签名认证(简化版使用 API key 作为临时方案)
941        let model_id = &config.model;
942        let url = format!("{}/model/{}/invoke", self.base_url, model_id);
943
944        // Bedrock 请求格式因模型不同而异,这里使用通用的 Converse API 格式
945        let mut bedrock_messages: Vec<BedrockMessage> = Vec::new();
946        for m in messages {
947            bedrock_messages.push(BedrockMessage {
948                role: match m.role {
949                    MessageRole::User => "user",
950                    MessageRole::Assistant => "assistant",
951                    MessageRole::System => "system",
952                },
953                content: vec![BedrockContent { text: m.content }],
954            });
955        }
956
957        let request_body = BedrockRequest {
958            messages: bedrock_messages,
959            system: config.system_prompt.clone(),
960            inference_config: Some(BedrockInferenceConfig {
961                max_tokens: config.max_tokens,
962                temperature: config.temperature,
963                top_p: None,
964                stop_sequences: if config.stop_sequences.is_empty() {
965                    None
966                } else {
967                    Some(config.stop_sequences.clone())
968                },
969            }),
970        };
971
972        let response = self
973            .client
974            .post(&url)
975            .header("Authorization", format!("Bearer {}", self.api_key))
976            .header("Content-Type", "application/json")
977            .json(&request_body)
978            .send()
979            .await?;
980
981        let status = response.status();
982        if !status.is_success() {
983            let error_text = response.text().await?;
984            return Err(anyhow!("Bedrock API error {}: {}", status, error_text));
985        }
986
987        let response_body: BedrockResponse = response.json().await?;
988
989        let content = response_body
990            .output
991            .message
992            .content
993            .first()
994            .map(|c| c.text.clone())
995            .unwrap_or_default();
996
997        Ok(LlmResponse {
998            content,
999            usage: TokenUsage {
1000                input_tokens: response_body.usage.input_tokens,
1001                output_tokens: response_body.usage.output_tokens,
1002            },
1003            model: config.model.clone(),
1004            response_id: response_body.request_id.unwrap_or_default(),
1005        })
1006    }
1007
1008    async fn stream_bedrock(
1009        &self,
1010        messages: Vec<Message>,
1011        config: &LlmRequestConfig,
1012    ) -> Result<MessageStream> {
1013        let model_id = &config.model;
1014        let url = format!(
1015            "{}/model/{}/invoke-with-response-stream",
1016            self.base_url, model_id
1017        );
1018
1019        let mut bedrock_messages: Vec<BedrockMessage> = Vec::new();
1020        for m in messages {
1021            bedrock_messages.push(BedrockMessage {
1022                role: match m.role {
1023                    MessageRole::User => "user",
1024                    MessageRole::Assistant => "assistant",
1025                    MessageRole::System => "system",
1026                },
1027                content: vec![BedrockContent { text: m.content }],
1028            });
1029        }
1030
1031        let request_body = BedrockRequest {
1032            messages: bedrock_messages,
1033            system: config.system_prompt.clone(),
1034            inference_config: Some(BedrockInferenceConfig {
1035                max_tokens: config.max_tokens,
1036                temperature: config.temperature,
1037                top_p: None,
1038                stop_sequences: if config.stop_sequences.is_empty() {
1039                    None
1040                } else {
1041                    Some(config.stop_sequences.clone())
1042                },
1043            }),
1044        };
1045
1046        let response = self
1047            .client
1048            .post(&url)
1049            .header("Authorization", format!("Bearer {}", self.api_key))
1050            .header("Accept", "text/event-stream")
1051            .header("Content-Type", "application/json")
1052            .json(&request_body)
1053            .send()
1054            .await?;
1055
1056        let status = response.status();
1057        if !status.is_success() {
1058            let error_text = response.text().await?;
1059            return Err(anyhow!("Bedrock API error {}: {}", status, error_text));
1060        }
1061
1062        Ok(MessageStream::new(
1063            response,
1064            StreamProvider::Bedrock,
1065            config.model.clone(),
1066        ))
1067    }
1068
1069    // ========================================================================
1070    // Ollama (本地) 实现
1071    // ========================================================================
1072
1073    async fn send_ollama(
1074        &self,
1075        messages: Vec<Message>,
1076        config: &LlmRequestConfig,
1077    ) -> Result<LlmResponse> {
1078        // Ollama API: POST /api/chat 或 /api/generate
1079        let url = format!("{}/api/chat", self.base_url);
1080
1081        let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
1082        if let Some(ref system) = config.system_prompt {
1083            ollama_messages.push(OllamaMessage {
1084                role: "system",
1085                content: system.clone(),
1086            });
1087        }
1088        for m in messages {
1089            ollama_messages.push(OllamaMessage {
1090                role: match m.role {
1091                    MessageRole::User => "user",
1092                    MessageRole::Assistant => "assistant",
1093                    MessageRole::System => "system",
1094                },
1095                content: m.content,
1096            });
1097        }
1098
1099        let request_body = OllamaChatRequest {
1100            model: config.model.clone(),
1101            messages: ollama_messages,
1102            stream: false,
1103            options: Some(OllamaOptions {
1104                num_predict: config.max_tokens as i32,
1105                temperature: config.temperature,
1106                stop: if config.stop_sequences.is_empty() {
1107                    None
1108                } else {
1109                    Some(config.stop_sequences.clone())
1110                },
1111            }),
1112        };
1113
1114        // Ollama 本地运行,通常无需 API key
1115        let response = self
1116            .client
1117            .post(&url)
1118            .header("Content-Type", "application/json")
1119            .json(&request_body)
1120            .send()
1121            .await?;
1122
1123        let status = response.status();
1124        if !status.is_success() {
1125            let error_text = response.text().await?;
1126            return Err(anyhow!("Ollama API error {}: {}", status, error_text));
1127        }
1128
1129        let response_body: OllamaChatResponse = response.json().await?;
1130
1131        Ok(LlmResponse {
1132            content: response_body.message.content,
1133            usage: TokenUsage {
1134                input_tokens: response_body.prompt_eval_count.unwrap_or(0),
1135                output_tokens: response_body.eval_count.unwrap_or(0),
1136            },
1137            model: response_body.model,
1138            response_id: "".to_string(),
1139        })
1140    }
1141
1142    async fn stream_ollama(
1143        &self,
1144        messages: Vec<Message>,
1145        config: &LlmRequestConfig,
1146    ) -> Result<MessageStream> {
1147        let url = format!("{}/api/chat", self.base_url);
1148
1149        let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
1150        if let Some(ref system) = config.system_prompt {
1151            ollama_messages.push(OllamaMessage {
1152                role: "system",
1153                content: system.clone(),
1154            });
1155        }
1156        for m in messages {
1157            ollama_messages.push(OllamaMessage {
1158                role: match m.role {
1159                    MessageRole::User => "user",
1160                    MessageRole::Assistant => "assistant",
1161                    MessageRole::System => "system",
1162                },
1163                content: m.content,
1164            });
1165        }
1166
1167        let request_body = OllamaChatRequest {
1168            model: config.model.clone(),
1169            messages: ollama_messages,
1170            stream: true,
1171            options: Some(OllamaOptions {
1172                num_predict: config.max_tokens as i32,
1173                temperature: config.temperature,
1174                stop: if config.stop_sequences.is_empty() {
1175                    None
1176                } else {
1177                    Some(config.stop_sequences.clone())
1178                },
1179            }),
1180        };
1181
1182        let response = self
1183            .client
1184            .post(&url)
1185            .header("Accept", "application/json")
1186            .header("Content-Type", "application/json")
1187            .json(&request_body)
1188            .send()
1189            .await?;
1190
1191        let status = response.status();
1192        if !status.is_success() {
1193            let error_text = response.text().await?;
1194            return Err(anyhow!("Ollama API error {}: {}", status, error_text));
1195        }
1196
1197        Ok(MessageStream::new(
1198            response,
1199            StreamProvider::Ollama,
1200            config.model.clone(),
1201        ))
1202    }
1203}
1204
1205// Anthropic API 结构
1206#[derive(Serialize)]
1207struct AnthropicRequest {
1208    model: String,
1209    max_tokens: u32,
1210    messages: Vec<AnthropicMessage>,
1211    system: Option<String>,
1212    temperature: f32,
1213}
1214
1215#[derive(Serialize)]
1216struct AnthropicStreamRequest {
1217    model: String,
1218    max_tokens: u32,
1219    messages: Vec<AnthropicMessage>,
1220    system: Option<String>,
1221    temperature: f32,
1222    stream: bool,
1223}
1224
1225#[derive(Serialize)]
1226struct AnthropicMessage {
1227    role: &'static str,
1228    content: AnthropicContent,
1229}
1230
1231#[derive(Serialize)]
1232#[serde(untagged)]
1233#[allow(dead_code)]
1234enum AnthropicContent {
1235    Text(String),
1236    Blocks(Vec<AnthropicContentBlock>),
1237}
1238
1239#[derive(Serialize)]
1240struct AnthropicContentBlock {
1241    #[serde(rename = "type")]
1242    content_type: String,
1243    text: String,
1244}
1245
1246#[derive(Deserialize)]
1247#[allow(dead_code)]
1248struct AnthropicResponse {
1249    #[serde(default)]
1250    id: String,
1251    #[serde(default)]
1252    model: String,
1253    #[serde(default)]
1254    content: Vec<AnthropicContentResponse>,
1255    #[serde(default)]
1256    usage: AnthropicUsage,
1257    #[serde(default)]
1258    #[serde(rename = "type")]
1259    response_type: Option<String>,
1260    #[serde(default)]
1261    role: Option<String>,
1262    #[serde(default)]
1263    stop_reason: Option<String>,
1264}
1265
1266#[derive(Deserialize)]
1267#[allow(dead_code)]
1268struct AnthropicContentResponse {
1269    #[serde(rename = "type", default)]
1270    content_type: String,
1271    #[serde(default)]
1272    text: String,
1273}
1274
1275#[derive(Deserialize, Default)]
1276struct AnthropicUsage {
1277    #[serde(default)]
1278    input_tokens: u32,
1279    #[serde(default)]
1280    output_tokens: u32,
1281}
1282
1283// OpenAI API 结构
1284#[derive(Serialize)]
1285struct OpenAiRequest {
1286    model: String,
1287    messages: Vec<OpenAiMessage>,
1288    #[serde(skip_serializing_if = "Option::is_none")]
1289    max_tokens: Option<u32>,
1290    #[serde(skip_serializing_if = "Option::is_none")]
1291    temperature: Option<f32>,
1292    #[serde(skip_serializing_if = "Option::is_none")]
1293    stop: Option<Vec<String>>,
1294}
1295
1296#[derive(Serialize)]
1297struct OpenAiStreamRequest {
1298    model: String,
1299    messages: Vec<OpenAiMessage>,
1300    #[serde(skip_serializing_if = "Option::is_none")]
1301    max_tokens: Option<u32>,
1302    #[serde(skip_serializing_if = "Option::is_none")]
1303    temperature: Option<f32>,
1304    stream: bool,
1305}
1306
1307#[derive(Serialize)]
1308struct OpenAiMessage {
1309    role: &'static str,
1310    content: String,
1311}
1312
1313#[derive(Deserialize)]
1314struct OpenAiResponse {
1315    id: String,
1316    model: String,
1317    choices: Vec<OpenAiChoice>,
1318    usage: OpenAiUsage,
1319}
1320
1321#[derive(Deserialize)]
1322#[allow(dead_code)]
1323struct OpenAiChoice {
1324    message: OpenAiResponseMessage,
1325    finish_reason: String,
1326}
1327
1328#[derive(Deserialize)]
1329#[allow(dead_code)]
1330struct OpenAiResponseMessage {
1331    role: String,
1332    content: String,
1333}
1334
1335#[derive(Deserialize)]
1336#[allow(dead_code)]
1337struct OpenAiUsage {
1338    prompt_tokens: u32,
1339    completion_tokens: u32,
1340    total_tokens: u32,
1341}
1342
1343// Gemini API 结构
1344#[derive(Serialize)]
1345struct GeminiRequest {
1346    contents: Vec<GeminiContent>,
1347    #[serde(skip_serializing_if = "Option::is_none")]
1348    generation_config: Option<GeminiGenerationConfig>,
1349    #[serde(skip_serializing_if = "Option::is_none")]
1350    system_instruction: Option<GeminiSystemInstruction>,
1351}
1352
1353#[derive(Serialize)]
1354struct GeminiContent {
1355    role: String,
1356    parts: Vec<GeminiPart>,
1357}
1358
1359#[derive(Serialize)]
1360struct GeminiPart {
1361    text: String,
1362}
1363
1364#[derive(Serialize)]
1365struct GeminiGenerationConfig {
1366    #[serde(skip_serializing_if = "Option::is_none")]
1367    max_output_tokens: Option<u32>,
1368    #[serde(skip_serializing_if = "Option::is_none")]
1369    temperature: Option<f32>,
1370    #[serde(skip_serializing_if = "Option::is_none")]
1371    stop_sequences: Option<Vec<String>>,
1372}
1373
1374#[derive(Serialize)]
1375struct GeminiSystemInstruction {
1376    parts: Vec<GeminiPart>,
1377}
1378
1379#[derive(Deserialize)]
1380struct GeminiResponse {
1381    candidates: Vec<GeminiCandidate>,
1382    usage_metadata: GeminiUsageMetadata,
1383}
1384
1385#[derive(Deserialize)]
1386#[allow(dead_code)]
1387struct GeminiCandidate {
1388    content: GeminiContentResponse,
1389    finish_reason: String,
1390}
1391
1392#[derive(Deserialize)]
1393#[allow(dead_code)]
1394struct GeminiContentResponse {
1395    parts: Vec<GeminiPartResponse>,
1396    role: String,
1397}
1398
1399#[derive(Deserialize)]
1400struct GeminiPartResponse {
1401    text: String,
1402}
1403
1404#[derive(Deserialize)]
1405#[allow(dead_code)]
1406struct GeminiUsageMetadata {
1407    prompt_token_count: Option<u32>,
1408    candidates_token_count: Option<u32>,
1409    total_token_count: Option<u32>,
1410}
1411
1412// ========================================================================
1413// AWS Bedrock API 结构
1414// ========================================================================
1415
1416#[derive(Serialize)]
1417struct BedrockRequest {
1418    messages: Vec<BedrockMessage>,
1419    #[serde(skip_serializing_if = "Option::is_none")]
1420    system: Option<String>,
1421    #[serde(skip_serializing_if = "Option::is_none")]
1422    inference_config: Option<BedrockInferenceConfig>,
1423}
1424
1425#[derive(Serialize)]
1426struct BedrockMessage {
1427    role: &'static str,
1428    content: Vec<BedrockContent>,
1429}
1430
1431#[derive(Serialize)]
1432struct BedrockContent {
1433    text: String,
1434}
1435
1436#[derive(Serialize)]
1437struct BedrockInferenceConfig {
1438    #[serde(rename = "maxTokens")]
1439    max_tokens: u32,
1440    temperature: f32,
1441    #[serde(skip_serializing_if = "Option::is_none")]
1442    top_p: Option<f32>,
1443    #[serde(skip_serializing_if = "Option::is_none")]
1444    stop_sequences: Option<Vec<String>>,
1445}
1446
1447#[derive(Deserialize)]
1448#[allow(dead_code)]
1449struct BedrockResponse {
1450    output: BedrockOutput,
1451    usage: BedrockUsage,
1452    #[serde(default)]
1453    request_id: Option<String>,
1454}
1455
1456#[derive(Deserialize)]
1457struct BedrockOutput {
1458    message: BedrockResponseMessage,
1459}
1460
1461#[derive(Deserialize)]
1462struct BedrockResponseMessage {
1463    content: Vec<BedrockResponseContent>,
1464}
1465
1466#[derive(Deserialize)]
1467struct BedrockResponseContent {
1468    text: String,
1469}
1470
1471#[derive(Deserialize)]
1472struct BedrockUsage {
1473    #[serde(default)]
1474    input_tokens: u32,
1475    #[serde(default)]
1476    output_tokens: u32,
1477}
1478
1479// ========================================================================
1480// Ollama API 结构
1481// ========================================================================
1482
1483#[derive(Serialize)]
1484struct OllamaChatRequest {
1485    model: String,
1486    messages: Vec<OllamaMessage>,
1487    stream: bool,
1488    #[serde(skip_serializing_if = "Option::is_none")]
1489    options: Option<OllamaOptions>,
1490}
1491
1492#[derive(Serialize)]
1493struct OllamaMessage {
1494    role: &'static str,
1495    content: String,
1496}
1497
1498#[derive(Serialize)]
1499struct OllamaOptions {
1500    num_predict: i32,
1501    temperature: f32,
1502    #[serde(skip_serializing_if = "Option::is_none")]
1503    stop: Option<Vec<String>>,
1504}
1505
1506#[derive(Deserialize)]
1507struct OllamaChatResponse {
1508    model: String,
1509    message: OllamaResponseMessage,
1510    #[serde(default)]
1511    prompt_eval_count: Option<u32>,
1512    #[serde(default)]
1513    eval_count: Option<u32>,
1514}
1515
1516#[derive(Deserialize)]
1517struct OllamaResponseMessage {
1518    content: String,
1519}
1520
1521#[cfg(test)]
1522mod tests {
1523    use super::*;
1524
1525    #[test]
1526    fn test_default_config() {
1527        let config = LlmRequestConfig::default();
1528        assert_eq!(config.model, "claude-sonnet-4-6");
1529        assert_eq!(config.max_tokens, 4096);
1530    }
1531
1532    #[test]
1533    fn test_client_creation() {
1534        let client = LlmClient::new(LlmProvider::Anthropic, "test_key".to_string());
1535        assert_eq!(client.base_url, "https://api.anthropic.com/v1");
1536    }
1537
1538    #[test]
1539    fn test_openai_client_creation() {
1540        let client = LlmClient::new(LlmProvider::OpenAI, "test_key".to_string());
1541        assert_eq!(client.base_url, "https://api.openai.com/v1");
1542    }
1543
1544    #[test]
1545    fn test_gemini_client_creation() {
1546        let client = LlmClient::new(LlmProvider::Gemini, "test_key".to_string());
1547        assert_eq!(
1548            client.base_url,
1549            "https://generativelanguage.googleapis.com/v1"
1550        );
1551    }
1552
1553    #[test]
1554    fn test_custom_provider() {
1555        let client = LlmClient::new(
1556            LlmProvider::Custom("https://custom.api.com/v1".to_string()),
1557            "test_key".to_string(),
1558        );
1559        assert_eq!(client.base_url, "https://custom.api.com/v1");
1560    }
1561
1562    #[test]
1563    fn test_openai_compatible_provider() {
1564        let client = LlmClient::new(
1565            LlmProvider::OpenAICompatible {
1566                base_url: "https://api.deepseek.com/v1".to_string(),
1567            },
1568            "test_key".to_string(),
1569        );
1570        assert_eq!(client.base_url, "https://api.deepseek.com/v1");
1571    }
1572
1573    #[test]
1574    fn test_azure_openai_client_creation() {
1575        let client = LlmClient::new(LlmProvider::AzureOpenAI, "test_key".to_string());
1576        assert!(client.base_url.contains("openai.azure.com"));
1577    }
1578
1579    #[test]
1580    fn test_bedrock_client_creation() {
1581        let client = LlmClient::new(LlmProvider::Bedrock, "test_key".to_string());
1582        assert!(client.base_url.contains("bedrock-runtime"));
1583    }
1584
1585    #[test]
1586    fn test_ollama_client_creation() {
1587        let client = LlmClient::new(LlmProvider::Ollama, "".to_string());
1588        assert_eq!(client.base_url, "http://localhost:11434");
1589    }
1590
1591    #[test]
1592    fn test_azure_openai_with_custom_url() {
1593        let client = LlmClient::new(LlmProvider::AzureOpenAI, "test_key".to_string())
1594            .with_base_url("https://myresource.openai.azure.com".to_string());
1595        assert_eq!(client.base_url, "https://myresource.openai.azure.com");
1596    }
1597
1598    #[test]
1599    fn test_ollama_with_custom_url() {
1600        let client = LlmClient::new(LlmProvider::Ollama, "".to_string())
1601            .with_base_url("http://192.168.1.100:11434".to_string());
1602        assert_eq!(client.base_url, "http://192.168.1.100:11434");
1603    }
1604
1605    #[test]
1606    fn test_message_creation() {
1607        let message = Message {
1608            role: MessageRole::User,
1609            content: "Hello".to_string(),
1610        };
1611        assert_eq!(message.content, "Hello");
1612    }
1613
1614    #[test]
1615    fn test_config_with_system_prompt() {
1616        let config = LlmRequestConfig {
1617            model: "gpt-4".to_string(),
1618            max_tokens: 8192,
1619            temperature: 0.5,
1620            system_prompt: Some("You are a helpful assistant".to_string()),
1621            stop_sequences: vec![],
1622        };
1623        assert_eq!(config.model, "gpt-4");
1624        assert!(config.system_prompt.is_some());
1625    }
1626
1627    #[test]
1628    fn test_llm_response_creation() {
1629        let response = LlmResponse {
1630            content: "Hello".to_string(),
1631            usage: TokenUsage {
1632                input_tokens: 10,
1633                output_tokens: 5,
1634            },
1635            model: "gpt-4".to_string(),
1636            response_id: "resp_123".to_string(),
1637        };
1638        assert_eq!(response.content, "Hello");
1639        assert_eq!(response.usage.input_tokens, 10);
1640    }
1641
1642    #[test]
1643    fn test_provider_serialization() {
1644        let provider = LlmProvider::Anthropic;
1645        let json = serde_json::to_string(&provider).unwrap();
1646        assert!(json.contains("Anthropic"));
1647    }
1648
1649    #[test]
1650    fn test_message_role_serialization() {
1651        let role = MessageRole::User;
1652        let json = serde_json::to_string(&role).unwrap();
1653        assert!(json.contains("User"));
1654    }
1655
1656    // AnthropicCompatible provider tests
1657    #[test]
1658    fn test_anthropic_compatible_provider_creation() {
1659        let client = LlmClient::new(
1660            LlmProvider::AnthropicCompatible {
1661                base_url: "https://api.lkeap.cloud.tencent.com/coding/anthropic".to_string(),
1662            },
1663            "test_key".to_string(),
1664        );
1665        assert_eq!(
1666            client.base_url,
1667            "https://api.lkeap.cloud.tencent.com/coding/anthropic"
1668        );
1669    }
1670
1671    #[test]
1672    fn test_anthropic_compatible_provider_serialization() {
1673        let provider = LlmProvider::AnthropicCompatible {
1674            base_url: "https://example.com".to_string(),
1675        };
1676        let json = serde_json::to_string(&provider).unwrap();
1677        assert!(json.contains("anthropic_compatible") || json.contains("AnthropicCompatible"));
1678    }
1679
1680    // URL construction tests for build_anthropic_messages_url
1681    #[test]
1682    fn test_build_anthropic_messages_url_official_api() {
1683        let url = LlmClient::build_anthropic_messages_url("https://api.anthropic.com");
1684        assert_eq!(url, "https://api.anthropic.com/v1/messages");
1685    }
1686
1687    #[test]
1688    fn test_build_anthropic_messages_url_already_has_v1() {
1689        let url = LlmClient::build_anthropic_messages_url("https://api.anthropic.com/v1");
1690        assert_eq!(url, "https://api.anthropic.com/v1/messages");
1691    }
1692
1693    #[test]
1694    fn test_build_anthropic_messages_url_already_has_messages() {
1695        let url =
1696            LlmClient::build_anthropic_messages_url("https://api.example.com/anthropic/messages");
1697        assert_eq!(url, "https://api.example.com/anthropic/messages");
1698    }
1699
1700    #[test]
1701    fn test_build_anthropic_messages_url_tencent_endpoint() {
1702        let url = LlmClient::build_anthropic_messages_url(
1703            "https://api.lkeap.cloud.tencent.com/coding/anthropic",
1704        );
1705        assert_eq!(
1706            url,
1707            "https://api.lkeap.cloud.tencent.com/coding/anthropic/messages"
1708        );
1709    }
1710
1711    #[test]
1712    fn test_build_anthropic_messages_url_with_trailing_slash() {
1713        let url = LlmClient::build_anthropic_messages_url("https://api.anthropic.com/v1/");
1714        assert_eq!(url, "https://api.anthropic.com/v1/messages");
1715    }
1716
1717    // Provider routing tests
1718    #[test]
1719    fn test_provider_routing_anthropic_compatible() {
1720        // Verify AnthropicCompatible routes to Anthropic format
1721        let provider = LlmProvider::AnthropicCompatible {
1722            base_url: "https://example.com".to_string(),
1723        };
1724        assert!(matches!(
1725            provider,
1726            LlmProvider::Anthropic | LlmProvider::AnthropicCompatible { .. }
1727        ));
1728    }
1729
1730    #[test]
1731    fn test_provider_routing_openai_compatible() {
1732        // Verify OpenAICompatible routes to OpenAI format
1733        let provider = LlmProvider::OpenAICompatible {
1734            base_url: "https://example.com".to_string(),
1735        };
1736        assert!(matches!(
1737            provider,
1738            LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. }
1739        ));
1740    }
1741}