stygian_graph/adapters/ai/
gemini.rs1use std::time::Duration;
21
22use async_trait::async_trait;
23use futures::stream::{self, BoxStream};
24use reqwest::Client;
25use serde_json::{Value, json};
26
27use crate::domain::error::{ProviderError, Result, StygianError};
28use crate::ports::{AIProvider, ProviderCapabilities};
29
30const DEFAULT_MODEL: &str = "gemini-2.0-flash";
32
33const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
35
36#[derive(Debug, Clone)]
38pub struct GeminiConfig {
39 pub api_key: String,
41 pub model: String,
43 pub max_tokens: u32,
45 pub timeout: Duration,
47}
48
49impl GeminiConfig {
50 #[must_use]
52 pub fn new(api_key: String) -> Self {
53 Self {
54 api_key,
55 model: DEFAULT_MODEL.to_string(),
56 max_tokens: 8192,
57 timeout: Duration::from_mins(2),
58 }
59 }
60
61 #[must_use]
63 pub fn with_model(mut self, model: impl Into<String>) -> Self {
64 self.model = model.into();
65 self
66 }
67}
68
69pub struct GeminiProvider {
71 config: GeminiConfig,
72 client: Client,
73}
74
75impl GeminiProvider {
76 #[must_use]
85 pub fn new(api_key: String) -> Self {
86 Self::with_config(GeminiConfig::new(api_key))
87 }
88
89 #[must_use]
105 pub fn with_config(config: GeminiConfig) -> Self {
106 #[allow(clippy::expect_used)]
108 let client = Client::builder()
109 .timeout(config.timeout)
110 .build()
111 .expect("Failed to build HTTP client");
112 Self { config, client }
113 }
114
115 fn api_url(&self) -> String {
116 format!(
117 "{}/{}:generateContent?key={}",
118 API_BASE, self.config.model, self.config.api_key
119 )
120 }
121
122 fn build_body(&self, content: &str, schema: &Value) -> Value {
123 let prompt = format!(
124 "Extract structured data from the following content according to this JSON schema.\n\
125 Return ONLY valid JSON matching the schema.\n\
126 Schema: {}\n\nContent:\n{}",
127 serde_json::to_string(schema).unwrap_or_default(),
128 content
129 );
130
131 json!({
132 "contents": [{"parts": [{"text": prompt}]}],
133 "generationConfig": {
134 "maxOutputTokens": self.config.max_tokens,
135 "responseMimeType": "application/json",
136 "responseSchema": schema
137 }
138 })
139 }
140
141 fn parse_response(response: &Value) -> Result<Value> {
142 let text = response
143 .pointer("/candidates/0/content/parts/0/text")
144 .and_then(Value::as_str)
145 .ok_or_else(|| {
146 StygianError::Provider(ProviderError::ApiError(
147 "No text in Gemini response".to_string(),
148 ))
149 })?;
150
151 serde_json::from_str(text).map_err(|e| {
152 StygianError::Provider(ProviderError::ApiError(format!(
153 "Failed to parse Gemini JSON response: {e}"
154 )))
155 })
156 }
157
158 fn map_http_error(status: u16, body: &str) -> StygianError {
159 match status {
160 400 if body.contains("API_KEY") => {
161 StygianError::Provider(ProviderError::InvalidCredentials)
162 }
163 429 => StygianError::Provider(ProviderError::ApiError(format!(
164 "Gemini rate limited: {body}"
165 ))),
166 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
167 }
168 }
169}
170
171#[async_trait]
172impl AIProvider for GeminiProvider {
173 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
174 let body = self.build_body(&content, &schema);
175 let url = self.api_url();
176
177 let response = self
178 .client
179 .post(&url)
180 .header("Content-Type", "application/json")
181 .json(&body)
182 .send()
183 .await
184 .map_err(|e| {
185 StygianError::Provider(ProviderError::ApiError(format!(
186 "Gemini request failed: {e}"
187 )))
188 })?;
189
190 let status = response.status().as_u16();
191 let text = response
192 .text()
193 .await
194 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
195
196 if status != 200 {
197 return Err(Self::map_http_error(status, &text));
198 }
199
200 let json_val: Value = serde_json::from_str(&text)
201 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
202
203 Self::parse_response(&json_val)
204 }
205
206 async fn stream_extract(
207 &self,
208 content: String,
209 schema: Value,
210 ) -> Result<BoxStream<'static, Result<Value>>> {
211 let result = self.extract(content, schema).await;
212 Ok(Box::pin(stream::once(async move { result })))
213 }
214
215 fn capabilities(&self) -> ProviderCapabilities {
216 ProviderCapabilities {
217 streaming: true,
218 vision: true,
219 tool_use: false,
220 json_mode: true,
221 }
222 }
223
224 fn name(&self) -> &'static str {
225 "gemini"
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use serde_json::json;
233
234 #[test]
235 fn test_name() {
236 assert_eq!(GeminiProvider::new("k".to_string()).name(), "gemini");
237 }
238
239 #[test]
240 fn test_capabilities() {
241 let caps = GeminiProvider::new("k".to_string()).capabilities();
242 assert!(caps.json_mode);
243 assert!(caps.vision);
244 }
245
246 #[test]
247 fn test_api_url_contains_model_and_key() {
248 let p = GeminiProvider::new("my-key".to_string());
249 let url = p.api_url();
250 assert!(url.contains(DEFAULT_MODEL));
251 assert!(url.contains("my-key"));
252 }
253
254 #[test]
255 fn test_build_body_has_response_mime() {
256 let p = GeminiProvider::new("k".to_string());
257 let body = p.build_body("content", &json!({"type": "object"}));
258 assert_eq!(
259 body.get("generationConfig")
260 .and_then(|gc| gc.get("responseMimeType"))
261 .and_then(Value::as_str),
262 Some("application/json")
263 );
264 }
265
266 #[test]
267 fn test_parse_response_valid() -> Result<()> {
268 let resp = json!({
269 "candidates": [{
270 "content": {"parts": [{"text": "{\"name\": \"Alice\"}"}]}
271 }]
272 });
273 let val = GeminiProvider::parse_response(&resp)?;
274 assert_eq!(val.get("name").and_then(Value::as_str), Some("Alice"));
275 Ok(())
276 }
277
278 #[test]
279 fn test_parse_response_no_candidates() {
280 let resp = json!({"promptFeedback": {}});
281 assert!(GeminiProvider::parse_response(&resp).is_err());
282 }
283
284 #[test]
285 fn test_parse_response_invalid_json_text() {
286 let resp = json!({
287 "candidates": [{
288 "content": {"parts": [{"text": "not json at all"}]}
289 }]
290 });
291 assert!(GeminiProvider::parse_response(&resp).is_err());
292 }
293
294 #[test]
295 fn test_map_http_error_api_key() {
296 let err = GeminiProvider::map_http_error(400, "Invalid API_KEY provided");
297 assert!(matches!(
298 err,
299 StygianError::Provider(ProviderError::InvalidCredentials)
300 ));
301 }
302
303 #[test]
304 fn test_map_http_error_429() {
305 let err = GeminiProvider::map_http_error(429, "quota exceeded");
306 assert!(
307 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
308 );
309 }
310
311 #[test]
312 fn test_map_http_error_server_error() {
313 let err = GeminiProvider::map_http_error(503, "unavailable");
314 assert!(
315 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("503"))
316 );
317 }
318
319 #[test]
320 fn test_config_with_model() {
321 let cfg = GeminiConfig::new("AIza".to_string()).with_model("gemini-1.5-pro");
322 assert_eq!(cfg.model, "gemini-1.5-pro");
323 }
324}