Skip to main content

sh_layer1/
llm_client.rs

1//! LLM 客户端模块
2//!
3//! 统一的 LLM API 客户端,支持多提供商。
4//!
5//! [STABLE] 基础请求功能完整
6//! [STABLE] 流式响应支持 Anthropic/OpenAI 格式
7
8use anyhow::{anyhow, Result};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use tracing::{info, warn};
15
16use crate::streaming::{
17    CallbackStream, ContentDelta, MessageStream, OnChunkCallback, StreamEvent, StreamProvider,
18};
19
20/// LLM 提供商类型
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum LlmProvider {
23    Anthropic,
24    OpenAI,
25    Gemini,
26    AzureOpenAI,
27    Bedrock,
28    Ollama,
29    /// OpenAI-compatible provider with custom base_url (e.g. deepseek, glm, qwen, kimi, grok)
30    OpenAICompatible {
31        base_url: String,
32    },
33    Custom(String),
34}
35
36/// LLM 请求配置
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct LlmRequestConfig {
39    /// 模型名称
40    pub model: String,
41    /// 最大 token 数
42    pub max_tokens: u32,
43    /// 温度参数
44    pub temperature: f32,
45    /// 系统提示
46    pub system_prompt: Option<String>,
47    /// 停止词
48    pub stop_sequences: Vec<String>,
49}
50
51impl Default for LlmRequestConfig {
52    fn default() -> Self {
53        Self {
54            model: "claude-sonnet-4-6".to_string(),
55            max_tokens: 4096,
56            temperature: 0.7,
57            system_prompt: None,
58            stop_sequences: vec!["\n\n\n".to_string()],
59        }
60    }
61}
62
63/// LLM 响应
64#[derive(Debug, Serialize, Deserialize)]
65pub struct LlmResponse {
66    /// 响应内容
67    pub content: String,
68    /// Token 使用情况
69    pub usage: TokenUsage,
70    /// 模型名称
71    pub model: String,
72    /// 响应 ID
73    pub response_id: String,
74}
75
76/// Token 使用情况
77#[derive(Debug, Serialize, Deserialize)]
78pub struct TokenUsage {
79    /// 输入 token 数
80    pub input_tokens: u32,
81    /// 输出 token 数
82    pub output_tokens: u32,
83}
84
85/// LLM 客户端 trait
86#[async_trait]
87pub trait LlmClientTrait {
88    /// 发送请求并获取响应
89    async fn send(&self, messages: Vec<Message>, config: &LlmRequestConfig) -> Result<LlmResponse>;
90
91    /// 发送请求并流式获取响应
92    async fn send_stream(
93        &self,
94        messages: Vec<Message>,
95        config: &LlmRequestConfig,
96    ) -> Result<MessageStream>;
97}
98
99/// 消息
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct Message {
102    pub role: MessageRole,
103    pub content: String,
104}
105
106/// 消息角色
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum MessageRole {
109    User,
110    Assistant,
111    System,
112}
113
114/// LLM 客户端实现
115pub struct LlmClient {
116    /// HTTP 客户端
117    client: Client,
118    /// API 密钥
119    api_key: String,
120    /// 提供商
121    provider: LlmProvider,
122    /// API 基础 URL
123    base_url: String,
124}
125
126impl LlmClient {
127    pub fn new(provider: LlmProvider, api_key: String) -> Self {
128        let base_url = match &provider {
129            LlmProvider::Anthropic => "https://api.anthropic.com/v1".to_string(),
130            LlmProvider::OpenAI => "https://api.openai.com/v1".to_string(),
131            LlmProvider::Gemini => "https://generativelanguage.googleapis.com/v1".to_string(),
132            LlmProvider::AzureOpenAI => "https://YOUR_RESOURCE.openai.azure.com".to_string(),
133            LlmProvider::Bedrock => "https://bedrock-runtime.us-east-1.amazonaws.com".to_string(),
134            LlmProvider::Ollama => "http://localhost:11434".to_string(),
135            LlmProvider::OpenAICompatible { base_url } => base_url.clone(),
136            LlmProvider::Custom(url) => url.clone(),
137        };
138
139        Self {
140            client: Client::new(),
141            api_key,
142            provider,
143            base_url,
144        }
145    }
146
147    /// 创建客户端并指定自定义 base_url(覆盖 provider 默认值)
148    pub fn with_base_url(mut self, base_url: String) -> Self {
149        self.base_url = base_url;
150        self
151    }
152
153    /// 发送带回调的流式请求
154    pub async fn send_stream_with_callback(
155        &self,
156        messages: Vec<Message>,
157        config: &LlmRequestConfig,
158        on_chunk: OnChunkCallback,
159    ) -> Result<LlmResponse> {
160        let message_stream = self.send_stream(messages, config).await?;
161        let mut callback_stream = CallbackStream::new(message_stream, Some(on_chunk));
162
163        let mut content = String::new();
164        let mut input_tokens = 0u32;
165        let mut output_tokens = 0u32;
166        let mut message_id = String::new();
167        let mut model = config.model.clone();
168
169        while let Some(event) = callback_stream.next_event().await? {
170            match event {
171                StreamEvent::MessageStart { id, model: m } => {
172                    message_id = id;
173                    model = m;
174                }
175                StreamEvent::ContentBlockDelta {
176                    delta: ContentDelta::Text(t),
177                    ..
178                } => {
179                    content.push_str(&t);
180                }
181                StreamEvent::ContentBlockDelta { .. } => {}
182                StreamEvent::MessageDelta { usage, .. } => {
183                    input_tokens = usage.input_tokens;
184                    output_tokens = usage.output_tokens;
185                }
186                _ => {}
187            }
188        }
189
190        Ok(LlmResponse {
191            content,
192            usage: TokenUsage {
193                input_tokens,
194                output_tokens,
195            },
196            model,
197            response_id: message_id,
198        })
199    }
200
201    /// 发送可中断的流式请求
202    pub async fn send_stream_abortable(
203        &self,
204        messages: Vec<Message>,
205        config: &LlmRequestConfig,
206        abort_flag: Arc<AtomicBool>,
207    ) -> Result<LlmResponse> {
208        let message_stream = self.send_stream(messages, config).await?;
209        let mut callback_stream = CallbackStream::new(message_stream, None);
210
211        let mut content = String::new();
212        let mut input_tokens = 0u32;
213        let mut output_tokens = 0u32;
214        let mut message_id = String::new();
215        let mut model = config.model.clone();
216
217        while !abort_flag.load(Ordering::Relaxed) {
218            match callback_stream.next_event().await {
219                Ok(Some(event)) => match event {
220                    StreamEvent::MessageStart { id, model: m } => {
221                        message_id = id;
222                        model = m;
223                    }
224                    StreamEvent::ContentBlockDelta {
225                        delta: ContentDelta::Text(t),
226                        ..
227                    } => {
228                        content.push_str(&t);
229                    }
230                    StreamEvent::ContentBlockDelta { .. } => {}
231                    StreamEvent::MessageDelta { usage, .. } => {
232                        input_tokens = usage.input_tokens;
233                        output_tokens = usage.output_tokens;
234                    }
235                    StreamEvent::MessageStop => {
236                        break;
237                    }
238                    _ => {}
239                },
240                Ok(None) => break,
241                Err(e) => {
242                    if abort_flag.load(Ordering::Relaxed) {
243                        info!("Stream aborted by user");
244                        break;
245                    }
246                    return Err(e);
247                }
248            }
249        }
250
251        if abort_flag.load(Ordering::Relaxed) {
252            info!("Stream was aborted");
253        }
254
255        Ok(LlmResponse {
256            content,
257            usage: TokenUsage {
258                input_tokens,
259                output_tokens,
260            },
261            model,
262            response_id: message_id,
263        })
264    }
265
266    /// 带错误恢复的请求重试
267    pub async fn send_with_retry(
268        &self,
269        messages: Vec<Message>,
270        config: &LlmRequestConfig,
271        max_retries: u32,
272    ) -> Result<LlmResponse> {
273        let mut attempts = 0;
274        let mut last_error: Option<anyhow::Error> = None;
275
276        while attempts < max_retries {
277            attempts += 1;
278
279            match self.send(messages.clone(), config).await {
280                Ok(response) => {
281                    info!("LLM request succeeded after {} attempts", attempts);
282                    return Ok(response);
283                }
284                Err(e) => {
285                    let error_msg = e.to_string();
286
287                    if error_msg.contains("rate limit")
288                        || error_msg.contains("429")
289                        || error_msg.contains("overloaded")
290                        || error_msg.contains("timeout")
291                    {
292                        warn!(
293                            "LLM request failed (attempt {}/{}): {}",
294                            attempts, max_retries, e
295                        );
296                        last_error = Some(e);
297
298                        let delay = std::cmp::min(1000 * 2u64.pow(attempts - 1), 30000);
299                        tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
300                    } else {
301                        return Err(e);
302                    }
303                }
304            }
305        }
306
307        Err(last_error.unwrap_or_else(|| anyhow!("Max retries exceeded")))
308    }
309
310    /// 带错误恢复的流式请求重试
311    pub async fn send_stream_with_retry(
312        &self,
313        messages: Vec<Message>,
314        config: &LlmRequestConfig,
315        max_retries: u32,
316    ) -> Result<LlmResponse> {
317        let mut attempts = 0;
318        let mut last_error: Option<anyhow::Error> = None;
319
320        while attempts < max_retries {
321            attempts += 1;
322
323            match self
324                .send_stream_with_callback(messages.clone(), config, Box::new(|_| {}))
325                .await
326            {
327                Ok(response) => {
328                    info!("Stream request succeeded after {} attempts", attempts);
329                    return Ok(response);
330                }
331                Err(e) => {
332                    let error_msg = e.to_string();
333
334                    if error_msg.contains("rate limit")
335                        || error_msg.contains("429")
336                        || error_msg.contains("overloaded")
337                        || error_msg.contains("timeout")
338                        || error_msg.contains("aborted")
339                    {
340                        warn!(
341                            "Stream request failed (attempt {}/{}): {}",
342                            attempts, max_retries, e
343                        );
344                        last_error = Some(e);
345
346                        let delay = std::cmp::min(1000 * 2u64.pow(attempts - 1), 30000);
347                        tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
348                    } else {
349                        return Err(e);
350                    }
351                }
352            }
353        }
354
355        Err(last_error.unwrap_or_else(|| anyhow!("Max retries exceeded")))
356    }
357}
358
359#[async_trait]
360impl LlmClientTrait for LlmClient {
361    async fn send(&self, messages: Vec<Message>, config: &LlmRequestConfig) -> Result<LlmResponse> {
362        match self.provider {
363            LlmProvider::Anthropic => self.send_anthropic(messages, config).await,
364            LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. } => {
365                self.send_openai(messages, config).await
366            }
367            LlmProvider::Gemini => self.send_gemini(messages, config).await,
368            LlmProvider::AzureOpenAI => self.send_azure_openai(messages, config).await,
369            LlmProvider::Bedrock => self.send_bedrock(messages, config).await,
370            LlmProvider::Ollama => self.send_ollama(messages, config).await,
371            LlmProvider::Custom(_) => {
372                Err(anyhow!("Custom provider requires custom implementation. Use an OpenAI-compatible provider instead."))
373            }
374        }
375    }
376
377    async fn send_stream(
378        &self,
379        messages: Vec<Message>,
380        config: &LlmRequestConfig,
381    ) -> Result<MessageStream> {
382        match self.provider {
383            LlmProvider::Anthropic => self.stream_anthropic(messages, config).await,
384            LlmProvider::OpenAI | LlmProvider::OpenAICompatible { .. } => {
385                self.stream_openai(messages, config).await
386            }
387            LlmProvider::Gemini => self.stream_gemini(messages, config).await,
388            LlmProvider::AzureOpenAI => self.stream_azure_openai(messages, config).await,
389            LlmProvider::Bedrock => self.stream_bedrock(messages, config).await,
390            LlmProvider::Ollama => self.stream_ollama(messages, config).await,
391            LlmProvider::Custom(_) => Err(anyhow!("Custom provider does not support streaming. Use an OpenAI-compatible provider instead.")),
392        }
393    }
394}
395
396impl LlmClient {
397    async fn send_anthropic(
398        &self,
399        messages: Vec<Message>,
400        config: &LlmRequestConfig,
401    ) -> Result<LlmResponse> {
402        let url = if self.base_url.ends_with("/v1") || self.base_url.ends_with("/v1/") {
403            format!("{}/messages", self.base_url.trim_end_matches('/'))
404        } else {
405            format!("{}/v1/messages", self.base_url.trim_end_matches('/'))
406        };
407
408        let request_body = AnthropicRequest {
409            model: config.model.clone(),
410            max_tokens: config.max_tokens,
411            messages: messages
412                .into_iter()
413                .map(|m| AnthropicMessage {
414                    role: match m.role {
415                        MessageRole::User => "user",
416                        MessageRole::Assistant => "assistant",
417                        MessageRole::System => "system",
418                    },
419                    content: AnthropicContent::Text(m.content),
420                })
421                .collect(),
422            system: config.system_prompt.clone(),
423            temperature: config.temperature,
424        };
425
426        let response = self
427            .client
428            .post(&url)
429            .header("x-api-key", &self.api_key)
430            .header("anthropic-version", "2023-06-01")
431            .json(&request_body)
432            .send()
433            .await?;
434
435        let response_text = response.text().await?;
436        tracing::debug!("Anthropic API response: {}", response_text);
437
438        let response_body: AnthropicResponse = serde_json::from_str(&response_text)?;
439
440        Ok(LlmResponse {
441            content: response_body
442                .content
443                .first()
444                .map(|c| c.text.clone())
445                .unwrap_or_default(),
446            usage: TokenUsage {
447                input_tokens: response_body.usage.input_tokens,
448                output_tokens: response_body.usage.output_tokens,
449            },
450            model: response_body.model,
451            response_id: response_body.id,
452        })
453    }
454
455    async fn stream_anthropic(
456        &self,
457        messages: Vec<Message>,
458        config: &LlmRequestConfig,
459    ) -> Result<MessageStream> {
460        let url = if self.base_url.ends_with("/v1") || self.base_url.ends_with("/v1/") {
461            format!("{}/messages", self.base_url.trim_end_matches('/'))
462        } else {
463            format!("{}/v1/messages", self.base_url.trim_end_matches('/'))
464        };
465
466        let request_body = AnthropicStreamRequest {
467            model: config.model.clone(),
468            max_tokens: config.max_tokens,
469            messages: messages
470                .into_iter()
471                .map(|m| AnthropicMessage {
472                    role: match m.role {
473                        MessageRole::User => "user",
474                        MessageRole::Assistant => "assistant",
475                        MessageRole::System => "system",
476                    },
477                    content: AnthropicContent::Text(m.content),
478                })
479                .collect(),
480            system: config.system_prompt.clone(),
481            temperature: config.temperature,
482            stream: true,
483        };
484
485        let response = self
486            .client
487            .post(&url)
488            .header("x-api-key", &self.api_key)
489            .header("anthropic-version", "2023-06-01")
490            .header("Accept", "text/event-stream")
491            .json(&request_body)
492            .send()
493            .await?;
494
495        let status = response.status();
496        if !status.is_success() {
497            let error_text = response.text().await?;
498            return Err(anyhow!("Anthropic API error {}: {}", status, error_text));
499        }
500
501        Ok(MessageStream::new(
502            response,
503            StreamProvider::Anthropic,
504            config.model.clone(),
505        ))
506    }
507
508    async fn send_openai(
509        &self,
510        messages: Vec<Message>,
511        config: &LlmRequestConfig,
512    ) -> Result<LlmResponse> {
513        let url = format!("{}/chat/completions", self.base_url);
514
515        let mut openai_messages: Vec<OpenAiMessage> = Vec::new();
516
517        if let Some(ref system) = config.system_prompt {
518            openai_messages.push(OpenAiMessage {
519                role: "system",
520                content: system.clone(),
521            });
522        }
523
524        for m in messages {
525            openai_messages.push(OpenAiMessage {
526                role: match m.role {
527                    MessageRole::User => "user",
528                    MessageRole::Assistant => "assistant",
529                    MessageRole::System => "system",
530                },
531                content: m.content,
532            });
533        }
534
535        let request_body = OpenAiRequest {
536            model: config.model.clone(),
537            messages: openai_messages,
538            max_tokens: Some(config.max_tokens),
539            temperature: Some(config.temperature),
540            stop: if config.stop_sequences.is_empty() {
541                None
542            } else {
543                Some(config.stop_sequences.clone())
544            },
545        };
546
547        let response = self
548            .client
549            .post(&url)
550            .header("Authorization", format!("Bearer {}", self.api_key))
551            .json(&request_body)
552            .send()
553            .await?;
554
555        let response_body: OpenAiResponse = response.json().await?;
556
557        let choice = response_body
558            .choices
559            .first()
560            .ok_or_else(|| anyhow!("No response choices"))?;
561
562        Ok(LlmResponse {
563            content: choice.message.content.clone(),
564            usage: TokenUsage {
565                input_tokens: response_body.usage.prompt_tokens,
566                output_tokens: response_body.usage.completion_tokens,
567            },
568            model: response_body.model,
569            response_id: response_body.id,
570        })
571    }
572
573    async fn stream_openai(
574        &self,
575        messages: Vec<Message>,
576        config: &LlmRequestConfig,
577    ) -> Result<MessageStream> {
578        let url = format!("{}/chat/completions", self.base_url);
579
580        let mut openai_messages: Vec<OpenAiMessage> = Vec::new();
581        if let Some(ref system) = config.system_prompt {
582            openai_messages.push(OpenAiMessage {
583                role: "system",
584                content: system.clone(),
585            });
586        }
587        for m in messages {
588            openai_messages.push(OpenAiMessage {
589                role: match m.role {
590                    MessageRole::User => "user",
591                    MessageRole::Assistant => "assistant",
592                    MessageRole::System => "system",
593                },
594                content: m.content,
595            });
596        }
597
598        let request_body = OpenAiStreamRequest {
599            model: config.model.clone(),
600            messages: openai_messages,
601            max_tokens: Some(config.max_tokens),
602            temperature: Some(config.temperature),
603            stream: true,
604        };
605
606        let response = self
607            .client
608            .post(&url)
609            .header("Authorization", format!("Bearer {}", self.api_key))
610            .header("Accept", "text/event-stream")
611            .json(&request_body)
612            .send()
613            .await?;
614
615        let status = response.status();
616        if !status.is_success() {
617            let error_text = response.text().await?;
618            return Err(anyhow!("OpenAI API error {}: {}", status, error_text));
619        }
620
621        Ok(MessageStream::new(
622            response,
623            StreamProvider::OpenAI,
624            config.model.clone(),
625        ))
626    }
627
628    async fn send_gemini(
629        &self,
630        messages: Vec<Message>,
631        config: &LlmRequestConfig,
632    ) -> Result<LlmResponse> {
633        let url = format!(
634            "{}/models/{}:generateContent?key={}",
635            self.base_url, config.model, self.api_key
636        );
637
638        let mut contents: Vec<GeminiContent> = Vec::new();
639        let system_instruction = config.system_prompt.clone();
640
641        for m in messages {
642            contents.push(GeminiContent {
643                role: match m.role {
644                    MessageRole::User => "user".to_string(),
645                    MessageRole::Assistant => "model".to_string(),
646                    MessageRole::System => "user".to_string(),
647                },
648                parts: vec![GeminiPart { text: m.content }],
649            });
650        }
651
652        let request_body = GeminiRequest {
653            contents,
654            generation_config: Some(GeminiGenerationConfig {
655                max_output_tokens: Some(config.max_tokens),
656                temperature: Some(config.temperature),
657                stop_sequences: if config.stop_sequences.is_empty() {
658                    None
659                } else {
660                    Some(config.stop_sequences.clone())
661                },
662            }),
663            system_instruction: system_instruction.map(|s| GeminiSystemInstruction {
664                parts: vec![GeminiPart { text: s }],
665            }),
666        };
667
668        let response = self.client.post(&url).json(&request_body).send().await?;
669
670        let response_body: GeminiResponse = response.json().await?;
671
672        let candidate = response_body
673            .candidates
674            .first()
675            .ok_or_else(|| anyhow!("No response candidates"))?;
676
677        let content = candidate
678            .content
679            .parts
680            .first()
681            .map(|p| p.text.clone())
682            .unwrap_or_default();
683
684        Ok(LlmResponse {
685            content,
686            usage: TokenUsage {
687                input_tokens: response_body.usage_metadata.prompt_token_count.unwrap_or(0),
688                output_tokens: response_body
689                    .usage_metadata
690                    .candidates_token_count
691                    .unwrap_or(0),
692            },
693            model: config.model.clone(),
694            response_id: "".to_string(),
695        })
696    }
697
698    async fn stream_gemini(
699        &self,
700        messages: Vec<Message>,
701        config: &LlmRequestConfig,
702    ) -> Result<MessageStream> {
703        let url = format!(
704            "{}/models/{}:streamGenerateContent?key={}&alt=sse",
705            self.base_url, config.model, self.api_key
706        );
707
708        let mut contents: Vec<GeminiContent> = Vec::new();
709        let system_instruction = config.system_prompt.clone();
710
711        for m in messages {
712            contents.push(GeminiContent {
713                role: match m.role {
714                    MessageRole::User => "user".to_string(),
715                    MessageRole::Assistant => "model".to_string(),
716                    MessageRole::System => "user".to_string(),
717                },
718                parts: vec![GeminiPart { text: m.content }],
719            });
720        }
721
722        let request_body = GeminiRequest {
723            contents,
724            generation_config: Some(GeminiGenerationConfig {
725                max_output_tokens: Some(config.max_tokens),
726                temperature: Some(config.temperature),
727                stop_sequences: if config.stop_sequences.is_empty() {
728                    None
729                } else {
730                    Some(config.stop_sequences.clone())
731                },
732            }),
733            system_instruction: system_instruction.map(|s| GeminiSystemInstruction {
734                parts: vec![GeminiPart { text: s }],
735            }),
736        };
737
738        let response = self.client.post(&url).json(&request_body).send().await?;
739
740        let status = response.status();
741        if !status.is_success() {
742            let error_text = response.text().await?;
743            return Err(anyhow!("Gemini API error {}: {}", status, error_text));
744        }
745
746        Ok(MessageStream::new(
747            response,
748            StreamProvider::Gemini,
749            config.model.clone(),
750        ))
751    }
752
753    // ========================================================================
754    // Azure OpenAI 实现
755    // ========================================================================
756
757    async fn send_azure_openai(
758        &self,
759        messages: Vec<Message>,
760        config: &LlmRequestConfig,
761    ) -> Result<LlmResponse> {
762        // Azure OpenAI 使用 deployment name 而非 model name
763        // URL format: {base_url}/openai/deployments/{deployment}/chat/completions?api-version=2024-02-15-preview
764        let deployment = &config.model;
765        let url = format!(
766            "{}/openai/deployments/{}/chat/completions?api-version=2024-02-15-preview",
767            self.base_url, deployment
768        );
769
770        let mut azure_messages: Vec<OpenAiMessage> = Vec::new();
771        if let Some(ref system) = config.system_prompt {
772            azure_messages.push(OpenAiMessage {
773                role: "system",
774                content: system.clone(),
775            });
776        }
777        for m in messages {
778            azure_messages.push(OpenAiMessage {
779                role: match m.role {
780                    MessageRole::User => "user",
781                    MessageRole::Assistant => "assistant",
782                    MessageRole::System => "system",
783                },
784                content: m.content,
785            });
786        }
787
788        let request_body = OpenAiRequest {
789            model: deployment.clone(), // Azure 使用 deployment name
790            messages: azure_messages,
791            max_tokens: Some(config.max_tokens),
792            temperature: Some(config.temperature),
793            stop: if config.stop_sequences.is_empty() {
794                None
795            } else {
796                Some(config.stop_sequences.clone())
797            },
798        };
799
800        let response = self
801            .client
802            .post(&url)
803            .header("api-key", &self.api_key) // Azure 使用 api-key header 而非 Authorization
804            .json(&request_body)
805            .send()
806            .await?;
807
808        let status = response.status();
809        if !status.is_success() {
810            let error_text = response.text().await?;
811            return Err(anyhow!("Azure OpenAI API error {}: {}", status, error_text));
812        }
813
814        let response_body: OpenAiResponse = response.json().await?;
815
816        let choice = response_body
817            .choices
818            .first()
819            .ok_or_else(|| anyhow!("No response choices"))?;
820
821        Ok(LlmResponse {
822            content: choice.message.content.clone(),
823            usage: TokenUsage {
824                input_tokens: response_body.usage.prompt_tokens,
825                output_tokens: response_body.usage.completion_tokens,
826            },
827            model: response_body.model,
828            response_id: response_body.id,
829        })
830    }
831
832    async fn stream_azure_openai(
833        &self,
834        messages: Vec<Message>,
835        config: &LlmRequestConfig,
836    ) -> Result<MessageStream> {
837        let deployment = &config.model;
838        let url = format!(
839            "{}/openai/deployments/{}/chat/completions?api-version=2024-02-15-preview",
840            self.base_url, deployment
841        );
842
843        let mut azure_messages: Vec<OpenAiMessage> = Vec::new();
844        if let Some(ref system) = config.system_prompt {
845            azure_messages.push(OpenAiMessage {
846                role: "system",
847                content: system.clone(),
848            });
849        }
850        for m in messages {
851            azure_messages.push(OpenAiMessage {
852                role: match m.role {
853                    MessageRole::User => "user",
854                    MessageRole::Assistant => "assistant",
855                    MessageRole::System => "system",
856                },
857                content: m.content,
858            });
859        }
860
861        let request_body = OpenAiStreamRequest {
862            model: deployment.clone(),
863            messages: azure_messages,
864            max_tokens: Some(config.max_tokens),
865            temperature: Some(config.temperature),
866            stream: true,
867        };
868
869        let response = self
870            .client
871            .post(&url)
872            .header("api-key", &self.api_key)
873            .header("Accept", "text/event-stream")
874            .json(&request_body)
875            .send()
876            .await?;
877
878        let status = response.status();
879        if !status.is_success() {
880            let error_text = response.text().await?;
881            return Err(anyhow!("Azure OpenAI API error {}: {}", status, error_text));
882        }
883
884        Ok(MessageStream::new(
885            response,
886            StreamProvider::AzureOpenAI,
887            config.model.clone(),
888        ))
889    }
890
891    // ========================================================================
892    // AWS Bedrock 实现
893    // ========================================================================
894
895    async fn send_bedrock(
896        &self,
897        messages: Vec<Message>,
898        config: &LlmRequestConfig,
899    ) -> Result<LlmResponse> {
900        // Bedrock URL: {base_url}/model/{model_id}/invoke
901        // 需要 AWS SigV4 签名认证(简化版使用 API key 作为临时方案)
902        let model_id = &config.model;
903        let url = format!("{}/model/{}/invoke", self.base_url, model_id);
904
905        // Bedrock 请求格式因模型不同而异,这里使用通用的 Converse API 格式
906        let mut bedrock_messages: Vec<BedrockMessage> = Vec::new();
907        for m in messages {
908            bedrock_messages.push(BedrockMessage {
909                role: match m.role {
910                    MessageRole::User => "user",
911                    MessageRole::Assistant => "assistant",
912                    MessageRole::System => "system",
913                },
914                content: vec![BedrockContent { text: m.content }],
915            });
916        }
917
918        let request_body = BedrockRequest {
919            messages: bedrock_messages,
920            system: config.system_prompt.clone(),
921            inference_config: Some(BedrockInferenceConfig {
922                max_tokens: config.max_tokens,
923                temperature: config.temperature,
924                top_p: None,
925                stop_sequences: if config.stop_sequences.is_empty() {
926                    None
927                } else {
928                    Some(config.stop_sequences.clone())
929                },
930            }),
931        };
932
933        let response = self
934            .client
935            .post(&url)
936            .header("Authorization", format!("Bearer {}", self.api_key))
937            .header("Content-Type", "application/json")
938            .json(&request_body)
939            .send()
940            .await?;
941
942        let status = response.status();
943        if !status.is_success() {
944            let error_text = response.text().await?;
945            return Err(anyhow!("Bedrock API error {}: {}", status, error_text));
946        }
947
948        let response_body: BedrockResponse = response.json().await?;
949
950        let content = response_body
951            .output
952            .message
953            .content
954            .first()
955            .map(|c| c.text.clone())
956            .unwrap_or_default();
957
958        Ok(LlmResponse {
959            content,
960            usage: TokenUsage {
961                input_tokens: response_body.usage.input_tokens,
962                output_tokens: response_body.usage.output_tokens,
963            },
964            model: config.model.clone(),
965            response_id: response_body.request_id.unwrap_or_default(),
966        })
967    }
968
969    async fn stream_bedrock(
970        &self,
971        messages: Vec<Message>,
972        config: &LlmRequestConfig,
973    ) -> Result<MessageStream> {
974        let model_id = &config.model;
975        let url = format!(
976            "{}/model/{}/invoke-with-response-stream",
977            self.base_url, model_id
978        );
979
980        let mut bedrock_messages: Vec<BedrockMessage> = Vec::new();
981        for m in messages {
982            bedrock_messages.push(BedrockMessage {
983                role: match m.role {
984                    MessageRole::User => "user",
985                    MessageRole::Assistant => "assistant",
986                    MessageRole::System => "system",
987                },
988                content: vec![BedrockContent { text: m.content }],
989            });
990        }
991
992        let request_body = BedrockRequest {
993            messages: bedrock_messages,
994            system: config.system_prompt.clone(),
995            inference_config: Some(BedrockInferenceConfig {
996                max_tokens: config.max_tokens,
997                temperature: config.temperature,
998                top_p: None,
999                stop_sequences: if config.stop_sequences.is_empty() {
1000                    None
1001                } else {
1002                    Some(config.stop_sequences.clone())
1003                },
1004            }),
1005        };
1006
1007        let response = self
1008            .client
1009            .post(&url)
1010            .header("Authorization", format!("Bearer {}", self.api_key))
1011            .header("Accept", "text/event-stream")
1012            .header("Content-Type", "application/json")
1013            .json(&request_body)
1014            .send()
1015            .await?;
1016
1017        let status = response.status();
1018        if !status.is_success() {
1019            let error_text = response.text().await?;
1020            return Err(anyhow!("Bedrock API error {}: {}", status, error_text));
1021        }
1022
1023        Ok(MessageStream::new(
1024            response,
1025            StreamProvider::Bedrock,
1026            config.model.clone(),
1027        ))
1028    }
1029
1030    // ========================================================================
1031    // Ollama (本地) 实现
1032    // ========================================================================
1033
1034    async fn send_ollama(
1035        &self,
1036        messages: Vec<Message>,
1037        config: &LlmRequestConfig,
1038    ) -> Result<LlmResponse> {
1039        // Ollama API: POST /api/chat 或 /api/generate
1040        let url = format!("{}/api/chat", self.base_url);
1041
1042        let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
1043        if let Some(ref system) = config.system_prompt {
1044            ollama_messages.push(OllamaMessage {
1045                role: "system",
1046                content: system.clone(),
1047            });
1048        }
1049        for m in messages {
1050            ollama_messages.push(OllamaMessage {
1051                role: match m.role {
1052                    MessageRole::User => "user",
1053                    MessageRole::Assistant => "assistant",
1054                    MessageRole::System => "system",
1055                },
1056                content: m.content,
1057            });
1058        }
1059
1060        let request_body = OllamaChatRequest {
1061            model: config.model.clone(),
1062            messages: ollama_messages,
1063            stream: false,
1064            options: Some(OllamaOptions {
1065                num_predict: config.max_tokens as i32,
1066                temperature: config.temperature,
1067                stop: if config.stop_sequences.is_empty() {
1068                    None
1069                } else {
1070                    Some(config.stop_sequences.clone())
1071                },
1072            }),
1073        };
1074
1075        // Ollama 本地运行,通常无需 API key
1076        let response = self
1077            .client
1078            .post(&url)
1079            .header("Content-Type", "application/json")
1080            .json(&request_body)
1081            .send()
1082            .await?;
1083
1084        let status = response.status();
1085        if !status.is_success() {
1086            let error_text = response.text().await?;
1087            return Err(anyhow!("Ollama API error {}: {}", status, error_text));
1088        }
1089
1090        let response_body: OllamaChatResponse = response.json().await?;
1091
1092        Ok(LlmResponse {
1093            content: response_body.message.content,
1094            usage: TokenUsage {
1095                input_tokens: response_body.prompt_eval_count.unwrap_or(0),
1096                output_tokens: response_body.eval_count.unwrap_or(0),
1097            },
1098            model: response_body.model,
1099            response_id: "".to_string(),
1100        })
1101    }
1102
1103    async fn stream_ollama(
1104        &self,
1105        messages: Vec<Message>,
1106        config: &LlmRequestConfig,
1107    ) -> Result<MessageStream> {
1108        let url = format!("{}/api/chat", self.base_url);
1109
1110        let mut ollama_messages: Vec<OllamaMessage> = Vec::new();
1111        if let Some(ref system) = config.system_prompt {
1112            ollama_messages.push(OllamaMessage {
1113                role: "system",
1114                content: system.clone(),
1115            });
1116        }
1117        for m in messages {
1118            ollama_messages.push(OllamaMessage {
1119                role: match m.role {
1120                    MessageRole::User => "user",
1121                    MessageRole::Assistant => "assistant",
1122                    MessageRole::System => "system",
1123                },
1124                content: m.content,
1125            });
1126        }
1127
1128        let request_body = OllamaChatRequest {
1129            model: config.model.clone(),
1130            messages: ollama_messages,
1131            stream: true,
1132            options: Some(OllamaOptions {
1133                num_predict: config.max_tokens as i32,
1134                temperature: config.temperature,
1135                stop: if config.stop_sequences.is_empty() {
1136                    None
1137                } else {
1138                    Some(config.stop_sequences.clone())
1139                },
1140            }),
1141        };
1142
1143        let response = self
1144            .client
1145            .post(&url)
1146            .header("Accept", "application/json")
1147            .header("Content-Type", "application/json")
1148            .json(&request_body)
1149            .send()
1150            .await?;
1151
1152        let status = response.status();
1153        if !status.is_success() {
1154            let error_text = response.text().await?;
1155            return Err(anyhow!("Ollama API error {}: {}", status, error_text));
1156        }
1157
1158        Ok(MessageStream::new(
1159            response,
1160            StreamProvider::Ollama,
1161            config.model.clone(),
1162        ))
1163    }
1164}
1165
1166// Anthropic API 结构
1167#[derive(Serialize)]
1168struct AnthropicRequest {
1169    model: String,
1170    max_tokens: u32,
1171    messages: Vec<AnthropicMessage>,
1172    system: Option<String>,
1173    temperature: f32,
1174}
1175
1176#[derive(Serialize)]
1177struct AnthropicStreamRequest {
1178    model: String,
1179    max_tokens: u32,
1180    messages: Vec<AnthropicMessage>,
1181    system: Option<String>,
1182    temperature: f32,
1183    stream: bool,
1184}
1185
1186#[derive(Serialize)]
1187struct AnthropicMessage {
1188    role: &'static str,
1189    content: AnthropicContent,
1190}
1191
1192#[derive(Serialize)]
1193#[serde(untagged)]
1194#[allow(dead_code)]
1195enum AnthropicContent {
1196    Text(String),
1197    Blocks(Vec<AnthropicContentBlock>),
1198}
1199
1200#[derive(Serialize)]
1201struct AnthropicContentBlock {
1202    #[serde(rename = "type")]
1203    content_type: String,
1204    text: String,
1205}
1206
1207#[derive(Deserialize)]
1208#[allow(dead_code)]
1209struct AnthropicResponse {
1210    #[serde(default)]
1211    id: String,
1212    #[serde(default)]
1213    model: String,
1214    #[serde(default)]
1215    content: Vec<AnthropicContentResponse>,
1216    #[serde(default)]
1217    usage: AnthropicUsage,
1218    #[serde(default)]
1219    #[serde(rename = "type")]
1220    response_type: Option<String>,
1221    #[serde(default)]
1222    role: Option<String>,
1223    #[serde(default)]
1224    stop_reason: Option<String>,
1225}
1226
1227#[derive(Deserialize)]
1228#[allow(dead_code)]
1229struct AnthropicContentResponse {
1230    #[serde(rename = "type", default)]
1231    content_type: String,
1232    #[serde(default)]
1233    text: String,
1234}
1235
1236#[derive(Deserialize, Default)]
1237struct AnthropicUsage {
1238    #[serde(default)]
1239    input_tokens: u32,
1240    #[serde(default)]
1241    output_tokens: u32,
1242}
1243
1244// OpenAI API 结构
1245#[derive(Serialize)]
1246struct OpenAiRequest {
1247    model: String,
1248    messages: Vec<OpenAiMessage>,
1249    #[serde(skip_serializing_if = "Option::is_none")]
1250    max_tokens: Option<u32>,
1251    #[serde(skip_serializing_if = "Option::is_none")]
1252    temperature: Option<f32>,
1253    #[serde(skip_serializing_if = "Option::is_none")]
1254    stop: Option<Vec<String>>,
1255}
1256
1257#[derive(Serialize)]
1258struct OpenAiStreamRequest {
1259    model: String,
1260    messages: Vec<OpenAiMessage>,
1261    #[serde(skip_serializing_if = "Option::is_none")]
1262    max_tokens: Option<u32>,
1263    #[serde(skip_serializing_if = "Option::is_none")]
1264    temperature: Option<f32>,
1265    stream: bool,
1266}
1267
1268#[derive(Serialize)]
1269struct OpenAiMessage {
1270    role: &'static str,
1271    content: String,
1272}
1273
1274#[derive(Deserialize)]
1275struct OpenAiResponse {
1276    id: String,
1277    model: String,
1278    choices: Vec<OpenAiChoice>,
1279    usage: OpenAiUsage,
1280}
1281
1282#[derive(Deserialize)]
1283#[allow(dead_code)]
1284struct OpenAiChoice {
1285    message: OpenAiResponseMessage,
1286    finish_reason: String,
1287}
1288
1289#[derive(Deserialize)]
1290#[allow(dead_code)]
1291struct OpenAiResponseMessage {
1292    role: String,
1293    content: String,
1294}
1295
1296#[derive(Deserialize)]
1297#[allow(dead_code)]
1298struct OpenAiUsage {
1299    prompt_tokens: u32,
1300    completion_tokens: u32,
1301    total_tokens: u32,
1302}
1303
1304// Gemini API 结构
1305#[derive(Serialize)]
1306struct GeminiRequest {
1307    contents: Vec<GeminiContent>,
1308    #[serde(skip_serializing_if = "Option::is_none")]
1309    generation_config: Option<GeminiGenerationConfig>,
1310    #[serde(skip_serializing_if = "Option::is_none")]
1311    system_instruction: Option<GeminiSystemInstruction>,
1312}
1313
1314#[derive(Serialize)]
1315struct GeminiContent {
1316    role: String,
1317    parts: Vec<GeminiPart>,
1318}
1319
1320#[derive(Serialize)]
1321struct GeminiPart {
1322    text: String,
1323}
1324
1325#[derive(Serialize)]
1326struct GeminiGenerationConfig {
1327    #[serde(skip_serializing_if = "Option::is_none")]
1328    max_output_tokens: Option<u32>,
1329    #[serde(skip_serializing_if = "Option::is_none")]
1330    temperature: Option<f32>,
1331    #[serde(skip_serializing_if = "Option::is_none")]
1332    stop_sequences: Option<Vec<String>>,
1333}
1334
1335#[derive(Serialize)]
1336struct GeminiSystemInstruction {
1337    parts: Vec<GeminiPart>,
1338}
1339
1340#[derive(Deserialize)]
1341struct GeminiResponse {
1342    candidates: Vec<GeminiCandidate>,
1343    usage_metadata: GeminiUsageMetadata,
1344}
1345
1346#[derive(Deserialize)]
1347#[allow(dead_code)]
1348struct GeminiCandidate {
1349    content: GeminiContentResponse,
1350    finish_reason: String,
1351}
1352
1353#[derive(Deserialize)]
1354#[allow(dead_code)]
1355struct GeminiContentResponse {
1356    parts: Vec<GeminiPartResponse>,
1357    role: String,
1358}
1359
1360#[derive(Deserialize)]
1361struct GeminiPartResponse {
1362    text: String,
1363}
1364
1365#[derive(Deserialize)]
1366#[allow(dead_code)]
1367struct GeminiUsageMetadata {
1368    prompt_token_count: Option<u32>,
1369    candidates_token_count: Option<u32>,
1370    total_token_count: Option<u32>,
1371}
1372
1373// ========================================================================
1374// AWS Bedrock API 结构
1375// ========================================================================
1376
1377#[derive(Serialize)]
1378struct BedrockRequest {
1379    messages: Vec<BedrockMessage>,
1380    #[serde(skip_serializing_if = "Option::is_none")]
1381    system: Option<String>,
1382    #[serde(skip_serializing_if = "Option::is_none")]
1383    inference_config: Option<BedrockInferenceConfig>,
1384}
1385
1386#[derive(Serialize)]
1387struct BedrockMessage {
1388    role: &'static str,
1389    content: Vec<BedrockContent>,
1390}
1391
1392#[derive(Serialize)]
1393struct BedrockContent {
1394    text: String,
1395}
1396
1397#[derive(Serialize)]
1398struct BedrockInferenceConfig {
1399    #[serde(rename = "maxTokens")]
1400    max_tokens: u32,
1401    temperature: f32,
1402    #[serde(skip_serializing_if = "Option::is_none")]
1403    top_p: Option<f32>,
1404    #[serde(skip_serializing_if = "Option::is_none")]
1405    stop_sequences: Option<Vec<String>>,
1406}
1407
1408#[derive(Deserialize)]
1409#[allow(dead_code)]
1410struct BedrockResponse {
1411    output: BedrockOutput,
1412    usage: BedrockUsage,
1413    #[serde(default)]
1414    request_id: Option<String>,
1415}
1416
1417#[derive(Deserialize)]
1418struct BedrockOutput {
1419    message: BedrockResponseMessage,
1420}
1421
1422#[derive(Deserialize)]
1423struct BedrockResponseMessage {
1424    content: Vec<BedrockResponseContent>,
1425}
1426
1427#[derive(Deserialize)]
1428struct BedrockResponseContent {
1429    text: String,
1430}
1431
1432#[derive(Deserialize)]
1433struct BedrockUsage {
1434    #[serde(default)]
1435    input_tokens: u32,
1436    #[serde(default)]
1437    output_tokens: u32,
1438}
1439
1440// ========================================================================
1441// Ollama API 结构
1442// ========================================================================
1443
1444#[derive(Serialize)]
1445struct OllamaChatRequest {
1446    model: String,
1447    messages: Vec<OllamaMessage>,
1448    stream: bool,
1449    #[serde(skip_serializing_if = "Option::is_none")]
1450    options: Option<OllamaOptions>,
1451}
1452
1453#[derive(Serialize)]
1454struct OllamaMessage {
1455    role: &'static str,
1456    content: String,
1457}
1458
1459#[derive(Serialize)]
1460struct OllamaOptions {
1461    num_predict: i32,
1462    temperature: f32,
1463    #[serde(skip_serializing_if = "Option::is_none")]
1464    stop: Option<Vec<String>>,
1465}
1466
1467#[derive(Deserialize)]
1468struct OllamaChatResponse {
1469    model: String,
1470    message: OllamaResponseMessage,
1471    #[serde(default)]
1472    prompt_eval_count: Option<u32>,
1473    #[serde(default)]
1474    eval_count: Option<u32>,
1475}
1476
1477#[derive(Deserialize)]
1478struct OllamaResponseMessage {
1479    content: String,
1480}
1481
1482#[cfg(test)]
1483mod tests {
1484    use super::*;
1485
1486    #[test]
1487    fn test_default_config() {
1488        let config = LlmRequestConfig::default();
1489        assert_eq!(config.model, "claude-sonnet-4-6");
1490        assert_eq!(config.max_tokens, 4096);
1491    }
1492
1493    #[test]
1494    fn test_client_creation() {
1495        let client = LlmClient::new(LlmProvider::Anthropic, "test_key".to_string());
1496        assert_eq!(client.base_url, "https://api.anthropic.com/v1");
1497    }
1498
1499    #[test]
1500    fn test_openai_client_creation() {
1501        let client = LlmClient::new(LlmProvider::OpenAI, "test_key".to_string());
1502        assert_eq!(client.base_url, "https://api.openai.com/v1");
1503    }
1504
1505    #[test]
1506    fn test_gemini_client_creation() {
1507        let client = LlmClient::new(LlmProvider::Gemini, "test_key".to_string());
1508        assert_eq!(
1509            client.base_url,
1510            "https://generativelanguage.googleapis.com/v1"
1511        );
1512    }
1513
1514    #[test]
1515    fn test_custom_provider() {
1516        let client = LlmClient::new(
1517            LlmProvider::Custom("https://custom.api.com/v1".to_string()),
1518            "test_key".to_string(),
1519        );
1520        assert_eq!(client.base_url, "https://custom.api.com/v1");
1521    }
1522
1523    #[test]
1524    fn test_openai_compatible_provider() {
1525        let client = LlmClient::new(
1526            LlmProvider::OpenAICompatible {
1527                base_url: "https://api.deepseek.com/v1".to_string(),
1528            },
1529            "test_key".to_string(),
1530        );
1531        assert_eq!(client.base_url, "https://api.deepseek.com/v1");
1532    }
1533
1534    #[test]
1535    fn test_azure_openai_client_creation() {
1536        let client = LlmClient::new(LlmProvider::AzureOpenAI, "test_key".to_string());
1537        assert!(client.base_url.contains("openai.azure.com"));
1538    }
1539
1540    #[test]
1541    fn test_bedrock_client_creation() {
1542        let client = LlmClient::new(LlmProvider::Bedrock, "test_key".to_string());
1543        assert!(client.base_url.contains("bedrock-runtime"));
1544    }
1545
1546    #[test]
1547    fn test_ollama_client_creation() {
1548        let client = LlmClient::new(LlmProvider::Ollama, "".to_string());
1549        assert_eq!(client.base_url, "http://localhost:11434");
1550    }
1551
1552    #[test]
1553    fn test_azure_openai_with_custom_url() {
1554        let client = LlmClient::new(LlmProvider::AzureOpenAI, "test_key".to_string())
1555            .with_base_url("https://myresource.openai.azure.com".to_string());
1556        assert_eq!(client.base_url, "https://myresource.openai.azure.com");
1557    }
1558
1559    #[test]
1560    fn test_ollama_with_custom_url() {
1561        let client = LlmClient::new(LlmProvider::Ollama, "".to_string())
1562            .with_base_url("http://192.168.1.100:11434".to_string());
1563        assert_eq!(client.base_url, "http://192.168.1.100:11434");
1564    }
1565
1566    #[test]
1567    fn test_message_creation() {
1568        let message = Message {
1569            role: MessageRole::User,
1570            content: "Hello".to_string(),
1571        };
1572        assert_eq!(message.content, "Hello");
1573    }
1574
1575    #[test]
1576    fn test_config_with_system_prompt() {
1577        let config = LlmRequestConfig {
1578            model: "gpt-4".to_string(),
1579            max_tokens: 8192,
1580            temperature: 0.5,
1581            system_prompt: Some("You are a helpful assistant".to_string()),
1582            stop_sequences: vec![],
1583        };
1584        assert_eq!(config.model, "gpt-4");
1585        assert!(config.system_prompt.is_some());
1586    }
1587
1588    #[test]
1589    fn test_llm_response_creation() {
1590        let response = LlmResponse {
1591            content: "Hello".to_string(),
1592            usage: TokenUsage {
1593                input_tokens: 10,
1594                output_tokens: 5,
1595            },
1596            model: "gpt-4".to_string(),
1597            response_id: "resp_123".to_string(),
1598        };
1599        assert_eq!(response.content, "Hello");
1600        assert_eq!(response.usage.input_tokens, 10);
1601    }
1602
1603    #[test]
1604    fn test_provider_serialization() {
1605        let provider = LlmProvider::Anthropic;
1606        let json = serde_json::to_string(&provider).unwrap();
1607        assert!(json.contains("Anthropic"));
1608    }
1609
1610    #[test]
1611    fn test_message_role_serialization() {
1612        let role = MessageRole::User;
1613        let json = serde_json::to_string(&role).unwrap();
1614        assert!(json.contains("User"));
1615    }
1616}