1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6 AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapseError,
7 TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9
10use crate::backend::{ProviderBackend, ProviderRequest, ProviderResponse};
11
12#[derive(Debug, Clone)]
13pub struct GeminiConfig {
14 pub api_key: String,
15 pub model: String,
16 pub base_url: String,
17 pub top_p: Option<f64>,
18 pub stop: Option<Vec<String>>,
19}
20
21impl GeminiConfig {
22 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
23 Self {
24 api_key: api_key.into(),
25 model: model.into(),
26 base_url: "https://generativelanguage.googleapis.com".to_string(),
27 top_p: None,
28 stop: None,
29 }
30 }
31
32 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
33 self.base_url = url.into();
34 self
35 }
36
37 pub fn with_top_p(mut self, top_p: f64) -> Self {
38 self.top_p = Some(top_p);
39 self
40 }
41
42 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
43 self.stop = Some(stop);
44 self
45 }
46}
47
48pub struct GeminiChatModel {
49 config: GeminiConfig,
50 backend: Arc<dyn ProviderBackend>,
51}
52
53impl GeminiChatModel {
54 pub fn new(config: GeminiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
55 Self { config, backend }
56 }
57
58 fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
59 let mut system_text: Option<String> = None;
60 let mut contents: Vec<Value> = Vec::new();
61
62 for msg in &request.messages {
63 match msg {
64 Message::System { content, .. } => {
65 system_text = Some(content.clone());
66 }
67 Message::Human { content, .. } => {
68 contents.push(json!({
69 "role": "user",
70 "parts": [{"text": content}],
71 }));
72 }
73 Message::AI {
74 content,
75 tool_calls,
76 ..
77 } => {
78 let mut parts: Vec<Value> = Vec::new();
79 if !content.is_empty() {
80 parts.push(json!({"text": content}));
81 }
82 for tc in tool_calls {
83 parts.push(json!({
84 "functionCall": {
85 "name": tc.name,
86 "args": tc.arguments,
87 }
88 }));
89 }
90 contents.push(json!({
91 "role": "model",
92 "parts": parts,
93 }));
94 }
95 Message::Tool {
96 content,
97 tool_call_id: _,
98 ..
99 } => {
100 let result: Value =
102 serde_json::from_str(content).unwrap_or(json!({"result": content}));
103 contents.push(json!({
104 "role": "user",
105 "parts": [{
106 "functionResponse": {
107 "name": "tool",
108 "response": result,
109 }
110 }],
111 }));
112 }
113 Message::Chat { content, .. } => {
114 contents.push(json!({
115 "role": "user",
116 "parts": [{"text": content}],
117 }));
118 }
119 Message::Remove { .. } => { }
120 }
121 }
122
123 let mut body = json!({
124 "contents": contents,
125 });
126
127 if let Some(system) = system_text {
128 body["system_instruction"] = json!({
129 "parts": [{"text": system}],
130 });
131 }
132
133 {
134 let mut gen_config = json!({});
135 let mut has_gen_config = false;
136 if let Some(top_p) = self.config.top_p {
137 gen_config["topP"] = json!(top_p);
138 has_gen_config = true;
139 }
140 if let Some(ref stop) = self.config.stop {
141 gen_config["stopSequences"] = json!(stop);
142 has_gen_config = true;
143 }
144 if has_gen_config {
145 body["generationConfig"] = gen_config;
146 }
147 }
148
149 if !request.tools.is_empty() {
150 body["tools"] = json!([{
151 "functionDeclarations": request.tools.iter().map(tool_def_to_gemini).collect::<Vec<_>>(),
152 }]);
153 }
154 if let Some(ref choice) = request.tool_choice {
155 body["tool_config"] = match choice {
156 ToolChoice::Auto => json!({"functionCallingConfig": {"mode": "AUTO"}}),
157 ToolChoice::Required => json!({"functionCallingConfig": {"mode": "ANY"}}),
158 ToolChoice::None => json!({"functionCallingConfig": {"mode": "NONE"}}),
159 ToolChoice::Specific(name) => json!({
160 "functionCallingConfig": {
161 "mode": "ANY",
162 "allowedFunctionNames": [name]
163 }
164 }),
165 };
166 }
167
168 let method = if stream {
169 "streamGenerateContent"
170 } else {
171 "generateContent"
172 };
173
174 let mut url = format!(
175 "{}/v1beta/models/{}:{}?key={}",
176 self.config.base_url, self.config.model, method, self.config.api_key
177 );
178 if stream {
179 url.push_str("&alt=sse");
180 }
181
182 ProviderRequest {
183 url,
184 headers: vec![("Content-Type".to_string(), "application/json".to_string())],
185 body,
186 }
187 }
188}
189
190fn tool_def_to_gemini(def: &ToolDefinition) -> Value {
191 json!({
192 "name": def.name,
193 "description": def.description,
194 "parameters": def.parameters,
195 })
196}
197
198fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapseError> {
199 check_error_status(resp)?;
200
201 let parts = resp.body["candidates"][0]["content"]["parts"]
202 .as_array()
203 .cloned()
204 .unwrap_or_default();
205
206 let mut text = String::new();
207 let mut tool_calls = Vec::new();
208
209 for part in &parts {
210 if let Some(t) = part["text"].as_str() {
211 text.push_str(t);
212 }
213 if let Some(fc) = part.get("functionCall") {
214 if let Some(name) = fc["name"].as_str() {
215 tool_calls.push(ToolCall {
216 id: format!("gemini-{}", tool_calls.len()),
217 name: name.to_string(),
218 arguments: fc["args"].clone(),
219 });
220 }
221 }
222 }
223
224 let usage = parse_usage(&resp.body["usageMetadata"]);
225
226 let message = if tool_calls.is_empty() {
227 Message::ai(text)
228 } else {
229 Message::ai_with_tool_calls(text, tool_calls)
230 };
231
232 Ok(ChatResponse { message, usage })
233}
234
235fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapseError> {
236 if resp.status == 429 {
237 let msg = resp.body["error"]["message"]
238 .as_str()
239 .unwrap_or("rate limited")
240 .to_string();
241 return Err(SynapseError::RateLimit(msg));
242 }
243 if resp.status >= 400 {
244 let msg = resp.body["error"]["message"]
245 .as_str()
246 .unwrap_or("unknown API error")
247 .to_string();
248 return Err(SynapseError::Model(format!(
249 "Gemini API error ({}): {}",
250 resp.status, msg
251 )));
252 }
253 Ok(())
254}
255
256fn parse_usage(usage: &Value) -> Option<TokenUsage> {
257 if usage.is_null() {
258 return None;
259 }
260 Some(TokenUsage {
261 input_tokens: usage["promptTokenCount"].as_u64().unwrap_or(0) as u32,
262 output_tokens: usage["candidatesTokenCount"].as_u64().unwrap_or(0) as u32,
263 total_tokens: usage["totalTokenCount"].as_u64().unwrap_or(0) as u32,
264 input_details: None,
265 output_details: None,
266 })
267}
268
269fn parse_stream_chunk(data: &str) -> Option<AIMessageChunk> {
270 let v: Value = serde_json::from_str(data).ok()?;
271 let parts = v["candidates"][0]["content"]["parts"]
272 .as_array()
273 .cloned()
274 .unwrap_or_default();
275
276 let mut content = String::new();
277 let mut tool_calls = Vec::new();
278
279 for part in &parts {
280 if let Some(t) = part["text"].as_str() {
281 content.push_str(t);
282 }
283 if let Some(fc) = part.get("functionCall") {
284 if let Some(name) = fc["name"].as_str() {
285 tool_calls.push(ToolCall {
286 id: format!("gemini-{}", tool_calls.len()),
287 name: name.to_string(),
288 arguments: fc["args"].clone(),
289 });
290 }
291 }
292 }
293
294 let usage = parse_usage(&v["usageMetadata"]);
295
296 Some(AIMessageChunk {
297 content,
298 tool_calls,
299 usage,
300 ..Default::default()
301 })
302}
303
304#[async_trait]
305impl ChatModel for GeminiChatModel {
306 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapseError> {
307 let provider_req = self.build_request(&request, false);
308 let resp = self.backend.send(provider_req).await?;
309 parse_response(&resp)
310 }
311
312 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
313 Box::pin(async_stream::stream! {
314 let provider_req = self.build_request(&request, true);
315 let byte_stream = self.backend.send_stream(provider_req).await;
316
317 let byte_stream = match byte_stream {
318 Ok(s) => s,
319 Err(e) => {
320 yield Err(e);
321 return;
322 }
323 };
324
325 use eventsource_stream::Eventsource;
326 use futures::StreamExt;
327
328 let mut event_stream = byte_stream
329 .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
330 .eventsource();
331
332 while let Some(event) = event_stream.next().await {
333 match event {
334 Ok(ev) => {
335 if let Some(chunk) = parse_stream_chunk(&ev.data) {
336 yield Ok(chunk);
337 }
338 }
339 Err(e) => {
340 yield Err(SynapseError::Model(format!("SSE parse error: {e}")));
341 break;
342 }
343 }
344 }
345 })
346 }
347}