1use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
9use reqwest::{Client as HttpClient, Proxy};
10use tokio::sync::OnceCell;
11
12use crate::auth::OAuthTokenProvider;
13use crate::error::{Error, Result};
14use google_cloud_auth::credentials::{
15 Builder as AuthBuilder, CacheableResource, Credentials as GoogleCredentials,
16};
17use http::Extensions;
18use rust_genai_types::http::HttpRetryOptions;
19
20#[derive(Clone)]
22pub struct Client {
23 inner: Arc<ClientInner>,
24}
25
26pub(crate) struct ClientInner {
27 pub http: HttpClient,
28 pub config: ClientConfig,
29 pub api_client: ApiClient,
30 pub(crate) auth_provider: Option<AuthProvider>,
31}
32
33#[derive(Debug, Clone)]
35pub struct ClientConfig {
36 pub api_key: Option<String>,
38 pub backend: Backend,
40 pub vertex_config: Option<VertexConfig>,
42 pub http_options: HttpOptions,
44 pub credentials: Credentials,
46 pub auth_scopes: Vec<String>,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum Backend {
53 GeminiApi,
54 VertexAi,
55}
56
57#[derive(Debug, Clone)]
59pub enum Credentials {
60 ApiKey(String),
62 OAuth {
64 client_secret_path: PathBuf,
65 token_cache_path: Option<PathBuf>,
66 },
67 ApplicationDefault,
69}
70
71#[derive(Debug, Clone)]
73pub struct VertexConfig {
74 pub project: String,
75 pub location: String,
76 pub credentials: Option<VertexCredentials>,
77}
78
79#[derive(Debug, Clone)]
81pub struct VertexCredentials {
82 pub access_token: Option<String>,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct HttpOptions {
88 pub timeout: Option<u64>,
89 pub proxy: Option<String>,
90 pub headers: HashMap<String, String>,
91 pub base_url: Option<String>,
92 pub api_version: Option<String>,
93 pub retry_options: Option<HttpRetryOptions>,
94}
95
96impl Client {
97 pub fn new(api_key: impl Into<String>) -> Result<Self> {
102 Self::builder()
103 .api_key(api_key)
104 .backend(Backend::GeminiApi)
105 .build()
106 }
107
108 pub fn from_env() -> Result<Self> {
113 let api_key = std::env::var("GEMINI_API_KEY")
114 .or_else(|_| std::env::var("GOOGLE_API_KEY"))
115 .map_err(|_| Error::InvalidConfig {
116 message: "GEMINI_API_KEY or GOOGLE_API_KEY not found".into(),
117 })?;
118 let mut builder = Self::builder().api_key(api_key);
119 if let Ok(base_url) =
120 std::env::var("GENAI_BASE_URL").or_else(|_| std::env::var("GEMINI_BASE_URL"))
121 {
122 if !base_url.trim().is_empty() {
123 builder = builder.base_url(base_url);
124 }
125 }
126 if let Ok(api_version) = std::env::var("GENAI_API_VERSION") {
127 if !api_version.trim().is_empty() {
128 builder = builder.api_version(api_version);
129 }
130 }
131 builder.build()
132 }
133
134 pub fn new_vertex(project: impl Into<String>, location: impl Into<String>) -> Result<Self> {
139 Self::builder()
140 .backend(Backend::VertexAi)
141 .vertex_project(project)
142 .vertex_location(location)
143 .build()
144 }
145
146 pub fn with_oauth(client_secret_path: impl AsRef<Path>) -> Result<Self> {
151 Self::builder()
152 .credentials(Credentials::OAuth {
153 client_secret_path: client_secret_path.as_ref().to_path_buf(),
154 token_cache_path: None,
155 })
156 .build()
157 }
158
159 pub fn with_adc() -> Result<Self> {
164 Self::builder()
165 .credentials(Credentials::ApplicationDefault)
166 .build()
167 }
168
169 #[must_use]
171 pub fn builder() -> ClientBuilder {
172 ClientBuilder::default()
173 }
174
175 #[must_use]
177 pub fn models(&self) -> crate::models::Models {
178 crate::models::Models::new(self.inner.clone())
179 }
180
181 #[must_use]
183 pub fn chats(&self) -> crate::chats::Chats {
184 crate::chats::Chats::new(self.inner.clone())
185 }
186
187 #[must_use]
189 pub fn files(&self) -> crate::files::Files {
190 crate::files::Files::new(self.inner.clone())
191 }
192
193 #[must_use]
195 pub fn file_search_stores(&self) -> crate::file_search_stores::FileSearchStores {
196 crate::file_search_stores::FileSearchStores::new(self.inner.clone())
197 }
198
199 #[must_use]
201 pub fn documents(&self) -> crate::documents::Documents {
202 crate::documents::Documents::new(self.inner.clone())
203 }
204
205 #[must_use]
207 pub fn live(&self) -> crate::live::Live {
208 crate::live::Live::new(self.inner.clone())
209 }
210
211 #[must_use]
213 pub fn live_music(&self) -> crate::live_music::LiveMusic {
214 crate::live_music::LiveMusic::new(self.inner.clone())
215 }
216
217 #[must_use]
219 pub fn caches(&self) -> crate::caches::Caches {
220 crate::caches::Caches::new(self.inner.clone())
221 }
222
223 #[must_use]
225 pub fn batches(&self) -> crate::batches::Batches {
226 crate::batches::Batches::new(self.inner.clone())
227 }
228
229 #[must_use]
231 pub fn tunings(&self) -> crate::tunings::Tunings {
232 crate::tunings::Tunings::new(self.inner.clone())
233 }
234
235 #[must_use]
237 pub fn operations(&self) -> crate::operations::Operations {
238 crate::operations::Operations::new(self.inner.clone())
239 }
240
241 #[must_use]
243 pub fn auth_tokens(&self) -> crate::tokens::AuthTokens {
244 crate::tokens::AuthTokens::new(self.inner.clone())
245 }
246
247 #[must_use]
251 pub fn tokens(&self) -> crate::tokens::Tokens {
252 self.auth_tokens()
253 }
254
255 #[must_use]
257 pub fn interactions(&self) -> crate::interactions::Interactions {
258 crate::interactions::Interactions::new(self.inner.clone())
259 }
260
261 #[must_use]
263 pub fn deep_research(&self) -> crate::deep_research::DeepResearch {
264 crate::deep_research::DeepResearch::new(self.inner.clone())
265 }
266}
267
268#[derive(Default)]
270pub struct ClientBuilder {
271 api_key: Option<String>,
272 credentials: Option<Credentials>,
273 backend: Option<Backend>,
274 vertex_project: Option<String>,
275 vertex_location: Option<String>,
276 http_options: HttpOptions,
277 auth_scopes: Option<Vec<String>>,
278}
279
280impl ClientBuilder {
281 #[must_use]
283 pub fn api_key(mut self, key: impl Into<String>) -> Self {
284 self.api_key = Some(key.into());
285 self
286 }
287
288 #[must_use]
290 pub fn credentials(mut self, credentials: Credentials) -> Self {
291 self.credentials = Some(credentials);
292 self
293 }
294
295 #[must_use]
297 pub const fn backend(mut self, backend: Backend) -> Self {
298 self.backend = Some(backend);
299 self
300 }
301
302 #[must_use]
304 pub fn vertex_project(mut self, project: impl Into<String>) -> Self {
305 self.vertex_project = Some(project.into());
306 self
307 }
308
309 #[must_use]
311 pub fn vertex_location(mut self, location: impl Into<String>) -> Self {
312 self.vertex_location = Some(location.into());
313 self
314 }
315
316 #[must_use]
318 pub const fn timeout(mut self, secs: u64) -> Self {
319 self.http_options.timeout = Some(secs);
320 self
321 }
322
323 #[must_use]
325 pub fn proxy(mut self, url: impl Into<String>) -> Self {
326 self.http_options.proxy = Some(url.into());
327 self
328 }
329
330 #[must_use]
332 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
333 self.http_options.headers.insert(key.into(), value.into());
334 self
335 }
336
337 #[must_use]
339 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
340 self.http_options.base_url = Some(base_url.into());
341 self
342 }
343
344 #[must_use]
346 pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
347 self.http_options.api_version = Some(api_version.into());
348 self
349 }
350
351 #[must_use]
353 pub fn retry_options(mut self, retry_options: HttpRetryOptions) -> Self {
354 self.http_options.retry_options = Some(retry_options);
355 self
356 }
357
358 #[must_use]
360 pub fn auth_scopes(mut self, scopes: Vec<String>) -> Self {
361 self.auth_scopes = Some(scopes);
362 self
363 }
364
365 pub fn build(self) -> Result<Client> {
370 let Self {
371 api_key,
372 credentials,
373 backend,
374 vertex_project,
375 vertex_location,
376 http_options,
377 auth_scopes,
378 } = self;
379
380 let backend = Self::resolve_backend(
381 backend,
382 vertex_project.as_deref(),
383 vertex_location.as_deref(),
384 );
385 Self::validate_vertex_config(
386 backend,
387 vertex_project.as_deref(),
388 vertex_location.as_deref(),
389 )?;
390 let credentials = Self::resolve_credentials(backend, api_key.as_deref(), credentials)?;
391 let headers = Self::build_headers(&http_options, backend, &credentials)?;
392 let http = Self::build_http_client(&http_options, headers)?;
393
394 let auth_scopes = auth_scopes.unwrap_or_else(|| default_auth_scopes(backend));
395 let api_key = match &credentials {
396 Credentials::ApiKey(key) => Some(key.clone()),
397 _ => None,
398 };
399 let vertex_config = Self::build_vertex_config(backend, vertex_project, vertex_location)?;
400 let config = ClientConfig {
401 api_key,
402 backend,
403 vertex_config,
404 http_options,
405 credentials: credentials.clone(),
406 auth_scopes,
407 };
408
409 let auth_provider = build_auth_provider(&credentials)?;
410 let api_client = ApiClient::new(&config);
411
412 Ok(Client {
413 inner: Arc::new(ClientInner {
414 http,
415 config,
416 api_client,
417 auth_provider,
418 }),
419 })
420 }
421
422 fn resolve_backend(
423 backend: Option<Backend>,
424 vertex_project: Option<&str>,
425 vertex_location: Option<&str>,
426 ) -> Backend {
427 backend.unwrap_or_else(|| {
428 if vertex_project.is_some() || vertex_location.is_some() {
429 Backend::VertexAi
430 } else {
431 Backend::GeminiApi
432 }
433 })
434 }
435
436 fn validate_vertex_config(
437 backend: Backend,
438 vertex_project: Option<&str>,
439 vertex_location: Option<&str>,
440 ) -> Result<()> {
441 if backend == Backend::VertexAi && (vertex_project.is_none() || vertex_location.is_none()) {
442 return Err(Error::InvalidConfig {
443 message: "Project and location required for Vertex AI".into(),
444 });
445 }
446 Ok(())
447 }
448
449 fn resolve_credentials(
450 backend: Backend,
451 api_key: Option<&str>,
452 credentials: Option<Credentials>,
453 ) -> Result<Credentials> {
454 if credentials.is_some()
455 && api_key.is_some()
456 && !matches!(credentials, Some(Credentials::ApiKey(_)))
457 {
458 return Err(Error::InvalidConfig {
459 message: "API key cannot be combined with OAuth/ADC credentials".into(),
460 });
461 }
462
463 let credentials = match credentials {
464 Some(credentials) => credentials,
465 None => {
466 if let Some(api_key) = api_key {
467 Credentials::ApiKey(api_key.to_string())
468 } else if backend == Backend::VertexAi {
469 Credentials::ApplicationDefault
470 } else {
471 return Err(Error::InvalidConfig {
472 message: "API key or OAuth credentials required for Gemini API".into(),
473 });
474 }
475 }
476 };
477
478 if backend == Backend::VertexAi && matches!(credentials, Credentials::ApiKey(_)) {
479 return Err(Error::InvalidConfig {
480 message: "Vertex AI does not support API key authentication".into(),
481 });
482 }
483
484 Ok(credentials)
485 }
486
487 fn build_headers(
488 http_options: &HttpOptions,
489 backend: Backend,
490 credentials: &Credentials,
491 ) -> Result<HeaderMap> {
492 let mut headers = HeaderMap::new();
493 for (key, value) in &http_options.headers {
494 let name =
495 HeaderName::from_bytes(key.as_bytes()).map_err(|_| Error::InvalidConfig {
496 message: format!("Invalid header name: {key}"),
497 })?;
498 let value = HeaderValue::from_str(value).map_err(|_| Error::InvalidConfig {
499 message: format!("Invalid header value for {key}"),
500 })?;
501 headers.insert(name, value);
502 }
503
504 if backend == Backend::GeminiApi {
505 let api_key = match credentials {
506 Credentials::ApiKey(key) => key.as_str(),
507 _ => "",
508 };
509 let header_name = HeaderName::from_static("x-goog-api-key");
510 if !api_key.is_empty() && !headers.contains_key(&header_name) {
511 let mut header_value =
512 HeaderValue::from_str(api_key).map_err(|_| Error::InvalidConfig {
513 message: "Invalid API key value".into(),
514 })?;
515 header_value.set_sensitive(true);
516 headers.insert(header_name, header_value);
517 }
518 }
519
520 Ok(headers)
521 }
522
523 fn build_http_client(http_options: &HttpOptions, headers: HeaderMap) -> Result<HttpClient> {
524 let mut http_builder = HttpClient::builder();
525 if let Some(timeout) = http_options.timeout {
526 http_builder = http_builder.timeout(Duration::from_secs(timeout));
527 }
528
529 if let Some(proxy_url) = &http_options.proxy {
530 let proxy = Proxy::all(proxy_url).map_err(|e| Error::InvalidConfig {
531 message: format!("Invalid proxy: {e}"),
532 })?;
533 http_builder = http_builder.proxy(proxy);
534 }
535
536 if !headers.is_empty() {
537 http_builder = http_builder.default_headers(headers);
538 }
539
540 Ok(http_builder.build()?)
541 }
542
543 fn build_vertex_config(
544 backend: Backend,
545 vertex_project: Option<String>,
546 vertex_location: Option<String>,
547 ) -> Result<Option<VertexConfig>> {
548 if backend != Backend::VertexAi {
549 return Ok(None);
550 }
551 let project = vertex_project.ok_or_else(|| Error::InvalidConfig {
552 message: "Project and location required for Vertex AI".into(),
553 })?;
554 let location = vertex_location.ok_or_else(|| Error::InvalidConfig {
555 message: "Project and location required for Vertex AI".into(),
556 })?;
557 Ok(Some(VertexConfig {
558 project,
559 location,
560 credentials: None,
561 }))
562 }
563}
564
565fn build_auth_provider(credentials: &Credentials) -> Result<Option<AuthProvider>> {
566 match credentials {
567 Credentials::ApiKey(_) => Ok(None),
568 Credentials::OAuth {
569 client_secret_path,
570 token_cache_path,
571 } => Ok(Some(AuthProvider::OAuth(Arc::new(
572 OAuthTokenProvider::from_paths(client_secret_path.clone(), token_cache_path.clone())?,
573 )))),
574 Credentials::ApplicationDefault => Ok(Some(AuthProvider::ApplicationDefault(Arc::new(
575 OnceCell::new(),
576 )))),
577 }
578}
579
580#[derive(Clone)]
581pub(crate) enum AuthProvider {
582 OAuth(Arc<OAuthTokenProvider>),
583 ApplicationDefault(Arc<OnceCell<Arc<GoogleCredentials>>>),
584}
585
586impl AuthProvider {
587 async fn headers(&self, scopes: &[&str]) -> Result<HeaderMap> {
588 match self {
589 Self::OAuth(provider) => {
590 let token = provider.token().await?;
591 let mut header =
592 HeaderValue::from_str(&format!("Bearer {token}")).map_err(|_| Error::Auth {
593 message: "Invalid OAuth access token".into(),
594 })?;
595 header.set_sensitive(true);
596 let mut headers = HeaderMap::new();
597 headers.insert(AUTHORIZATION, header);
598 Ok(headers)
599 }
600 Self::ApplicationDefault(cell) => {
601 let credentials = cell
602 .get_or_try_init(|| async {
603 AuthBuilder::default()
604 .with_scopes(scopes.iter().copied())
605 .build()
606 .map(Arc::new)
607 .map_err(|err| Error::Auth {
608 message: format!("ADC init failed: {err}"),
609 })
610 })
611 .await?;
612 let headers = credentials
613 .headers(Extensions::new())
614 .await
615 .map_err(|err| Error::Auth {
616 message: format!("ADC header fetch failed: {err}"),
617 })?;
618 match headers {
619 CacheableResource::New { data, .. } => Ok(data),
620 CacheableResource::NotModified => Err(Error::Auth {
621 message: "ADC header fetch returned NotModified without cached headers"
622 .into(),
623 }),
624 }
625 }
626 }
627 }
628}
629
630const DEFAULT_RETRY_ATTEMPTS: u32 = 5; const DEFAULT_RETRY_INITIAL_DELAY_SECS: f64 = 1.0;
632const DEFAULT_RETRY_MAX_DELAY_SECS: f64 = 60.0;
633const DEFAULT_RETRY_EXP_BASE: f64 = 2.0;
634const DEFAULT_RETRY_JITTER: f64 = 1.0;
635const DEFAULT_RETRY_HTTP_STATUS_CODES: [u16; 6] = [408, 429, 500, 502, 503, 504];
636
637impl ClientInner {
638 pub async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
643 self.send_with_http_options(request, None).await
644 }
645
646 pub async fn send_with_http_options(
651 &self,
652 request: reqwest::RequestBuilder,
653 request_http_options: Option<&rust_genai_types::http::HttpOptions>,
654 ) -> Result<reqwest::Response> {
655 let retry_options = request_http_options
656 .and_then(|options| options.retry_options.as_ref())
657 .or(self.config.http_options.retry_options.as_ref());
658
659 let request_template = request.build()?;
660 if let Some(options) = retry_options {
661 self.execute_with_retry(request_template, options).await
662 } else {
663 self.execute_once(request_template).await
664 }
665 }
666
667 async fn execute_once(&self, mut request: reqwest::Request) -> Result<reqwest::Response> {
668 self.prepare_request(&mut request).await?;
669 Ok(self.http.execute(request).await?)
670 }
671
672 async fn execute_with_retry(
673 &self,
674 request_template: reqwest::Request,
675 retry_options: &HttpRetryOptions,
676 ) -> Result<reqwest::Response> {
677 let attempts = retry_options.attempts.unwrap_or(DEFAULT_RETRY_ATTEMPTS);
678 if attempts <= 1 {
679 return self.execute_once(request_template).await;
680 }
681
682 if request_template.try_clone().is_none() {
684 return self.execute_once(request_template).await;
685 }
686
687 let retryable_codes: &[u16] = retry_options
688 .http_status_codes
689 .as_deref()
690 .unwrap_or(&DEFAULT_RETRY_HTTP_STATUS_CODES);
691
692 for attempt in 0..attempts {
693 let request = request_template
694 .try_clone()
695 .expect("request_template is cloneable");
696 let response = self.execute_once(request).await?;
697
698 if response.status().is_success() {
699 return Ok(response);
700 }
701
702 let status = response.status().as_u16();
703 let should_retry = retryable_codes.contains(&status);
704 let is_last_attempt = attempt + 1 >= attempts;
705 if !should_retry || is_last_attempt {
706 return Ok(response);
707 }
708
709 drop(response);
711
712 let delay = retry_delay_secs(retry_options, attempt);
713 if delay > 0.0 {
714 tokio::time::sleep(Duration::from_secs_f64(delay)).await;
715 }
716 }
717
718 unreachable!("retry loop must return a response");
720 }
721
722 async fn prepare_request(&self, request: &mut reqwest::Request) -> Result<()> {
723 if let Some(headers) = self.auth_headers().await? {
724 for (name, value) in &headers {
725 if request.headers().contains_key(name) {
726 continue;
727 }
728 let mut value = value.clone();
729 if name == AUTHORIZATION {
730 value.set_sensitive(true);
731 }
732 request.headers_mut().insert(name.clone(), value);
733 }
734 }
735 #[cfg(feature = "mcp")]
736 crate::mcp::append_mcp_usage_header(request.headers_mut())?;
737 Ok(())
738 }
739
740 async fn auth_headers(&self) -> Result<Option<HeaderMap>> {
741 let Some(provider) = &self.auth_provider else {
742 return Ok(None);
743 };
744
745 let scopes: Vec<&str> = self.config.auth_scopes.iter().map(String::as_str).collect();
746 let headers = provider.headers(&scopes).await?;
747 Ok(Some(headers))
748 }
749}
750
751fn retry_delay_secs(options: &HttpRetryOptions, retry_index: u32) -> f64 {
752 let initial = options
753 .initial_delay
754 .unwrap_or(DEFAULT_RETRY_INITIAL_DELAY_SECS)
755 .max(0.0);
756 let max_delay = options
757 .max_delay
758 .unwrap_or(DEFAULT_RETRY_MAX_DELAY_SECS)
759 .max(0.0);
760 let exp_base = options.exp_base.unwrap_or(DEFAULT_RETRY_EXP_BASE).max(0.0);
761 let jitter = options.jitter.unwrap_or(DEFAULT_RETRY_JITTER).max(0.0);
762
763 let exp_delay = if exp_base == 0.0 {
764 0.0
765 } else {
766 initial * exp_base.powf(retry_index as f64)
767 };
768 let base_delay = if max_delay > 0.0 {
769 exp_delay.min(max_delay)
770 } else {
771 exp_delay
772 };
773
774 let jitter_delay = if jitter > 0.0 {
775 let nanos = SystemTime::now()
777 .duration_since(UNIX_EPOCH)
778 .unwrap_or_default()
779 .subsec_nanos() as f64;
780 let frac = (nanos / 1_000_000_000.0).clamp(0.0, 1.0);
781 frac * jitter
782 } else {
783 0.0
784 };
785
786 let delay = base_delay + jitter_delay;
787 if max_delay > 0.0 {
788 delay.min(max_delay)
789 } else {
790 delay
791 }
792}
793
794fn default_auth_scopes(backend: Backend) -> Vec<String> {
795 match backend {
796 Backend::VertexAi => vec!["https://www.googleapis.com/auth/cloud-platform".into()],
797 Backend::GeminiApi => vec![
798 "https://www.googleapis.com/auth/generative-language".into(),
799 "https://www.googleapis.com/auth/generative-language.retriever".into(),
800 ],
801 }
802}
803
804pub(crate) struct ApiClient {
805 pub base_url: String,
806 pub api_version: String,
807}
808
809impl ApiClient {
810 pub fn new(config: &ClientConfig) -> Self {
812 let base_url = config.http_options.base_url.as_deref().map_or_else(
813 || match config.backend {
814 Backend::VertexAi => {
815 let location = config
816 .vertex_config
817 .as_ref()
818 .map_or("", |cfg| cfg.location.as_str());
819 if location.is_empty() {
820 "https://aiplatform.googleapis.com/".to_string()
821 } else {
822 format!("https://{location}-aiplatform.googleapis.com/")
823 }
824 }
825 Backend::GeminiApi => "https://generativelanguage.googleapis.com/".to_string(),
826 },
827 normalize_base_url,
828 );
829
830 let api_version =
831 config
832 .http_options
833 .api_version
834 .clone()
835 .unwrap_or_else(|| match config.backend {
836 Backend::VertexAi => "v1beta1".to_string(),
837 Backend::GeminiApi => "v1beta".to_string(),
838 });
839
840 Self {
841 base_url,
842 api_version,
843 }
844 }
845}
846
847fn normalize_base_url(base_url: &str) -> String {
848 let mut value = base_url.trim().to_string();
849 if !value.ends_with('/') {
850 value.push('/');
851 }
852 value
853}
854
855#[cfg(test)]
856mod tests {
857 use super::*;
858 use crate::test_support::with_env;
859 use std::path::PathBuf;
860 use tempfile::tempdir;
861
862 #[test]
863 fn test_client_from_api_key() {
864 let client = Client::new("test-api-key").unwrap();
865 assert_eq!(client.inner.config.backend, Backend::GeminiApi);
866 }
867
868 #[test]
869 fn test_client_builder() {
870 let client = Client::builder()
871 .api_key("test-key")
872 .timeout(30)
873 .build()
874 .unwrap();
875 assert!(client.inner.config.api_key.is_some());
876 }
877
878 #[test]
879 fn test_vertex_ai_config() {
880 let client = Client::new_vertex("my-project", "us-central1").unwrap();
881 assert_eq!(client.inner.config.backend, Backend::VertexAi);
882 assert_eq!(
883 client.inner.api_client.base_url,
884 "https://us-central1-aiplatform.googleapis.com/"
885 );
886 }
887
888 #[test]
889 fn test_base_url_normalization() {
890 let client = Client::builder()
891 .api_key("test-key")
892 .base_url("https://example.com")
893 .build()
894 .unwrap();
895 assert_eq!(client.inner.api_client.base_url, "https://example.com/");
896 }
897
898 #[test]
899 fn test_from_env_reads_overrides() {
900 with_env(
901 &[
902 ("GEMINI_API_KEY", Some("env-key")),
903 ("GENAI_BASE_URL", Some("https://env.example.com")),
904 ("GENAI_API_VERSION", Some("v99")),
905 ("GOOGLE_API_KEY", None),
906 ],
907 || {
908 let client = Client::from_env().unwrap();
909 assert_eq!(client.inner.api_client.base_url, "https://env.example.com/");
910 assert_eq!(client.inner.api_client.api_version, "v99");
911 },
912 );
913 }
914
915 #[test]
916 fn test_from_env_ignores_empty_overrides() {
917 with_env(
918 &[
919 ("GEMINI_API_KEY", Some("env-key")),
920 ("GENAI_BASE_URL", Some(" ")),
921 ("GENAI_API_VERSION", Some("")),
922 ("GOOGLE_API_KEY", None),
923 ],
924 || {
925 let client = Client::from_env().unwrap();
926 assert_eq!(
927 client.inner.api_client.base_url,
928 "https://generativelanguage.googleapis.com/"
929 );
930 assert_eq!(client.inner.api_client.api_version, "v1beta");
931 },
932 );
933 }
934
935 #[test]
936 fn test_from_env_missing_key_errors() {
937 with_env(
938 &[
939 ("GEMINI_API_KEY", None),
940 ("GOOGLE_API_KEY", None),
941 ("GENAI_BASE_URL", None),
942 ],
943 || {
944 let result = Client::from_env();
945 assert!(result.is_err());
946 },
947 );
948 }
949
950 #[test]
951 fn test_from_env_google_api_key_fallback() {
952 with_env(
953 &[
954 ("GEMINI_API_KEY", None),
955 ("GOOGLE_API_KEY", Some("google-key")),
956 ],
957 || {
958 let client = Client::from_env().unwrap();
959 assert_eq!(client.inner.config.api_key.as_deref(), Some("google-key"));
960 },
961 );
962 }
963
964 #[test]
965 fn test_with_oauth_missing_client_secret_errors() {
966 let dir = tempdir().unwrap();
967 let secret_path = dir.path().join("missing_client_secret.json");
968 let err = Client::with_oauth(&secret_path).err().unwrap();
969 assert!(matches!(err, Error::InvalidConfig { .. }));
970 }
971
972 #[test]
973 fn test_with_adc_builds_client() {
974 let client = Client::with_adc().unwrap();
975 assert!(matches!(
976 client.inner.config.credentials,
977 Credentials::ApplicationDefault
978 ));
979 }
980
981 #[test]
982 fn test_builder_defaults_to_vertex_when_project_set() {
983 let client = Client::builder()
984 .vertex_project("proj")
985 .vertex_location("loc")
986 .build()
987 .unwrap();
988 assert_eq!(client.inner.config.backend, Backend::VertexAi);
989 assert!(matches!(
990 client.inner.config.credentials,
991 Credentials::ApplicationDefault
992 ));
993 }
994
995 #[test]
996 fn test_valid_proxy_is_accepted() {
997 let client = Client::builder()
998 .api_key("test-key")
999 .proxy("http://127.0.0.1:8888")
1000 .build();
1001 assert!(client.is_ok());
1002 }
1003
1004 #[test]
1005 fn test_vertex_requires_project_and_location() {
1006 let result = Client::builder().backend(Backend::VertexAi).build();
1007 assert!(result.is_err());
1008 }
1009
1010 #[test]
1011 fn test_api_key_with_oauth_is_invalid() {
1012 let result = Client::builder()
1013 .api_key("test-key")
1014 .credentials(Credentials::OAuth {
1015 client_secret_path: PathBuf::from("client_secret.json"),
1016 token_cache_path: None,
1017 })
1018 .build();
1019 assert!(result.is_err());
1020 }
1021
1022 #[test]
1023 fn test_missing_api_key_for_gemini_errors() {
1024 let result = Client::builder().backend(Backend::GeminiApi).build();
1025 assert!(result.is_err());
1026 }
1027
1028 #[test]
1029 fn test_invalid_header_name_is_rejected() {
1030 let result = Client::builder()
1031 .api_key("test-key")
1032 .header("bad header", "value")
1033 .build();
1034 assert!(result.is_err());
1035 }
1036
1037 #[test]
1038 fn test_invalid_header_value_is_rejected() {
1039 let result = Client::builder()
1040 .api_key("test-key")
1041 .header("x-test", "bad\nvalue")
1042 .build();
1043 assert!(result.is_err());
1044 }
1045
1046 #[test]
1047 fn test_invalid_api_key_value_is_rejected() {
1048 let err = Client::builder().api_key("bad\nkey").build().err().unwrap();
1049 assert!(
1050 matches!(err, Error::InvalidConfig { message } if message.contains("Invalid API key value"))
1051 );
1052 }
1053
1054 #[test]
1055 fn test_invalid_proxy_is_rejected() {
1056 let result = Client::builder()
1057 .api_key("test-key")
1058 .proxy("not a url")
1059 .build();
1060 assert!(result.is_err());
1061 }
1062
1063 #[test]
1064 fn test_vertex_api_key_is_rejected() {
1065 let result = Client::builder()
1066 .backend(Backend::VertexAi)
1067 .vertex_project("proj")
1068 .vertex_location("loc")
1069 .credentials(Credentials::ApiKey("key".into()))
1070 .build();
1071 assert!(result.is_err());
1072 }
1073
1074 #[test]
1075 fn test_default_auth_scopes() {
1076 let gemini = default_auth_scopes(Backend::GeminiApi);
1077 assert!(gemini.iter().any(|s| s.contains("generative-language")));
1078
1079 let vertex = default_auth_scopes(Backend::VertexAi);
1080 assert!(vertex.iter().any(|s| s.contains("cloud-platform")));
1081 }
1082
1083 #[test]
1084 fn test_custom_auth_scopes_override_default() {
1085 let client = Client::builder()
1086 .api_key("test-key")
1087 .auth_scopes(vec!["scope-1".to_string()])
1088 .build()
1089 .unwrap();
1090 assert_eq!(client.inner.config.auth_scopes, vec!["scope-1".to_string()]);
1091 }
1092}