Skip to main content

tt_provider_gemini/
lib.rs

1//! Google Gemini provider adapter.
2//!
3//! Implements [`tt_shared::Provider`] for Google Gemini's
4//! `generateContent` and `streamGenerateContent` endpoints.
5//! Non-streaming and streaming (SSE via `?alt=sse`) are both fully supported.
6//! Embeddings use separate Gemini embedding models and are not wired here;
7//! they return [`ProviderError::Unsupported`].
8//!
9//! # Usage
10//!
11//! ```rust,no_run
12//! use tt_provider_gemini::{GeminiProvider, ClientConfig};
13//!
14//! let provider = GeminiProvider::new(ClientConfig::default());
15//! ```
16//!
17//! # API differences from OpenAI
18//!
19//! - Model is in the URL path, not the request body.
20//! - Auth is the `x-goog-api-key` request header (NOT a URL `?key=` query
21//!   param — keys in URLs leak via logs/proxies; see review §5.2).
22//! - System messages map to `systemInstruction`.
23//! - Tools use `functionDeclarations` inside a single `tools` object.
24//! - Streaming uses SSE format with `?alt=sse`.
25
26pub mod client;
27pub mod errors;
28pub mod pricing;
29pub mod stream;
30pub mod translate;
31
32use async_trait::async_trait;
33use futures::stream::BoxStream;
34use reqwest::Client;
35use tracing::instrument;
36use tt_shared::{
37    filter_extra_headers, validate_provider_url, ChatCompletionChunk, ChatCompletionRequest,
38    ChatCompletionResponse, EmbeddingsRequest, EmbeddingsResponse, ModelInfo, ModelPricing,
39    Provider, ProviderError, RequestContext,
40};
41
42pub use client::ClientConfig;
43
44/// Default Gemini API base URL.
45const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com";
46
47/// Stateless Gemini adapter. Holds an HTTP client and the static pricing table.
48///
49/// Create once with [`GeminiProvider::new`] and share across requests.
50pub struct GeminiProvider {
51    client: Client,
52    /// When `true`, skip SSRF URL validation for private/loopback addresses.
53    /// Always `false` in production; set to `true` only in tests that target
54    /// a local mock server.
55    allow_local: bool,
56}
57
58impl GeminiProvider {
59    /// Create a new [`GeminiProvider`] from the given client configuration.
60    ///
61    /// # Panics
62    ///
63    /// Panics if the underlying [`reqwest::Client`] cannot be constructed (very
64    /// rare — only happens with invalid TLS configuration).
65    pub fn new(cfg: ClientConfig) -> Self {
66        let client =
67            client::build_client(&cfg).expect("failed to build reqwest::Client for Gemini adapter");
68        Self {
69            client,
70            allow_local: false,
71        }
72    }
73
74    /// Create an adapter that skips SSRF URL validation for tests targeting a
75    /// local mock server.
76    ///
77    /// # Warning
78    ///
79    /// Do not use in production code. This bypasses the SSRF guard.
80    #[doc(hidden)]
81    pub fn new_allow_local(cfg: ClientConfig) -> Self {
82        let client =
83            client::build_client(&cfg).expect("failed to build reqwest::Client for Gemini adapter");
84        Self {
85            client,
86            allow_local: true,
87        }
88    }
89
90    /// Resolve the base URL from credentials or fall back to the default.
91    fn base_url<'a>(&self, ctx: &'a RequestContext) -> &'a str {
92        ctx.credentials
93            .base_url
94            .as_deref()
95            .unwrap_or(DEFAULT_BASE_URL)
96    }
97}
98
99#[async_trait]
100impl Provider for GeminiProvider {
101    fn id(&self) -> &'static str {
102        "gemini"
103    }
104
105    fn models(&self) -> Vec<ModelInfo> {
106        pricing::all_models()
107    }
108
109    fn pricing(&self, model: &str) -> Option<ModelPricing> {
110        pricing::pricing_for(model)
111    }
112
113    fn dropped_params(&self, req: &tt_shared::ChatCompletionRequest) -> Vec<String> {
114        // Mirror translate.rs: Gemini drops these; response_format is translated.
115        let mut out = Vec::new();
116        if req.n.is_some() {
117            out.push("n".to_string());
118        }
119        if req.seed.is_some() {
120            out.push("seed".to_string());
121        }
122        if req.presence_penalty.is_some() {
123            out.push("presence_penalty".to_string());
124        }
125        if req.frequency_penalty.is_some() {
126            out.push("frequency_penalty".to_string());
127        }
128        if req.user.is_some() {
129            out.push("user".to_string());
130        }
131        out
132    }
133
134    /// Non-streaming chat completion via
135    /// `POST /v1beta/models/{model}:generateContent` (key in `x-goog-api-key` header).
136    ///
137    /// Translates the canonical request to Gemini's wire format, sends it,
138    /// and maps errors to [`ProviderError`].
139    #[instrument(skip(self, ctx), fields(provider = "gemini", model = %req.model))]
140    async fn chat_completion(
141        &self,
142        req: ChatCompletionRequest,
143        ctx: &RequestContext,
144    ) -> Result<ChatCompletionResponse, ProviderError> {
145        let base_url = self.base_url(ctx);
146        // Validate customer-supplied base_url overrides; skip when using the
147        // compiled-in default (always safe) or when allow_local is set (tests).
148        if ctx.credentials.base_url.is_some() {
149            validate_provider_url(base_url, self.allow_local)
150                .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
151        }
152
153        let api_key = ctx.credentials.api_key.expose().to_string();
154        let model = req.model.clone();
155
156        translate::validate_model_id(&model)?;
157        let url = format!("{base_url}/v1beta/models/{model}:generateContent");
158
159        let body = translate::translate_request(req)?;
160
161        let mut request_builder = self
162            .client
163            .post(&url)
164            .header("Content-Type", "application/json")
165            .header("x-goog-api-key", &api_key)
166            .json(&body);
167        // Forward customer-supplied extra headers (denylist-filtered), matching
168        // the OpenAI/Anthropic/compat adapters.
169        for (name, value) in &filter_extra_headers(&ctx.credentials.extra_headers) {
170            request_builder = request_builder.header(name, value);
171        }
172        let response = request_builder
173            .send()
174            .await
175            .map_err(errors::map_reqwest_error)?;
176
177        let status = response.status().as_u16();
178        let retry_after = response
179            .headers()
180            .get("retry-after")
181            .and_then(|v| v.to_str().ok())
182            .map(|s| s.to_string());
183
184        let response_text = response.text().await.map_err(errors::map_reqwest_error)?;
185
186        if status >= 400 {
187            return Err(errors::map_response_error(
188                status,
189                &response_text,
190                retry_after.as_deref(),
191                &model,
192            ));
193        }
194
195        translate::deserialize_response(&response_text, &model)
196    }
197
198    /// Streaming chat completion via
199    /// `POST /v1beta/models/{model}:streamGenerateContent?alt=sse` (key in `x-goog-api-key` header).
200    ///
201    /// Returns [`ProviderError`] before yielding any chunk if the server
202    /// responds with HTTP ≥ 400. Otherwise returns a `BoxStream` that parses
203    /// Gemini SSE events and yields [`ChatCompletionChunk`] values.
204    #[instrument(skip(self, ctx), fields(provider = "gemini", model = %req.model))]
205    async fn chat_completion_stream(
206        &self,
207        req: ChatCompletionRequest,
208        ctx: &RequestContext,
209    ) -> Result<BoxStream<'static, Result<ChatCompletionChunk, ProviderError>>, ProviderError> {
210        let base_url = self.base_url(ctx);
211        if ctx.credentials.base_url.is_some() {
212            validate_provider_url(base_url, self.allow_local)
213                .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
214        }
215
216        let base_url = base_url.to_string();
217        let client = self.client.clone();
218        stream::stream_chat_completion(client, &base_url, req, ctx).await
219    }
220
221    /// Embeddings are not supported by this adapter.
222    ///
223    /// Gemini uses separate embedding models (e.g. `text-embedding-004`) via a
224    /// different endpoint. Those are wired as a separate task.
225    ///
226    /// Always returns [`ProviderError::Unsupported`].
227    async fn embeddings(
228        &self,
229        _req: EmbeddingsRequest,
230        _ctx: &RequestContext,
231    ) -> Result<EmbeddingsResponse, ProviderError> {
232        Err(ProviderError::Unsupported(
233            "Gemini embedding models use a separate endpoint; use a dedicated embedding adapter"
234                .to_string(),
235        ))
236    }
237}