Skip to main content

rust_genai/
client.rs

1//! Client configuration and transport layer.
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
9use reqwest::{Client as HttpClient, Proxy};
10use tokio::sync::OnceCell;
11
12use crate::auth::OAuthTokenProvider;
13use crate::error::{Error, Result};
14use google_cloud_auth::credentials::{
15    Builder as AuthBuilder, CacheableResource, Credentials as GoogleCredentials,
16};
17use http::Extensions;
18use rust_genai_types::http::HttpRetryOptions;
19
20/// Gemini 客户端。
21#[derive(Clone)]
22pub struct Client {
23    inner: Arc<ClientInner>,
24}
25
26pub(crate) struct ClientInner {
27    pub http: HttpClient,
28    pub config: ClientConfig,
29    pub api_client: ApiClient,
30    pub(crate) auth_provider: Option<AuthProvider>,
31}
32
33/// 客户端配置。
34#[derive(Debug, Clone)]
35pub struct ClientConfig {
36    /// API 密钥(Gemini API)。
37    pub api_key: Option<String>,
38    /// 后端选择。
39    pub backend: Backend,
40    /// Vertex AI 配置。
41    pub vertex_config: Option<VertexConfig>,
42    /// HTTP 配置。
43    pub http_options: HttpOptions,
44    /// 认证信息。
45    pub credentials: Credentials,
46    /// OAuth scopes(服务账号/ADC 使用)。
47    pub auth_scopes: Vec<String>,
48}
49
50/// 后端选择。
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum Backend {
53    GeminiApi,
54    VertexAi,
55}
56
57/// 认证方式。
58#[derive(Debug, Clone)]
59pub enum Credentials {
60    /// API Key(Gemini API)。
61    ApiKey(String),
62    /// OAuth 用户凭据。
63    OAuth {
64        client_secret_path: PathBuf,
65        token_cache_path: Option<PathBuf>,
66    },
67    /// Application Default Credentials (ADC)。
68    ApplicationDefault,
69}
70
71/// Vertex AI 配置。
72#[derive(Debug, Clone)]
73pub struct VertexConfig {
74    pub project: String,
75    pub location: String,
76    pub credentials: Option<VertexCredentials>,
77}
78
79/// Vertex AI 认证占位。
80#[derive(Debug, Clone)]
81pub struct VertexCredentials {
82    pub access_token: Option<String>,
83}
84
85/// HTTP 配置。
86#[derive(Debug, Clone, Default)]
87pub struct HttpOptions {
88    pub timeout: Option<u64>,
89    pub proxy: Option<String>,
90    pub headers: HashMap<String, String>,
91    pub base_url: Option<String>,
92    pub api_version: Option<String>,
93    pub retry_options: Option<HttpRetryOptions>,
94}
95
96impl Client {
97    /// 创建新客户端(Gemini API)。
98    ///
99    /// # Errors
100    /// 当配置无效或构建客户端失败时返回错误。
101    pub fn new(api_key: impl Into<String>) -> Result<Self> {
102        Self::builder()
103            .api_key(api_key)
104            .backend(Backend::GeminiApi)
105            .build()
106    }
107
108    /// 从环境变量创建客户端。
109    ///
110    /// # Errors
111    /// 当环境变量缺失或构建客户端失败时返回错误。
112    pub fn from_env() -> Result<Self> {
113        let api_key = std::env::var("GEMINI_API_KEY")
114            .or_else(|_| std::env::var("GOOGLE_API_KEY"))
115            .map_err(|_| Error::InvalidConfig {
116                message: "GEMINI_API_KEY or GOOGLE_API_KEY not found".into(),
117            })?;
118        let mut builder = Self::builder().api_key(api_key);
119        if let Ok(base_url) =
120            std::env::var("GENAI_BASE_URL").or_else(|_| std::env::var("GEMINI_BASE_URL"))
121        {
122            if !base_url.trim().is_empty() {
123                builder = builder.base_url(base_url);
124            }
125        }
126        if let Ok(api_version) = std::env::var("GENAI_API_VERSION") {
127            if !api_version.trim().is_empty() {
128                builder = builder.api_version(api_version);
129            }
130        }
131        builder.build()
132    }
133
134    /// 创建 Vertex AI 客户端。
135    ///
136    /// # Errors
137    /// 当配置无效或构建客户端失败时返回错误。
138    pub fn new_vertex(project: impl Into<String>, location: impl Into<String>) -> Result<Self> {
139        Self::builder()
140            .backend(Backend::VertexAi)
141            .vertex_project(project)
142            .vertex_location(location)
143            .build()
144    }
145
146    /// 使用 OAuth 凭据创建客户端(默认读取 token.json)。
147    ///
148    /// # Errors
149    /// 当凭据路径无效或构建客户端失败时返回错误。
150    pub fn with_oauth(client_secret_path: impl AsRef<Path>) -> Result<Self> {
151        Self::builder()
152            .credentials(Credentials::OAuth {
153                client_secret_path: client_secret_path.as_ref().to_path_buf(),
154                token_cache_path: None,
155            })
156            .build()
157    }
158
159    /// 使用 Application Default Credentials 创建客户端。
160    ///
161    /// # Errors
162    /// 当构建客户端失败时返回错误。
163    pub fn with_adc() -> Result<Self> {
164        Self::builder()
165            .credentials(Credentials::ApplicationDefault)
166            .build()
167    }
168
169    /// 创建 Builder。
170    #[must_use]
171    pub fn builder() -> ClientBuilder {
172        ClientBuilder::default()
173    }
174
175    /// 访问 Models API。
176    #[must_use]
177    pub fn models(&self) -> crate::models::Models {
178        crate::models::Models::new(self.inner.clone())
179    }
180
181    /// 访问 Chats API。
182    #[must_use]
183    pub fn chats(&self) -> crate::chats::Chats {
184        crate::chats::Chats::new(self.inner.clone())
185    }
186
187    /// 访问 Files API。
188    #[must_use]
189    pub fn files(&self) -> crate::files::Files {
190        crate::files::Files::new(self.inner.clone())
191    }
192
193    /// 访问 `FileSearchStores` API。
194    #[must_use]
195    pub fn file_search_stores(&self) -> crate::file_search_stores::FileSearchStores {
196        crate::file_search_stores::FileSearchStores::new(self.inner.clone())
197    }
198
199    /// 访问 Documents API。
200    #[must_use]
201    pub fn documents(&self) -> crate::documents::Documents {
202        crate::documents::Documents::new(self.inner.clone())
203    }
204
205    /// 访问 Live API。
206    #[must_use]
207    pub fn live(&self) -> crate::live::Live {
208        crate::live::Live::new(self.inner.clone())
209    }
210
211    /// 访问 Live Music API。
212    #[must_use]
213    pub fn live_music(&self) -> crate::live_music::LiveMusic {
214        crate::live_music::LiveMusic::new(self.inner.clone())
215    }
216
217    /// 访问 Caches API。
218    #[must_use]
219    pub fn caches(&self) -> crate::caches::Caches {
220        crate::caches::Caches::new(self.inner.clone())
221    }
222
223    /// 访问 Batches API。
224    #[must_use]
225    pub fn batches(&self) -> crate::batches::Batches {
226        crate::batches::Batches::new(self.inner.clone())
227    }
228
229    /// 访问 Tunings API。
230    #[must_use]
231    pub fn tunings(&self) -> crate::tunings::Tunings {
232        crate::tunings::Tunings::new(self.inner.clone())
233    }
234
235    /// 访问 Operations API。
236    #[must_use]
237    pub fn operations(&self) -> crate::operations::Operations {
238        crate::operations::Operations::new(self.inner.clone())
239    }
240
241    /// 访问 `AuthTokens` API(Ephemeral Tokens)。
242    #[must_use]
243    pub fn auth_tokens(&self) -> crate::tokens::AuthTokens {
244        crate::tokens::AuthTokens::new(self.inner.clone())
245    }
246
247    /// 访问 Tokens API(Ephemeral Tokens)。
248    ///
249    /// 与官方 SDK 的 `tokens` 命名保持一致(等价于 `auth_tokens()`)。
250    #[must_use]
251    pub fn tokens(&self) -> crate::tokens::Tokens {
252        self.auth_tokens()
253    }
254
255    /// 访问 Interactions API。
256    #[must_use]
257    pub fn interactions(&self) -> crate::interactions::Interactions {
258        crate::interactions::Interactions::new(self.inner.clone())
259    }
260
261    /// 访问 Deep Research。
262    #[must_use]
263    pub fn deep_research(&self) -> crate::deep_research::DeepResearch {
264        crate::deep_research::DeepResearch::new(self.inner.clone())
265    }
266}
267
268/// 客户端 Builder。
269#[derive(Default)]
270pub struct ClientBuilder {
271    api_key: Option<String>,
272    credentials: Option<Credentials>,
273    backend: Option<Backend>,
274    vertex_project: Option<String>,
275    vertex_location: Option<String>,
276    http_options: HttpOptions,
277    auth_scopes: Option<Vec<String>>,
278}
279
280impl ClientBuilder {
281    /// 设置 API Key(Gemini API)。
282    #[must_use]
283    pub fn api_key(mut self, key: impl Into<String>) -> Self {
284        self.api_key = Some(key.into());
285        self
286    }
287
288    /// 设置认证方式(OAuth/ADC/API Key)。
289    #[must_use]
290    pub fn credentials(mut self, credentials: Credentials) -> Self {
291        self.credentials = Some(credentials);
292        self
293    }
294
295    /// 设置后端(Gemini API 或 Vertex AI)。
296    #[must_use]
297    pub const fn backend(mut self, backend: Backend) -> Self {
298        self.backend = Some(backend);
299        self
300    }
301
302    /// 设置 Vertex AI 项目 ID。
303    #[must_use]
304    pub fn vertex_project(mut self, project: impl Into<String>) -> Self {
305        self.vertex_project = Some(project.into());
306        self
307    }
308
309    /// 设置 Vertex AI 区域。
310    #[must_use]
311    pub fn vertex_location(mut self, location: impl Into<String>) -> Self {
312        self.vertex_location = Some(location.into());
313        self
314    }
315
316    /// 设置请求超时(秒)。
317    #[must_use]
318    pub const fn timeout(mut self, secs: u64) -> Self {
319        self.http_options.timeout = Some(secs);
320        self
321    }
322
323    /// 设置代理。
324    #[must_use]
325    pub fn proxy(mut self, url: impl Into<String>) -> Self {
326        self.http_options.proxy = Some(url.into());
327        self
328    }
329
330    /// 增加默认 HTTP 头。
331    #[must_use]
332    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
333        self.http_options.headers.insert(key.into(), value.into());
334        self
335    }
336
337    /// 设置自定义基础 URL。
338    #[must_use]
339    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
340        self.http_options.base_url = Some(base_url.into());
341        self
342    }
343
344    /// 设置 API 版本。
345    #[must_use]
346    pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
347        self.http_options.api_version = Some(api_version.into());
348        self
349    }
350
351    /// 设置 HTTP 重试选项。
352    #[must_use]
353    pub fn retry_options(mut self, retry_options: HttpRetryOptions) -> Self {
354        self.http_options.retry_options = Some(retry_options);
355        self
356    }
357
358    /// 设置 OAuth scopes。
359    #[must_use]
360    pub fn auth_scopes(mut self, scopes: Vec<String>) -> Self {
361        self.auth_scopes = Some(scopes);
362        self
363    }
364
365    /// 构建客户端。
366    ///
367    /// # Errors
368    /// 当配置不完整、参数无效或构建 HTTP 客户端失败时返回错误。
369    pub fn build(self) -> Result<Client> {
370        let Self {
371            api_key,
372            credentials,
373            backend,
374            vertex_project,
375            vertex_location,
376            http_options,
377            auth_scopes,
378        } = self;
379
380        let backend = Self::resolve_backend(
381            backend,
382            vertex_project.as_deref(),
383            vertex_location.as_deref(),
384        );
385        Self::validate_vertex_config(
386            backend,
387            vertex_project.as_deref(),
388            vertex_location.as_deref(),
389        )?;
390        let credentials = Self::resolve_credentials(backend, api_key.as_deref(), credentials)?;
391        let headers = Self::build_headers(&http_options, backend, &credentials)?;
392        let http = Self::build_http_client(&http_options, headers)?;
393
394        let auth_scopes = auth_scopes.unwrap_or_else(|| default_auth_scopes(backend));
395        let api_key = match &credentials {
396            Credentials::ApiKey(key) => Some(key.clone()),
397            _ => None,
398        };
399        let vertex_config = Self::build_vertex_config(backend, vertex_project, vertex_location)?;
400        let config = ClientConfig {
401            api_key,
402            backend,
403            vertex_config,
404            http_options,
405            credentials: credentials.clone(),
406            auth_scopes,
407        };
408
409        let auth_provider = build_auth_provider(&credentials)?;
410        let api_client = ApiClient::new(&config);
411
412        Ok(Client {
413            inner: Arc::new(ClientInner {
414                http,
415                config,
416                api_client,
417                auth_provider,
418            }),
419        })
420    }
421
422    fn resolve_backend(
423        backend: Option<Backend>,
424        vertex_project: Option<&str>,
425        vertex_location: Option<&str>,
426    ) -> Backend {
427        backend.unwrap_or_else(|| {
428            if vertex_project.is_some() || vertex_location.is_some() {
429                Backend::VertexAi
430            } else {
431                Backend::GeminiApi
432            }
433        })
434    }
435
436    fn validate_vertex_config(
437        backend: Backend,
438        vertex_project: Option<&str>,
439        vertex_location: Option<&str>,
440    ) -> Result<()> {
441        if backend == Backend::VertexAi && (vertex_project.is_none() || vertex_location.is_none()) {
442            return Err(Error::InvalidConfig {
443                message: "Project and location required for Vertex AI".into(),
444            });
445        }
446        Ok(())
447    }
448
449    fn resolve_credentials(
450        backend: Backend,
451        api_key: Option<&str>,
452        credentials: Option<Credentials>,
453    ) -> Result<Credentials> {
454        if credentials.is_some()
455            && api_key.is_some()
456            && !matches!(credentials, Some(Credentials::ApiKey(_)))
457        {
458            return Err(Error::InvalidConfig {
459                message: "API key cannot be combined with OAuth/ADC credentials".into(),
460            });
461        }
462
463        let credentials = match credentials {
464            Some(credentials) => credentials,
465            None => {
466                if let Some(api_key) = api_key {
467                    Credentials::ApiKey(api_key.to_string())
468                } else if backend == Backend::VertexAi {
469                    Credentials::ApplicationDefault
470                } else {
471                    return Err(Error::InvalidConfig {
472                        message: "API key or OAuth credentials required for Gemini API".into(),
473                    });
474                }
475            }
476        };
477
478        if backend == Backend::VertexAi && matches!(credentials, Credentials::ApiKey(_)) {
479            return Err(Error::InvalidConfig {
480                message: "Vertex AI does not support API key authentication".into(),
481            });
482        }
483
484        Ok(credentials)
485    }
486
487    fn build_headers(
488        http_options: &HttpOptions,
489        backend: Backend,
490        credentials: &Credentials,
491    ) -> Result<HeaderMap> {
492        let mut headers = HeaderMap::new();
493        for (key, value) in &http_options.headers {
494            let name =
495                HeaderName::from_bytes(key.as_bytes()).map_err(|_| Error::InvalidConfig {
496                    message: format!("Invalid header name: {key}"),
497                })?;
498            let value = HeaderValue::from_str(value).map_err(|_| Error::InvalidConfig {
499                message: format!("Invalid header value for {key}"),
500            })?;
501            headers.insert(name, value);
502        }
503
504        if backend == Backend::GeminiApi {
505            let api_key = match credentials {
506                Credentials::ApiKey(key) => key.as_str(),
507                _ => "",
508            };
509            let header_name = HeaderName::from_static("x-goog-api-key");
510            if !api_key.is_empty() && !headers.contains_key(&header_name) {
511                let mut header_value =
512                    HeaderValue::from_str(api_key).map_err(|_| Error::InvalidConfig {
513                        message: "Invalid API key value".into(),
514                    })?;
515                header_value.set_sensitive(true);
516                headers.insert(header_name, header_value);
517            }
518        }
519
520        Ok(headers)
521    }
522
523    fn build_http_client(http_options: &HttpOptions, headers: HeaderMap) -> Result<HttpClient> {
524        let mut http_builder = HttpClient::builder();
525        if let Some(timeout) = http_options.timeout {
526            http_builder = http_builder.timeout(Duration::from_secs(timeout));
527        }
528
529        if let Some(proxy_url) = &http_options.proxy {
530            let proxy = Proxy::all(proxy_url).map_err(|e| Error::InvalidConfig {
531                message: format!("Invalid proxy: {e}"),
532            })?;
533            http_builder = http_builder.proxy(proxy);
534        }
535
536        if !headers.is_empty() {
537            http_builder = http_builder.default_headers(headers);
538        }
539
540        Ok(http_builder.build()?)
541    }
542
543    fn build_vertex_config(
544        backend: Backend,
545        vertex_project: Option<String>,
546        vertex_location: Option<String>,
547    ) -> Result<Option<VertexConfig>> {
548        if backend != Backend::VertexAi {
549            return Ok(None);
550        }
551        let project = vertex_project.ok_or_else(|| Error::InvalidConfig {
552            message: "Project and location required for Vertex AI".into(),
553        })?;
554        let location = vertex_location.ok_or_else(|| Error::InvalidConfig {
555            message: "Project and location required for Vertex AI".into(),
556        })?;
557        Ok(Some(VertexConfig {
558            project,
559            location,
560            credentials: None,
561        }))
562    }
563}
564
565fn build_auth_provider(credentials: &Credentials) -> Result<Option<AuthProvider>> {
566    match credentials {
567        Credentials::ApiKey(_) => Ok(None),
568        Credentials::OAuth {
569            client_secret_path,
570            token_cache_path,
571        } => Ok(Some(AuthProvider::OAuth(Arc::new(
572            OAuthTokenProvider::from_paths(client_secret_path.clone(), token_cache_path.clone())?,
573        )))),
574        Credentials::ApplicationDefault => Ok(Some(AuthProvider::ApplicationDefault(Arc::new(
575            OnceCell::new(),
576        )))),
577    }
578}
579
580#[derive(Clone)]
581pub(crate) enum AuthProvider {
582    OAuth(Arc<OAuthTokenProvider>),
583    ApplicationDefault(Arc<OnceCell<Arc<GoogleCredentials>>>),
584}
585
586impl AuthProvider {
587    async fn headers(&self, scopes: &[&str]) -> Result<HeaderMap> {
588        match self {
589            Self::OAuth(provider) => {
590                let token = provider.token().await?;
591                let mut header =
592                    HeaderValue::from_str(&format!("Bearer {token}")).map_err(|_| Error::Auth {
593                        message: "Invalid OAuth access token".into(),
594                    })?;
595                header.set_sensitive(true);
596                let mut headers = HeaderMap::new();
597                headers.insert(AUTHORIZATION, header);
598                Ok(headers)
599            }
600            Self::ApplicationDefault(cell) => {
601                let credentials = cell
602                    .get_or_try_init(|| async {
603                        AuthBuilder::default()
604                            .with_scopes(scopes.iter().copied())
605                            .build()
606                            .map(Arc::new)
607                            .map_err(|err| Error::Auth {
608                                message: format!("ADC init failed: {err}"),
609                            })
610                    })
611                    .await?;
612                let headers = credentials
613                    .headers(Extensions::new())
614                    .await
615                    .map_err(|err| Error::Auth {
616                        message: format!("ADC header fetch failed: {err}"),
617                    })?;
618                match headers {
619                    CacheableResource::New { data, .. } => Ok(data),
620                    CacheableResource::NotModified => Err(Error::Auth {
621                        message: "ADC header fetch returned NotModified without cached headers"
622                            .into(),
623                    }),
624                }
625            }
626        }
627    }
628}
629
630const DEFAULT_RETRY_ATTEMPTS: u32 = 5; // Including the initial call
631const DEFAULT_RETRY_INITIAL_DELAY_SECS: f64 = 1.0;
632const DEFAULT_RETRY_MAX_DELAY_SECS: f64 = 60.0;
633const DEFAULT_RETRY_EXP_BASE: f64 = 2.0;
634const DEFAULT_RETRY_JITTER: f64 = 1.0;
635const DEFAULT_RETRY_HTTP_STATUS_CODES: [u16; 6] = [408, 429, 500, 502, 503, 504];
636
637impl ClientInner {
638    /// 发送请求并自动注入鉴权头。
639    ///
640    /// # Errors
641    /// 当请求构建、鉴权头获取或网络请求失败时返回错误。
642    pub async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
643        self.send_with_http_options(request, None).await
644    }
645
646    /// 发送请求(支持 per-request HTTP options,例如 retry_options)。
647    ///
648    /// # Errors
649    /// 当请求构建、鉴权头获取或网络请求失败时返回错误。
650    pub async fn send_with_http_options(
651        &self,
652        request: reqwest::RequestBuilder,
653        request_http_options: Option<&rust_genai_types::http::HttpOptions>,
654    ) -> Result<reqwest::Response> {
655        let retry_options = request_http_options
656            .and_then(|options| options.retry_options.as_ref())
657            .or(self.config.http_options.retry_options.as_ref());
658
659        let request_template = request.build()?;
660        if let Some(options) = retry_options {
661            self.execute_with_retry(request_template, options).await
662        } else {
663            self.execute_once(request_template).await
664        }
665    }
666
667    async fn execute_once(&self, mut request: reqwest::Request) -> Result<reqwest::Response> {
668        self.prepare_request(&mut request).await?;
669        Ok(self.http.execute(request).await?)
670    }
671
672    async fn execute_with_retry(
673        &self,
674        request_template: reqwest::Request,
675        retry_options: &HttpRetryOptions,
676    ) -> Result<reqwest::Response> {
677        let attempts = retry_options.attempts.unwrap_or(DEFAULT_RETRY_ATTEMPTS);
678        if attempts <= 1 {
679            return self.execute_once(request_template).await;
680        }
681
682        // If the request body can't be cloned, we can't safely retry.
683        if request_template.try_clone().is_none() {
684            return self.execute_once(request_template).await;
685        }
686
687        let retryable_codes: &[u16] = retry_options
688            .http_status_codes
689            .as_deref()
690            .unwrap_or(&DEFAULT_RETRY_HTTP_STATUS_CODES);
691
692        for attempt in 0..attempts {
693            let request = request_template
694                .try_clone()
695                .expect("request_template is cloneable");
696            let response = self.execute_once(request).await?;
697
698            if response.status().is_success() {
699                return Ok(response);
700            }
701
702            let status = response.status().as_u16();
703            let should_retry = retryable_codes.contains(&status);
704            let is_last_attempt = attempt + 1 >= attempts;
705            if !should_retry || is_last_attempt {
706                return Ok(response);
707            }
708
709            // Drop the response before retrying to release the connection back to the pool.
710            drop(response);
711
712            let delay = retry_delay_secs(retry_options, attempt);
713            if delay > 0.0 {
714                tokio::time::sleep(Duration::from_secs_f64(delay)).await;
715            }
716        }
717
718        // Loop always returns on success or final attempt.
719        unreachable!("retry loop must return a response");
720    }
721
722    async fn prepare_request(&self, request: &mut reqwest::Request) -> Result<()> {
723        if let Some(headers) = self.auth_headers().await? {
724            for (name, value) in &headers {
725                if request.headers().contains_key(name) {
726                    continue;
727                }
728                let mut value = value.clone();
729                if name == AUTHORIZATION {
730                    value.set_sensitive(true);
731                }
732                request.headers_mut().insert(name.clone(), value);
733            }
734        }
735        #[cfg(feature = "mcp")]
736        crate::mcp::append_mcp_usage_header(request.headers_mut())?;
737        Ok(())
738    }
739
740    async fn auth_headers(&self) -> Result<Option<HeaderMap>> {
741        let Some(provider) = &self.auth_provider else {
742            return Ok(None);
743        };
744
745        let scopes: Vec<&str> = self.config.auth_scopes.iter().map(String::as_str).collect();
746        let headers = provider.headers(&scopes).await?;
747        Ok(Some(headers))
748    }
749}
750
751fn retry_delay_secs(options: &HttpRetryOptions, retry_index: u32) -> f64 {
752    let initial = options
753        .initial_delay
754        .unwrap_or(DEFAULT_RETRY_INITIAL_DELAY_SECS)
755        .max(0.0);
756    let max_delay = options
757        .max_delay
758        .unwrap_or(DEFAULT_RETRY_MAX_DELAY_SECS)
759        .max(0.0);
760    let exp_base = options.exp_base.unwrap_or(DEFAULT_RETRY_EXP_BASE).max(0.0);
761    let jitter = options.jitter.unwrap_or(DEFAULT_RETRY_JITTER).max(0.0);
762
763    let exp_delay = if exp_base == 0.0 {
764        0.0
765    } else {
766        initial * exp_base.powf(retry_index as f64)
767    };
768    let base_delay = if max_delay > 0.0 {
769        exp_delay.min(max_delay)
770    } else {
771        exp_delay
772    };
773
774    let jitter_delay = if jitter > 0.0 {
775        // Basic pseudo-random jitter without adding a new RNG dependency.
776        let nanos = SystemTime::now()
777            .duration_since(UNIX_EPOCH)
778            .unwrap_or_default()
779            .subsec_nanos() as f64;
780        let frac = (nanos / 1_000_000_000.0).clamp(0.0, 1.0);
781        frac * jitter
782    } else {
783        0.0
784    };
785
786    let delay = base_delay + jitter_delay;
787    if max_delay > 0.0 {
788        delay.min(max_delay)
789    } else {
790        delay
791    }
792}
793
794fn default_auth_scopes(backend: Backend) -> Vec<String> {
795    match backend {
796        Backend::VertexAi => vec!["https://www.googleapis.com/auth/cloud-platform".into()],
797        Backend::GeminiApi => vec![
798            "https://www.googleapis.com/auth/generative-language".into(),
799            "https://www.googleapis.com/auth/generative-language.retriever".into(),
800        ],
801    }
802}
803
804pub(crate) struct ApiClient {
805    pub base_url: String,
806    pub api_version: String,
807}
808
809impl ApiClient {
810    /// 创建 API 客户端配置。
811    pub fn new(config: &ClientConfig) -> Self {
812        let base_url = config.http_options.base_url.as_deref().map_or_else(
813            || match config.backend {
814                Backend::VertexAi => {
815                    let location = config
816                        .vertex_config
817                        .as_ref()
818                        .map_or("", |cfg| cfg.location.as_str());
819                    if location.is_empty() {
820                        "https://aiplatform.googleapis.com/".to_string()
821                    } else {
822                        format!("https://{location}-aiplatform.googleapis.com/")
823                    }
824                }
825                Backend::GeminiApi => "https://generativelanguage.googleapis.com/".to_string(),
826            },
827            normalize_base_url,
828        );
829
830        let api_version =
831            config
832                .http_options
833                .api_version
834                .clone()
835                .unwrap_or_else(|| match config.backend {
836                    Backend::VertexAi => "v1beta1".to_string(),
837                    Backend::GeminiApi => "v1beta".to_string(),
838                });
839
840        Self {
841            base_url,
842            api_version,
843        }
844    }
845}
846
847fn normalize_base_url(base_url: &str) -> String {
848    let mut value = base_url.trim().to_string();
849    if !value.ends_with('/') {
850        value.push('/');
851    }
852    value
853}
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858    use crate::test_support::with_env;
859    use std::path::PathBuf;
860    use tempfile::tempdir;
861
862    #[test]
863    fn test_client_from_api_key() {
864        let client = Client::new("test-api-key").unwrap();
865        assert_eq!(client.inner.config.backend, Backend::GeminiApi);
866    }
867
868    #[test]
869    fn test_client_builder() {
870        let client = Client::builder()
871            .api_key("test-key")
872            .timeout(30)
873            .build()
874            .unwrap();
875        assert!(client.inner.config.api_key.is_some());
876    }
877
878    #[test]
879    fn test_vertex_ai_config() {
880        let client = Client::new_vertex("my-project", "us-central1").unwrap();
881        assert_eq!(client.inner.config.backend, Backend::VertexAi);
882        assert_eq!(
883            client.inner.api_client.base_url,
884            "https://us-central1-aiplatform.googleapis.com/"
885        );
886    }
887
888    #[test]
889    fn test_base_url_normalization() {
890        let client = Client::builder()
891            .api_key("test-key")
892            .base_url("https://example.com")
893            .build()
894            .unwrap();
895        assert_eq!(client.inner.api_client.base_url, "https://example.com/");
896    }
897
898    #[test]
899    fn test_from_env_reads_overrides() {
900        with_env(
901            &[
902                ("GEMINI_API_KEY", Some("env-key")),
903                ("GENAI_BASE_URL", Some("https://env.example.com")),
904                ("GENAI_API_VERSION", Some("v99")),
905                ("GOOGLE_API_KEY", None),
906            ],
907            || {
908                let client = Client::from_env().unwrap();
909                assert_eq!(client.inner.api_client.base_url, "https://env.example.com/");
910                assert_eq!(client.inner.api_client.api_version, "v99");
911            },
912        );
913    }
914
915    #[test]
916    fn test_from_env_ignores_empty_overrides() {
917        with_env(
918            &[
919                ("GEMINI_API_KEY", Some("env-key")),
920                ("GENAI_BASE_URL", Some("   ")),
921                ("GENAI_API_VERSION", Some("")),
922                ("GOOGLE_API_KEY", None),
923            ],
924            || {
925                let client = Client::from_env().unwrap();
926                assert_eq!(
927                    client.inner.api_client.base_url,
928                    "https://generativelanguage.googleapis.com/"
929                );
930                assert_eq!(client.inner.api_client.api_version, "v1beta");
931            },
932        );
933    }
934
935    #[test]
936    fn test_from_env_missing_key_errors() {
937        with_env(
938            &[
939                ("GEMINI_API_KEY", None),
940                ("GOOGLE_API_KEY", None),
941                ("GENAI_BASE_URL", None),
942            ],
943            || {
944                let result = Client::from_env();
945                assert!(result.is_err());
946            },
947        );
948    }
949
950    #[test]
951    fn test_from_env_google_api_key_fallback() {
952        with_env(
953            &[
954                ("GEMINI_API_KEY", None),
955                ("GOOGLE_API_KEY", Some("google-key")),
956            ],
957            || {
958                let client = Client::from_env().unwrap();
959                assert_eq!(client.inner.config.api_key.as_deref(), Some("google-key"));
960            },
961        );
962    }
963
964    #[test]
965    fn test_with_oauth_missing_client_secret_errors() {
966        let dir = tempdir().unwrap();
967        let secret_path = dir.path().join("missing_client_secret.json");
968        let err = Client::with_oauth(&secret_path).err().unwrap();
969        assert!(matches!(err, Error::InvalidConfig { .. }));
970    }
971
972    #[test]
973    fn test_with_adc_builds_client() {
974        let client = Client::with_adc().unwrap();
975        assert!(matches!(
976            client.inner.config.credentials,
977            Credentials::ApplicationDefault
978        ));
979    }
980
981    #[test]
982    fn test_builder_defaults_to_vertex_when_project_set() {
983        let client = Client::builder()
984            .vertex_project("proj")
985            .vertex_location("loc")
986            .build()
987            .unwrap();
988        assert_eq!(client.inner.config.backend, Backend::VertexAi);
989        assert!(matches!(
990            client.inner.config.credentials,
991            Credentials::ApplicationDefault
992        ));
993    }
994
995    #[test]
996    fn test_valid_proxy_is_accepted() {
997        let client = Client::builder()
998            .api_key("test-key")
999            .proxy("http://127.0.0.1:8888")
1000            .build();
1001        assert!(client.is_ok());
1002    }
1003
1004    #[test]
1005    fn test_vertex_requires_project_and_location() {
1006        let result = Client::builder().backend(Backend::VertexAi).build();
1007        assert!(result.is_err());
1008    }
1009
1010    #[test]
1011    fn test_api_key_with_oauth_is_invalid() {
1012        let result = Client::builder()
1013            .api_key("test-key")
1014            .credentials(Credentials::OAuth {
1015                client_secret_path: PathBuf::from("client_secret.json"),
1016                token_cache_path: None,
1017            })
1018            .build();
1019        assert!(result.is_err());
1020    }
1021
1022    #[test]
1023    fn test_missing_api_key_for_gemini_errors() {
1024        let result = Client::builder().backend(Backend::GeminiApi).build();
1025        assert!(result.is_err());
1026    }
1027
1028    #[test]
1029    fn test_invalid_header_name_is_rejected() {
1030        let result = Client::builder()
1031            .api_key("test-key")
1032            .header("bad header", "value")
1033            .build();
1034        assert!(result.is_err());
1035    }
1036
1037    #[test]
1038    fn test_invalid_header_value_is_rejected() {
1039        let result = Client::builder()
1040            .api_key("test-key")
1041            .header("x-test", "bad\nvalue")
1042            .build();
1043        assert!(result.is_err());
1044    }
1045
1046    #[test]
1047    fn test_invalid_api_key_value_is_rejected() {
1048        let err = Client::builder().api_key("bad\nkey").build().err().unwrap();
1049        assert!(
1050            matches!(err, Error::InvalidConfig { message } if message.contains("Invalid API key value"))
1051        );
1052    }
1053
1054    #[test]
1055    fn test_invalid_proxy_is_rejected() {
1056        let result = Client::builder()
1057            .api_key("test-key")
1058            .proxy("not a url")
1059            .build();
1060        assert!(result.is_err());
1061    }
1062
1063    #[test]
1064    fn test_vertex_api_key_is_rejected() {
1065        let result = Client::builder()
1066            .backend(Backend::VertexAi)
1067            .vertex_project("proj")
1068            .vertex_location("loc")
1069            .credentials(Credentials::ApiKey("key".into()))
1070            .build();
1071        assert!(result.is_err());
1072    }
1073
1074    #[test]
1075    fn test_default_auth_scopes() {
1076        let gemini = default_auth_scopes(Backend::GeminiApi);
1077        assert!(gemini.iter().any(|s| s.contains("generative-language")));
1078
1079        let vertex = default_auth_scopes(Backend::VertexAi);
1080        assert!(vertex.iter().any(|s| s.contains("cloud-platform")));
1081    }
1082
1083    #[test]
1084    fn test_custom_auth_scopes_override_default() {
1085        let client = Client::builder()
1086            .api_key("test-key")
1087            .auth_scopes(vec!["scope-1".to_string()])
1088            .build()
1089            .unwrap();
1090        assert_eq!(client.inner.config.auth_scopes, vec!["scope-1".to_string()]);
1091    }
1092}