1use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Duration;
7
8use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
9use reqwest::{Client as HttpClient, Proxy};
10use tokio::sync::OnceCell;
11
12use crate::auth::OAuthTokenProvider;
13use crate::error::{Error, Result};
14use google_cloud_auth::credentials::{
15 Builder as AuthBuilder, CacheableResource, Credentials as GoogleCredentials,
16};
17use http::Extensions;
18
19#[derive(Clone)]
21pub struct Client {
22 inner: Arc<ClientInner>,
23}
24
25pub(crate) struct ClientInner {
26 pub http: HttpClient,
27 pub config: ClientConfig,
28 pub api_client: ApiClient,
29 pub(crate) auth_provider: Option<AuthProvider>,
30}
31
32#[derive(Debug, Clone)]
34pub struct ClientConfig {
35 pub api_key: Option<String>,
37 pub backend: Backend,
39 pub vertex_config: Option<VertexConfig>,
41 pub http_options: HttpOptions,
43 pub credentials: Credentials,
45 pub auth_scopes: Vec<String>,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum Backend {
52 GeminiApi,
53 VertexAi,
54}
55
56#[derive(Debug, Clone)]
58pub enum Credentials {
59 ApiKey(String),
61 OAuth {
63 client_secret_path: PathBuf,
64 token_cache_path: Option<PathBuf>,
65 },
66 ApplicationDefault,
68}
69
70#[derive(Debug, Clone)]
72pub struct VertexConfig {
73 pub project: String,
74 pub location: String,
75 pub credentials: Option<VertexCredentials>,
76}
77
78#[derive(Debug, Clone)]
80pub struct VertexCredentials {
81 pub access_token: Option<String>,
82}
83
84#[derive(Debug, Clone, Default)]
86pub struct HttpOptions {
87 pub timeout: Option<u64>,
88 pub proxy: Option<String>,
89 pub headers: HashMap<String, String>,
90 pub base_url: Option<String>,
91 pub api_version: Option<String>,
92}
93
94impl Client {
95 pub fn new(api_key: impl Into<String>) -> Result<Self> {
97 Self::builder()
98 .api_key(api_key)
99 .backend(Backend::GeminiApi)
100 .build()
101 }
102
103 pub fn from_env() -> Result<Self> {
105 let api_key = std::env::var("GEMINI_API_KEY")
106 .or_else(|_| std::env::var("GOOGLE_API_KEY"))
107 .map_err(|_| Error::InvalidConfig {
108 message: "GEMINI_API_KEY or GOOGLE_API_KEY not found".into(),
109 })?;
110 let mut builder = Self::builder().api_key(api_key);
111 if let Ok(base_url) =
112 std::env::var("GENAI_BASE_URL").or_else(|_| std::env::var("GEMINI_BASE_URL"))
113 {
114 if !base_url.trim().is_empty() {
115 builder = builder.base_url(base_url);
116 }
117 }
118 if let Ok(api_version) = std::env::var("GENAI_API_VERSION") {
119 if !api_version.trim().is_empty() {
120 builder = builder.api_version(api_version);
121 }
122 }
123 builder.build()
124 }
125
126 pub fn new_vertex(project: impl Into<String>, location: impl Into<String>) -> Result<Self> {
128 Self::builder()
129 .backend(Backend::VertexAi)
130 .vertex_project(project)
131 .vertex_location(location)
132 .build()
133 }
134
135 pub fn with_oauth(client_secret_path: impl AsRef<Path>) -> Result<Self> {
137 Self::builder()
138 .credentials(Credentials::OAuth {
139 client_secret_path: client_secret_path.as_ref().to_path_buf(),
140 token_cache_path: None,
141 })
142 .build()
143 }
144
145 pub fn with_adc() -> Result<Self> {
147 Self::builder()
148 .credentials(Credentials::ApplicationDefault)
149 .build()
150 }
151
152 pub fn builder() -> ClientBuilder {
154 ClientBuilder::default()
155 }
156
157 pub fn models(&self) -> crate::models::Models {
159 crate::models::Models::new(self.inner.clone())
160 }
161
162 pub fn chats(&self) -> crate::chats::Chats {
164 crate::chats::Chats::new(self.inner.clone())
165 }
166
167 pub fn files(&self) -> crate::files::Files {
169 crate::files::Files::new(self.inner.clone())
170 }
171
172 pub fn file_search_stores(&self) -> crate::file_search_stores::FileSearchStores {
174 crate::file_search_stores::FileSearchStores::new(self.inner.clone())
175 }
176
177 pub fn documents(&self) -> crate::documents::Documents {
179 crate::documents::Documents::new(self.inner.clone())
180 }
181
182 pub fn live(&self) -> crate::live::Live {
184 crate::live::Live::new(self.inner.clone())
185 }
186
187 pub fn live_music(&self) -> crate::live_music::LiveMusic {
189 crate::live_music::LiveMusic::new(self.inner.clone())
190 }
191
192 pub fn caches(&self) -> crate::caches::Caches {
194 crate::caches::Caches::new(self.inner.clone())
195 }
196
197 pub fn batches(&self) -> crate::batches::Batches {
199 crate::batches::Batches::new(self.inner.clone())
200 }
201
202 pub fn tunings(&self) -> crate::tunings::Tunings {
204 crate::tunings::Tunings::new(self.inner.clone())
205 }
206
207 pub fn operations(&self) -> crate::operations::Operations {
209 crate::operations::Operations::new(self.inner.clone())
210 }
211
212 pub fn auth_tokens(&self) -> crate::tokens::AuthTokens {
214 crate::tokens::AuthTokens::new(self.inner.clone())
215 }
216
217 pub fn interactions(&self) -> crate::interactions::Interactions {
219 crate::interactions::Interactions::new(self.inner.clone())
220 }
221
222 pub fn deep_research(&self) -> crate::deep_research::DeepResearch {
224 crate::deep_research::DeepResearch::new(self.inner.clone())
225 }
226}
227
228#[derive(Default)]
230pub struct ClientBuilder {
231 api_key: Option<String>,
232 credentials: Option<Credentials>,
233 backend: Option<Backend>,
234 vertex_project: Option<String>,
235 vertex_location: Option<String>,
236 http_options: HttpOptions,
237 auth_scopes: Option<Vec<String>>,
238}
239
240impl ClientBuilder {
241 pub fn api_key(mut self, key: impl Into<String>) -> Self {
243 self.api_key = Some(key.into());
244 self
245 }
246
247 pub fn credentials(mut self, credentials: Credentials) -> Self {
249 self.credentials = Some(credentials);
250 self
251 }
252
253 pub fn backend(mut self, backend: Backend) -> Self {
255 self.backend = Some(backend);
256 self
257 }
258
259 pub fn vertex_project(mut self, project: impl Into<String>) -> Self {
261 self.vertex_project = Some(project.into());
262 self
263 }
264
265 pub fn vertex_location(mut self, location: impl Into<String>) -> Self {
267 self.vertex_location = Some(location.into());
268 self
269 }
270
271 pub fn timeout(mut self, secs: u64) -> Self {
273 self.http_options.timeout = Some(secs);
274 self
275 }
276
277 pub fn proxy(mut self, url: impl Into<String>) -> Self {
279 self.http_options.proxy = Some(url.into());
280 self
281 }
282
283 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
285 self.http_options.headers.insert(key.into(), value.into());
286 self
287 }
288
289 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
291 self.http_options.base_url = Some(base_url.into());
292 self
293 }
294
295 pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
297 self.http_options.api_version = Some(api_version.into());
298 self
299 }
300
301 pub fn auth_scopes(mut self, scopes: Vec<String>) -> Self {
303 self.auth_scopes = Some(scopes);
304 self
305 }
306
307 pub fn build(self) -> Result<Client> {
309 let backend = self.backend.unwrap_or_else(|| {
310 if self.vertex_project.is_some() || self.vertex_location.is_some() {
311 Backend::VertexAi
312 } else {
313 Backend::GeminiApi
314 }
315 });
316
317 if backend == Backend::VertexAi
318 && (self.vertex_project.is_none() || self.vertex_location.is_none())
319 {
320 return Err(Error::InvalidConfig {
321 message: "Project and location required for Vertex AI".into(),
322 });
323 }
324
325 if self.credentials.is_some()
326 && self.api_key.is_some()
327 && !matches!(self.credentials, Some(Credentials::ApiKey(_)))
328 {
329 return Err(Error::InvalidConfig {
330 message: "API key cannot be combined with OAuth/ADC credentials".into(),
331 });
332 }
333
334 let credentials = match self.credentials {
335 Some(credentials) => credentials,
336 None => {
337 if let Some(api_key) = self.api_key.clone() {
338 Credentials::ApiKey(api_key)
339 } else if backend == Backend::VertexAi {
340 Credentials::ApplicationDefault
341 } else {
342 return Err(Error::InvalidConfig {
343 message: "API key or OAuth credentials required for Gemini API".into(),
344 });
345 }
346 }
347 };
348
349 if backend == Backend::VertexAi && matches!(credentials, Credentials::ApiKey(_)) {
350 return Err(Error::InvalidConfig {
351 message: "Vertex AI does not support API key authentication".into(),
352 });
353 }
354
355 let mut headers = HeaderMap::new();
356 for (key, value) in &self.http_options.headers {
357 let name =
358 HeaderName::from_bytes(key.as_bytes()).map_err(|_| Error::InvalidConfig {
359 message: format!("Invalid header name: {key}"),
360 })?;
361 let value = HeaderValue::from_str(value).map_err(|_| Error::InvalidConfig {
362 message: format!("Invalid header value for {key}"),
363 })?;
364 headers.insert(name, value);
365 }
366
367 if backend == Backend::GeminiApi {
368 let api_key = match &credentials {
369 Credentials::ApiKey(key) => key.as_str(),
370 _ => "",
371 };
372 let header_name = HeaderName::from_static("x-goog-api-key");
373 if !api_key.is_empty() && !headers.contains_key(&header_name) {
374 let mut header_value =
375 HeaderValue::from_str(api_key).map_err(|_| Error::InvalidConfig {
376 message: "Invalid API key value".into(),
377 })?;
378 header_value.set_sensitive(true);
379 headers.insert(header_name, header_value);
380 }
381 }
382
383 let mut http_builder = HttpClient::builder();
384 if let Some(timeout) = self.http_options.timeout {
385 http_builder = http_builder.timeout(Duration::from_secs(timeout));
386 }
387
388 if let Some(proxy_url) = &self.http_options.proxy {
389 let proxy = Proxy::all(proxy_url).map_err(|e| Error::InvalidConfig {
390 message: format!("Invalid proxy: {e}"),
391 })?;
392 http_builder = http_builder.proxy(proxy);
393 }
394
395 if !headers.is_empty() {
396 http_builder = http_builder.default_headers(headers);
397 }
398
399 let http = http_builder.build()?;
400
401 let auth_scopes = self
402 .auth_scopes
403 .unwrap_or_else(|| default_auth_scopes(backend));
404 let api_key = match &credentials {
405 Credentials::ApiKey(key) => Some(key.clone()),
406 _ => None,
407 };
408 let config = ClientConfig {
409 api_key,
410 backend,
411 vertex_config: if backend == Backend::VertexAi {
412 Some(VertexConfig {
413 project: self.vertex_project.unwrap(),
414 location: self.vertex_location.unwrap(),
415 credentials: None,
416 })
417 } else {
418 None
419 },
420 http_options: self.http_options,
421 credentials: credentials.clone(),
422 auth_scopes,
423 };
424
425 let auth_provider = match &credentials {
426 Credentials::ApiKey(_) => None,
427 Credentials::OAuth {
428 client_secret_path,
429 token_cache_path,
430 } => Some(AuthProvider::OAuth(Arc::new(
431 OAuthTokenProvider::from_paths(
432 client_secret_path.clone(),
433 token_cache_path.clone(),
434 )?,
435 ))),
436 Credentials::ApplicationDefault => {
437 Some(AuthProvider::ApplicationDefault(Arc::new(OnceCell::new())))
438 }
439 };
440
441 let api_client = ApiClient::new(&config)?;
442
443 Ok(Client {
444 inner: Arc::new(ClientInner {
445 http,
446 config,
447 api_client,
448 auth_provider,
449 }),
450 })
451 }
452}
453
454#[derive(Clone)]
455pub(crate) enum AuthProvider {
456 OAuth(Arc<OAuthTokenProvider>),
457 ApplicationDefault(Arc<OnceCell<Arc<GoogleCredentials>>>),
458}
459
460impl AuthProvider {
461 async fn headers(&self, scopes: &[&str]) -> Result<HeaderMap> {
462 match self {
463 AuthProvider::OAuth(provider) => {
464 let token = provider.token().await?;
465 let mut header =
466 HeaderValue::from_str(&format!("Bearer {token}")).map_err(|_| Error::Auth {
467 message: "Invalid OAuth access token".into(),
468 })?;
469 header.set_sensitive(true);
470 let mut headers = HeaderMap::new();
471 headers.insert(AUTHORIZATION, header);
472 Ok(headers)
473 }
474 AuthProvider::ApplicationDefault(cell) => {
475 let credentials = cell
476 .get_or_try_init(|| async {
477 AuthBuilder::default()
478 .with_scopes(scopes.iter().copied())
479 .build()
480 .map(Arc::new)
481 .map_err(|err| Error::Auth {
482 message: format!("ADC init failed: {err}"),
483 })
484 })
485 .await?;
486 let headers = credentials
487 .headers(Extensions::new())
488 .await
489 .map_err(|err| Error::Auth {
490 message: format!("ADC header fetch failed: {err}"),
491 })?;
492 match headers {
493 CacheableResource::New { data, .. } => Ok(data),
494 CacheableResource::NotModified => Err(Error::Auth {
495 message: "ADC header fetch returned NotModified without cached headers"
496 .into(),
497 }),
498 }
499 }
500 }
501 }
502}
503
504impl ClientInner {
505 pub async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
507 let mut request = request.build()?;
508 if let Some(headers) = self.auth_headers().await? {
509 for (name, value) in headers.iter() {
510 if request.headers().contains_key(name) {
511 continue;
512 }
513 let mut value = value.clone();
514 if name == AUTHORIZATION {
515 value.set_sensitive(true);
516 }
517 request.headers_mut().insert(name.clone(), value);
518 }
519 }
520 maybe_add_mcp_usage_header(request.headers_mut())?;
521 Ok(self.http.execute(request).await?)
522 }
523
524 async fn auth_headers(&self) -> Result<Option<HeaderMap>> {
525 let provider = match &self.auth_provider {
526 Some(provider) => provider,
527 None => return Ok(None),
528 };
529
530 let scopes: Vec<&str> = self.config.auth_scopes.iter().map(|s| s.as_str()).collect();
531 let headers = provider.headers(&scopes).await?;
532 Ok(Some(headers))
533 }
534}
535
536fn default_auth_scopes(backend: Backend) -> Vec<String> {
537 match backend {
538 Backend::VertexAi => vec!["https://www.googleapis.com/auth/cloud-platform".into()],
539 Backend::GeminiApi => vec![
540 "https://www.googleapis.com/auth/generative-language".into(),
541 "https://www.googleapis.com/auth/generative-language.retriever".into(),
542 ],
543 }
544}
545
546pub(crate) struct ApiClient {
547 pub base_url: String,
548 pub api_version: String,
549}
550
551impl ApiClient {
552 pub fn new(config: &ClientConfig) -> Result<Self> {
554 let base_url = if let Some(base_url) = &config.http_options.base_url {
555 normalize_base_url(base_url)
556 } else {
557 match config.backend {
558 Backend::VertexAi => {
559 let location = config
560 .vertex_config
561 .as_ref()
562 .map(|cfg| cfg.location.as_str())
563 .unwrap_or("");
564 if location.is_empty() {
565 "https://aiplatform.googleapis.com/".to_string()
566 } else {
567 format!("https://{location}-aiplatform.googleapis.com/")
568 }
569 }
570 Backend::GeminiApi => "https://generativelanguage.googleapis.com/".to_string(),
571 }
572 };
573
574 let api_version =
575 config
576 .http_options
577 .api_version
578 .clone()
579 .unwrap_or_else(|| match config.backend {
580 Backend::VertexAi => "v1beta1".to_string(),
581 Backend::GeminiApi => "v1beta".to_string(),
582 });
583
584 Ok(Self {
585 base_url,
586 api_version,
587 })
588 }
589}
590
591fn normalize_base_url(base_url: &str) -> String {
592 let mut value = base_url.trim().to_string();
593 if !value.ends_with('/') {
594 value.push('/');
595 }
596 value
597}
598
599#[cfg(feature = "mcp")]
600fn maybe_add_mcp_usage_header(headers: &mut HeaderMap) -> Result<()> {
601 crate::mcp::append_mcp_usage_header(headers)
602}
603
604#[cfg(not(feature = "mcp"))]
605fn maybe_add_mcp_usage_header(_headers: &mut HeaderMap) -> Result<()> {
606 Ok(())
607}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612
613 #[test]
614 fn test_client_from_api_key() {
615 let client = Client::new("test-api-key").unwrap();
616 assert_eq!(client.inner.config.backend, Backend::GeminiApi);
617 }
618
619 #[test]
620 fn test_client_builder() {
621 let client = Client::builder()
622 .api_key("test-key")
623 .timeout(30)
624 .build()
625 .unwrap();
626 assert!(client.inner.config.api_key.is_some());
627 }
628
629 #[test]
630 fn test_vertex_ai_config() {
631 let client = Client::new_vertex("my-project", "us-central1").unwrap();
632 assert_eq!(client.inner.config.backend, Backend::VertexAi);
633 assert_eq!(
634 client.inner.api_client.base_url,
635 "https://us-central1-aiplatform.googleapis.com/"
636 );
637 }
638
639 #[test]
640 fn test_base_url_normalization() {
641 let client = Client::builder()
642 .api_key("test-key")
643 .base_url("https://example.com")
644 .build()
645 .unwrap();
646 assert_eq!(client.inner.api_client.base_url, "https://example.com/");
647 }
648}