rust_genai/
client.rs

1//! Client configuration and transport layer.
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Duration;
7
8use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
9use reqwest::{Client as HttpClient, Proxy};
10use tokio::sync::OnceCell;
11
12use crate::auth::OAuthTokenProvider;
13use crate::error::{Error, Result};
14use google_cloud_auth::credentials::{
15    Builder as AuthBuilder, CacheableResource, Credentials as GoogleCredentials,
16};
17use http::Extensions;
18
19/// Gemini 客户端。
20#[derive(Clone)]
21pub struct Client {
22    inner: Arc<ClientInner>,
23}
24
25pub(crate) struct ClientInner {
26    pub http: HttpClient,
27    pub config: ClientConfig,
28    pub api_client: ApiClient,
29    pub(crate) auth_provider: Option<AuthProvider>,
30}
31
32/// 客户端配置。
33#[derive(Debug, Clone)]
34pub struct ClientConfig {
35    /// API 密钥(Gemini API)。
36    pub api_key: Option<String>,
37    /// 后端选择。
38    pub backend: Backend,
39    /// Vertex AI 配置。
40    pub vertex_config: Option<VertexConfig>,
41    /// HTTP 配置。
42    pub http_options: HttpOptions,
43    /// 认证信息。
44    pub credentials: Credentials,
45    /// OAuth scopes(服务账号/ADC 使用)。
46    pub auth_scopes: Vec<String>,
47}
48
49/// 后端选择。
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum Backend {
52    GeminiApi,
53    VertexAi,
54}
55
56/// 认证方式。
57#[derive(Debug, Clone)]
58pub enum Credentials {
59    /// API Key(Gemini API)。
60    ApiKey(String),
61    /// OAuth 用户凭据。
62    OAuth {
63        client_secret_path: PathBuf,
64        token_cache_path: Option<PathBuf>,
65    },
66    /// Application Default Credentials (ADC)。
67    ApplicationDefault,
68}
69
70/// Vertex AI 配置。
71#[derive(Debug, Clone)]
72pub struct VertexConfig {
73    pub project: String,
74    pub location: String,
75    pub credentials: Option<VertexCredentials>,
76}
77
78/// Vertex AI 认证占位。
79#[derive(Debug, Clone)]
80pub struct VertexCredentials {
81    pub access_token: Option<String>,
82}
83
84/// HTTP 配置。
85#[derive(Debug, Clone, Default)]
86pub struct HttpOptions {
87    pub timeout: Option<u64>,
88    pub proxy: Option<String>,
89    pub headers: HashMap<String, String>,
90    pub base_url: Option<String>,
91    pub api_version: Option<String>,
92}
93
94impl Client {
95    /// 创建新客户端(Gemini API)。
96    pub fn new(api_key: impl Into<String>) -> Result<Self> {
97        Self::builder()
98            .api_key(api_key)
99            .backend(Backend::GeminiApi)
100            .build()
101    }
102
103    /// 从环境变量创建客户端。
104    pub fn from_env() -> Result<Self> {
105        let api_key = std::env::var("GEMINI_API_KEY")
106            .or_else(|_| std::env::var("GOOGLE_API_KEY"))
107            .map_err(|_| Error::InvalidConfig {
108                message: "GEMINI_API_KEY or GOOGLE_API_KEY not found".into(),
109            })?;
110        let mut builder = Self::builder().api_key(api_key);
111        if let Ok(base_url) =
112            std::env::var("GENAI_BASE_URL").or_else(|_| std::env::var("GEMINI_BASE_URL"))
113        {
114            if !base_url.trim().is_empty() {
115                builder = builder.base_url(base_url);
116            }
117        }
118        if let Ok(api_version) = std::env::var("GENAI_API_VERSION") {
119            if !api_version.trim().is_empty() {
120                builder = builder.api_version(api_version);
121            }
122        }
123        builder.build()
124    }
125
126    /// 创建 Vertex AI 客户端。
127    pub fn new_vertex(project: impl Into<String>, location: impl Into<String>) -> Result<Self> {
128        Self::builder()
129            .backend(Backend::VertexAi)
130            .vertex_project(project)
131            .vertex_location(location)
132            .build()
133    }
134
135    /// 使用 OAuth 凭据创建客户端(默认读取 token.json)。
136    pub fn with_oauth(client_secret_path: impl AsRef<Path>) -> Result<Self> {
137        Self::builder()
138            .credentials(Credentials::OAuth {
139                client_secret_path: client_secret_path.as_ref().to_path_buf(),
140                token_cache_path: None,
141            })
142            .build()
143    }
144
145    /// 使用 Application Default Credentials 创建客户端。
146    pub fn with_adc() -> Result<Self> {
147        Self::builder()
148            .credentials(Credentials::ApplicationDefault)
149            .build()
150    }
151
152    /// 创建 Builder。
153    pub fn builder() -> ClientBuilder {
154        ClientBuilder::default()
155    }
156
157    /// 访问 Models API。
158    pub fn models(&self) -> crate::models::Models {
159        crate::models::Models::new(self.inner.clone())
160    }
161
162    /// 访问 Chats API。
163    pub fn chats(&self) -> crate::chats::Chats {
164        crate::chats::Chats::new(self.inner.clone())
165    }
166
167    /// 访问 Files API。
168    pub fn files(&self) -> crate::files::Files {
169        crate::files::Files::new(self.inner.clone())
170    }
171
172    /// 访问 FileSearchStores API。
173    pub fn file_search_stores(&self) -> crate::file_search_stores::FileSearchStores {
174        crate::file_search_stores::FileSearchStores::new(self.inner.clone())
175    }
176
177    /// 访问 Documents API。
178    pub fn documents(&self) -> crate::documents::Documents {
179        crate::documents::Documents::new(self.inner.clone())
180    }
181
182    /// 访问 Live API。
183    pub fn live(&self) -> crate::live::Live {
184        crate::live::Live::new(self.inner.clone())
185    }
186
187    /// 访问 Live Music API。
188    pub fn live_music(&self) -> crate::live_music::LiveMusic {
189        crate::live_music::LiveMusic::new(self.inner.clone())
190    }
191
192    /// 访问 Caches API。
193    pub fn caches(&self) -> crate::caches::Caches {
194        crate::caches::Caches::new(self.inner.clone())
195    }
196
197    /// 访问 Batches API。
198    pub fn batches(&self) -> crate::batches::Batches {
199        crate::batches::Batches::new(self.inner.clone())
200    }
201
202    /// 访问 Tunings API。
203    pub fn tunings(&self) -> crate::tunings::Tunings {
204        crate::tunings::Tunings::new(self.inner.clone())
205    }
206
207    /// 访问 Operations API。
208    pub fn operations(&self) -> crate::operations::Operations {
209        crate::operations::Operations::new(self.inner.clone())
210    }
211
212    /// 访问 AuthTokens API(Ephemeral Tokens)。
213    pub fn auth_tokens(&self) -> crate::tokens::AuthTokens {
214        crate::tokens::AuthTokens::new(self.inner.clone())
215    }
216
217    /// 访问 Interactions API。
218    pub fn interactions(&self) -> crate::interactions::Interactions {
219        crate::interactions::Interactions::new(self.inner.clone())
220    }
221
222    /// 访问 Deep Research。
223    pub fn deep_research(&self) -> crate::deep_research::DeepResearch {
224        crate::deep_research::DeepResearch::new(self.inner.clone())
225    }
226}
227
228/// 客户端 Builder。
229#[derive(Default)]
230pub struct ClientBuilder {
231    api_key: Option<String>,
232    credentials: Option<Credentials>,
233    backend: Option<Backend>,
234    vertex_project: Option<String>,
235    vertex_location: Option<String>,
236    http_options: HttpOptions,
237    auth_scopes: Option<Vec<String>>,
238}
239
240impl ClientBuilder {
241    /// 设置 API Key(Gemini API)。
242    pub fn api_key(mut self, key: impl Into<String>) -> Self {
243        self.api_key = Some(key.into());
244        self
245    }
246
247    /// 设置认证方式(OAuth/ADC/API Key)。
248    pub fn credentials(mut self, credentials: Credentials) -> Self {
249        self.credentials = Some(credentials);
250        self
251    }
252
253    /// 设置后端(Gemini API 或 Vertex AI)。
254    pub fn backend(mut self, backend: Backend) -> Self {
255        self.backend = Some(backend);
256        self
257    }
258
259    /// 设置 Vertex AI 项目 ID。
260    pub fn vertex_project(mut self, project: impl Into<String>) -> Self {
261        self.vertex_project = Some(project.into());
262        self
263    }
264
265    /// 设置 Vertex AI 区域。
266    pub fn vertex_location(mut self, location: impl Into<String>) -> Self {
267        self.vertex_location = Some(location.into());
268        self
269    }
270
271    /// 设置请求超时(秒)。
272    pub fn timeout(mut self, secs: u64) -> Self {
273        self.http_options.timeout = Some(secs);
274        self
275    }
276
277    /// 设置代理。
278    pub fn proxy(mut self, url: impl Into<String>) -> Self {
279        self.http_options.proxy = Some(url.into());
280        self
281    }
282
283    /// 增加默认 HTTP 头。
284    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
285        self.http_options.headers.insert(key.into(), value.into());
286        self
287    }
288
289    /// 设置自定义基础 URL。
290    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
291        self.http_options.base_url = Some(base_url.into());
292        self
293    }
294
295    /// 设置 API 版本。
296    pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
297        self.http_options.api_version = Some(api_version.into());
298        self
299    }
300
301    /// 设置 OAuth scopes。
302    pub fn auth_scopes(mut self, scopes: Vec<String>) -> Self {
303        self.auth_scopes = Some(scopes);
304        self
305    }
306
307    /// 构建客户端。
308    pub fn build(self) -> Result<Client> {
309        let backend = self.backend.unwrap_or_else(|| {
310            if self.vertex_project.is_some() || self.vertex_location.is_some() {
311                Backend::VertexAi
312            } else {
313                Backend::GeminiApi
314            }
315        });
316
317        if backend == Backend::VertexAi
318            && (self.vertex_project.is_none() || self.vertex_location.is_none())
319        {
320            return Err(Error::InvalidConfig {
321                message: "Project and location required for Vertex AI".into(),
322            });
323        }
324
325        if self.credentials.is_some()
326            && self.api_key.is_some()
327            && !matches!(self.credentials, Some(Credentials::ApiKey(_)))
328        {
329            return Err(Error::InvalidConfig {
330                message: "API key cannot be combined with OAuth/ADC credentials".into(),
331            });
332        }
333
334        let credentials = match self.credentials {
335            Some(credentials) => credentials,
336            None => {
337                if let Some(api_key) = self.api_key.clone() {
338                    Credentials::ApiKey(api_key)
339                } else if backend == Backend::VertexAi {
340                    Credentials::ApplicationDefault
341                } else {
342                    return Err(Error::InvalidConfig {
343                        message: "API key or OAuth credentials required for Gemini API".into(),
344                    });
345                }
346            }
347        };
348
349        if backend == Backend::VertexAi && matches!(credentials, Credentials::ApiKey(_)) {
350            return Err(Error::InvalidConfig {
351                message: "Vertex AI does not support API key authentication".into(),
352            });
353        }
354
355        let mut headers = HeaderMap::new();
356        for (key, value) in &self.http_options.headers {
357            let name =
358                HeaderName::from_bytes(key.as_bytes()).map_err(|_| Error::InvalidConfig {
359                    message: format!("Invalid header name: {key}"),
360                })?;
361            let value = HeaderValue::from_str(value).map_err(|_| Error::InvalidConfig {
362                message: format!("Invalid header value for {key}"),
363            })?;
364            headers.insert(name, value);
365        }
366
367        if backend == Backend::GeminiApi {
368            let api_key = match &credentials {
369                Credentials::ApiKey(key) => key.as_str(),
370                _ => "",
371            };
372            let header_name = HeaderName::from_static("x-goog-api-key");
373            if !api_key.is_empty() && !headers.contains_key(&header_name) {
374                let mut header_value =
375                    HeaderValue::from_str(api_key).map_err(|_| Error::InvalidConfig {
376                        message: "Invalid API key value".into(),
377                    })?;
378                header_value.set_sensitive(true);
379                headers.insert(header_name, header_value);
380            }
381        }
382
383        let mut http_builder = HttpClient::builder();
384        if let Some(timeout) = self.http_options.timeout {
385            http_builder = http_builder.timeout(Duration::from_secs(timeout));
386        }
387
388        if let Some(proxy_url) = &self.http_options.proxy {
389            let proxy = Proxy::all(proxy_url).map_err(|e| Error::InvalidConfig {
390                message: format!("Invalid proxy: {e}"),
391            })?;
392            http_builder = http_builder.proxy(proxy);
393        }
394
395        if !headers.is_empty() {
396            http_builder = http_builder.default_headers(headers);
397        }
398
399        let http = http_builder.build()?;
400
401        let auth_scopes = self
402            .auth_scopes
403            .unwrap_or_else(|| default_auth_scopes(backend));
404        let api_key = match &credentials {
405            Credentials::ApiKey(key) => Some(key.clone()),
406            _ => None,
407        };
408        let config = ClientConfig {
409            api_key,
410            backend,
411            vertex_config: if backend == Backend::VertexAi {
412                Some(VertexConfig {
413                    project: self.vertex_project.unwrap(),
414                    location: self.vertex_location.unwrap(),
415                    credentials: None,
416                })
417            } else {
418                None
419            },
420            http_options: self.http_options,
421            credentials: credentials.clone(),
422            auth_scopes,
423        };
424
425        let auth_provider = match &credentials {
426            Credentials::ApiKey(_) => None,
427            Credentials::OAuth {
428                client_secret_path,
429                token_cache_path,
430            } => Some(AuthProvider::OAuth(Arc::new(
431                OAuthTokenProvider::from_paths(
432                    client_secret_path.clone(),
433                    token_cache_path.clone(),
434                )?,
435            ))),
436            Credentials::ApplicationDefault => {
437                Some(AuthProvider::ApplicationDefault(Arc::new(OnceCell::new())))
438            }
439        };
440
441        let api_client = ApiClient::new(&config)?;
442
443        Ok(Client {
444            inner: Arc::new(ClientInner {
445                http,
446                config,
447                api_client,
448                auth_provider,
449            }),
450        })
451    }
452}
453
454#[derive(Clone)]
455pub(crate) enum AuthProvider {
456    OAuth(Arc<OAuthTokenProvider>),
457    ApplicationDefault(Arc<OnceCell<Arc<GoogleCredentials>>>),
458}
459
460impl AuthProvider {
461    async fn headers(&self, scopes: &[&str]) -> Result<HeaderMap> {
462        match self {
463            AuthProvider::OAuth(provider) => {
464                let token = provider.token().await?;
465                let mut header =
466                    HeaderValue::from_str(&format!("Bearer {token}")).map_err(|_| Error::Auth {
467                        message: "Invalid OAuth access token".into(),
468                    })?;
469                header.set_sensitive(true);
470                let mut headers = HeaderMap::new();
471                headers.insert(AUTHORIZATION, header);
472                Ok(headers)
473            }
474            AuthProvider::ApplicationDefault(cell) => {
475                let credentials = cell
476                    .get_or_try_init(|| async {
477                        AuthBuilder::default()
478                            .with_scopes(scopes.iter().copied())
479                            .build()
480                            .map(Arc::new)
481                            .map_err(|err| Error::Auth {
482                                message: format!("ADC init failed: {err}"),
483                            })
484                    })
485                    .await?;
486                let headers = credentials
487                    .headers(Extensions::new())
488                    .await
489                    .map_err(|err| Error::Auth {
490                        message: format!("ADC header fetch failed: {err}"),
491                    })?;
492                match headers {
493                    CacheableResource::New { data, .. } => Ok(data),
494                    CacheableResource::NotModified => Err(Error::Auth {
495                        message: "ADC header fetch returned NotModified without cached headers"
496                            .into(),
497                    }),
498                }
499            }
500        }
501    }
502}
503
504impl ClientInner {
505    /// 发送请求并自动注入鉴权头。
506    pub async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
507        let mut request = request.build()?;
508        if let Some(headers) = self.auth_headers().await? {
509            for (name, value) in headers.iter() {
510                if request.headers().contains_key(name) {
511                    continue;
512                }
513                let mut value = value.clone();
514                if name == AUTHORIZATION {
515                    value.set_sensitive(true);
516                }
517                request.headers_mut().insert(name.clone(), value);
518            }
519        }
520        maybe_add_mcp_usage_header(request.headers_mut())?;
521        Ok(self.http.execute(request).await?)
522    }
523
524    async fn auth_headers(&self) -> Result<Option<HeaderMap>> {
525        let provider = match &self.auth_provider {
526            Some(provider) => provider,
527            None => return Ok(None),
528        };
529
530        let scopes: Vec<&str> = self.config.auth_scopes.iter().map(|s| s.as_str()).collect();
531        let headers = provider.headers(&scopes).await?;
532        Ok(Some(headers))
533    }
534}
535
536fn default_auth_scopes(backend: Backend) -> Vec<String> {
537    match backend {
538        Backend::VertexAi => vec!["https://www.googleapis.com/auth/cloud-platform".into()],
539        Backend::GeminiApi => vec![
540            "https://www.googleapis.com/auth/generative-language".into(),
541            "https://www.googleapis.com/auth/generative-language.retriever".into(),
542        ],
543    }
544}
545
546pub(crate) struct ApiClient {
547    pub base_url: String,
548    pub api_version: String,
549}
550
551impl ApiClient {
552    /// 创建 API 客户端配置。
553    pub fn new(config: &ClientConfig) -> Result<Self> {
554        let base_url = if let Some(base_url) = &config.http_options.base_url {
555            normalize_base_url(base_url)
556        } else {
557            match config.backend {
558                Backend::VertexAi => {
559                    let location = config
560                        .vertex_config
561                        .as_ref()
562                        .map(|cfg| cfg.location.as_str())
563                        .unwrap_or("");
564                    if location.is_empty() {
565                        "https://aiplatform.googleapis.com/".to_string()
566                    } else {
567                        format!("https://{location}-aiplatform.googleapis.com/")
568                    }
569                }
570                Backend::GeminiApi => "https://generativelanguage.googleapis.com/".to_string(),
571            }
572        };
573
574        let api_version =
575            config
576                .http_options
577                .api_version
578                .clone()
579                .unwrap_or_else(|| match config.backend {
580                    Backend::VertexAi => "v1beta1".to_string(),
581                    Backend::GeminiApi => "v1beta".to_string(),
582                });
583
584        Ok(Self {
585            base_url,
586            api_version,
587        })
588    }
589}
590
591fn normalize_base_url(base_url: &str) -> String {
592    let mut value = base_url.trim().to_string();
593    if !value.ends_with('/') {
594        value.push('/');
595    }
596    value
597}
598
599#[cfg(feature = "mcp")]
600fn maybe_add_mcp_usage_header(headers: &mut HeaderMap) -> Result<()> {
601    crate::mcp::append_mcp_usage_header(headers)
602}
603
604#[cfg(not(feature = "mcp"))]
605fn maybe_add_mcp_usage_header(_headers: &mut HeaderMap) -> Result<()> {
606    Ok(())
607}
608
609#[cfg(test)]
610mod tests {
611    use super::*;
612
613    #[test]
614    fn test_client_from_api_key() {
615        let client = Client::new("test-api-key").unwrap();
616        assert_eq!(client.inner.config.backend, Backend::GeminiApi);
617    }
618
619    #[test]
620    fn test_client_builder() {
621        let client = Client::builder()
622            .api_key("test-key")
623            .timeout(30)
624            .build()
625            .unwrap();
626        assert!(client.inner.config.api_key.is_some());
627    }
628
629    #[test]
630    fn test_vertex_ai_config() {
631        let client = Client::new_vertex("my-project", "us-central1").unwrap();
632        assert_eq!(client.inner.config.backend, Backend::VertexAi);
633        assert_eq!(
634            client.inner.api_client.base_url,
635            "https://us-central1-aiplatform.googleapis.com/"
636        );
637    }
638
639    #[test]
640    fn test_base_url_normalization() {
641        let client = Client::builder()
642            .api_key("test-key")
643            .base_url("https://example.com")
644            .build()
645            .unwrap();
646        assert_eq!(client.inner.api_client.base_url, "https://example.com/");
647    }
648}