Skip to main content

zag_agent/providers/
gemini.rs

1use crate::agent::{Agent, ModelSize};
2use crate::output::AgentOutput;
3use crate::sandbox::SandboxConfig;
4use crate::session_log::{
5    BackfilledSession, HistoricalLogAdapter, LiveLogAdapter, LiveLogContext, LogCompleteness,
6    LogEventKind, LogSourceKind, SessionLogMetadata, SessionLogWriter,
7};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use log::info;
11use std::collections::HashSet;
12use std::path::Path;
13use std::process::Stdio;
14use tokio::fs;
15use tokio::process::Command;
16
17/// Return the Gemini tmp directory: `~/.gemini/tmp/`.
18pub fn tmp_dir() -> Option<std::path::PathBuf> {
19    dirs::home_dir().map(|h| h.join(".gemini/tmp"))
20}
21
22pub const DEFAULT_MODEL: &str = "auto";
23
24pub const AVAILABLE_MODELS: &[&str] = &[
25    "auto",
26    "gemini-3-pro-preview",
27    "gemini-3-flash-preview",
28    "gemini-2.5-pro",
29    "gemini-2.5-flash",
30    "gemini-2.5-flash-lite",
31];
32
33pub struct Gemini {
34    system_prompt: String,
35    model: String,
36    root: Option<String>,
37    skip_permissions: bool,
38    output_format: Option<String>,
39    add_dirs: Vec<String>,
40    capture_output: bool,
41    sandbox: Option<SandboxConfig>,
42    max_turns: Option<u32>,
43}
44
45pub struct GeminiLiveLogAdapter {
46    ctx: LiveLogContext,
47    session_path: Option<std::path::PathBuf>,
48    emitted_message_ids: std::collections::HashSet<String>,
49}
50
51pub struct GeminiHistoricalLogAdapter;
52
53impl Gemini {
54    pub fn new() -> Self {
55        Self {
56            system_prompt: String::new(),
57            model: DEFAULT_MODEL.to_string(),
58            root: None,
59            skip_permissions: false,
60            output_format: None,
61            add_dirs: Vec::new(),
62            capture_output: false,
63            sandbox: None,
64            max_turns: None,
65        }
66    }
67
68    fn get_base_path(&self) -> &Path {
69        self.root.as_ref().map(Path::new).unwrap_or(Path::new("."))
70    }
71
72    async fn write_system_file(&self) -> Result<()> {
73        let base = self.get_base_path();
74        log::debug!("Writing Gemini system file to {}", base.display());
75        let gemini_dir = base.join(".gemini");
76        fs::create_dir_all(&gemini_dir).await?;
77        fs::write(gemini_dir.join("system.md"), &self.system_prompt).await?;
78        Ok(())
79    }
80
81    /// Build the argument list for a run/exec invocation.
82    fn build_run_args(&self, interactive: bool, prompt: Option<&str>) -> Vec<String> {
83        let mut args = Vec::new();
84
85        if self.skip_permissions {
86            args.extend(["--approval-mode", "yolo"].map(String::from));
87        }
88
89        if !self.model.is_empty() && self.model != "auto" {
90            args.extend(["--model".to_string(), self.model.clone()]);
91        }
92
93        for dir in &self.add_dirs {
94            args.extend(["--include-directories".to_string(), dir.clone()]);
95        }
96
97        if !interactive && let Some(ref format) = self.output_format {
98            args.extend(["--output-format".to_string(), format.clone()]);
99        }
100
101        if let Some(turns) = self.max_turns {
102            args.extend(["--max-turns".to_string(), turns.to_string()]);
103        }
104
105        if let Some(p) = prompt {
106            args.push(p.to_string());
107        }
108
109        args
110    }
111
112    /// Create a `Command` either directly or wrapped in sandbox.
113    fn make_command(&self, agent_args: Vec<String>) -> Command {
114        if let Some(ref sb) = self.sandbox {
115            let std_cmd = crate::sandbox::build_sandbox_command(sb, agent_args);
116            Command::from(std_cmd)
117        } else {
118            let mut cmd = Command::new("gemini");
119            if let Some(ref root) = self.root {
120                cmd.current_dir(root);
121            }
122            cmd.args(&agent_args);
123            cmd
124        }
125    }
126
127    async fn execute(
128        &self,
129        interactive: bool,
130        prompt: Option<&str>,
131    ) -> Result<Option<AgentOutput>> {
132        if !self.system_prompt.is_empty() {
133            log::debug!(
134                "Gemini system prompt (written to system.md): {}",
135                self.system_prompt
136            );
137            self.write_system_file().await?;
138        }
139
140        let agent_args = self.build_run_args(interactive, prompt);
141        log::debug!("Gemini command: gemini {}", agent_args.join(" "));
142        if let Some(p) = prompt {
143            log::debug!("Gemini user prompt: {}", p);
144        }
145        let mut cmd = self.make_command(agent_args);
146
147        if !self.system_prompt.is_empty() {
148            cmd.env("GEMINI_SYSTEM_MD", "true");
149        }
150
151        if interactive {
152            cmd.stdin(Stdio::inherit())
153                .stdout(Stdio::inherit())
154                .stderr(Stdio::inherit());
155            let status = cmd
156                .status()
157                .await
158                .context("Failed to execute 'gemini' CLI. Is it installed and in PATH?")?;
159            if !status.success() {
160                anyhow::bail!("Gemini command failed with status: {}", status);
161            }
162            Ok(None)
163        } else if self.capture_output {
164            let text = crate::process::run_captured(&mut cmd, "Gemini").await?;
165            log::debug!("Gemini raw response ({} bytes): {}", text.len(), text);
166            Ok(Some(AgentOutput::from_text("gemini", &text)))
167        } else {
168            cmd.stdin(Stdio::inherit()).stdout(Stdio::inherit());
169            crate::process::run_with_captured_stderr(&mut cmd).await?;
170            Ok(None)
171        }
172    }
173}
174
175#[cfg(test)]
176#[path = "gemini_tests.rs"]
177mod tests;
178
179impl Default for Gemini {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185impl GeminiLiveLogAdapter {
186    pub fn new(ctx: LiveLogContext) -> Self {
187        Self {
188            ctx,
189            session_path: None,
190            emitted_message_ids: HashSet::new(),
191        }
192    }
193
194    fn discover_session_path(&self) -> Option<std::path::PathBuf> {
195        let gemini_tmp = tmp_dir()?;
196        let mut best: Option<(std::time::SystemTime, std::path::PathBuf)> = None;
197        let projects = std::fs::read_dir(gemini_tmp).ok()?;
198        for project in projects.flatten() {
199            let chats = project.path().join("chats");
200            let files = std::fs::read_dir(chats).ok()?;
201            for file in files.flatten() {
202                let path = file.path();
203                let metadata = file.metadata().ok()?;
204                let modified = metadata.modified().ok()?;
205                let started_at = std::time::SystemTime::UNIX_EPOCH
206                    + std::time::Duration::from_secs(self.ctx.started_at.timestamp().max(0) as u64);
207                if modified < started_at {
208                    continue;
209                }
210                if best
211                    .as_ref()
212                    .map(|(current, _)| modified > *current)
213                    .unwrap_or(true)
214                {
215                    best = Some((modified, path));
216                }
217            }
218        }
219        best.map(|(_, path)| path)
220    }
221}
222
223#[async_trait]
224impl LiveLogAdapter for GeminiLiveLogAdapter {
225    async fn poll(&mut self, writer: &SessionLogWriter) -> Result<()> {
226        if self.session_path.is_none() {
227            self.session_path = self.discover_session_path();
228            if let Some(path) = &self.session_path {
229                writer.add_source_path(path.to_string_lossy().to_string())?;
230            }
231        }
232        let Some(path) = self.session_path.as_ref() else {
233            return Ok(());
234        };
235        let content = match std::fs::read_to_string(path) {
236            Ok(content) => content,
237            Err(_) => return Ok(()),
238        };
239        let json: serde_json::Value = match serde_json::from_str(&content) {
240            Ok(json) => json,
241            Err(_) => {
242                writer.emit(
243                    LogSourceKind::ProviderFile,
244                    LogEventKind::ParseWarning {
245                        message: "Failed to parse Gemini chat file".to_string(),
246                        raw: None,
247                    },
248                )?;
249                return Ok(());
250            }
251        };
252        if let Some(session_id) = json.get("sessionId").and_then(|value| value.as_str()) {
253            writer.set_provider_session_id(Some(session_id.to_string()))?;
254        }
255        if let Some(messages) = json.get("messages").and_then(|value| value.as_array()) {
256            for message in messages {
257                let message_id = message
258                    .get("id")
259                    .and_then(|value| value.as_str())
260                    .unwrap_or_default()
261                    .to_string();
262                if message_id.is_empty() || !self.emitted_message_ids.insert(message_id.clone()) {
263                    continue;
264                }
265                match message.get("type").and_then(|value| value.as_str()) {
266                    Some("user") => writer.emit(
267                        LogSourceKind::ProviderFile,
268                        LogEventKind::UserMessage {
269                            role: "user".to_string(),
270                            content: message
271                                .get("content")
272                                .and_then(|value| value.as_str())
273                                .unwrap_or_default()
274                                .to_string(),
275                            message_id: Some(message_id.clone()),
276                        },
277                    )?,
278                    Some("gemini") => {
279                        writer.emit(
280                            LogSourceKind::ProviderFile,
281                            LogEventKind::AssistantMessage {
282                                content: message
283                                    .get("content")
284                                    .and_then(|value| value.as_str())
285                                    .unwrap_or_default()
286                                    .to_string(),
287                                message_id: Some(message_id.clone()),
288                            },
289                        )?;
290                        if let Some(thoughts) =
291                            message.get("thoughts").and_then(|value| value.as_array())
292                        {
293                            for thought in thoughts {
294                                writer.emit(
295                                    LogSourceKind::ProviderFile,
296                                    LogEventKind::Reasoning {
297                                        content: thought
298                                            .get("description")
299                                            .and_then(|value| value.as_str())
300                                            .unwrap_or_default()
301                                            .to_string(),
302                                        message_id: Some(message_id.clone()),
303                                    },
304                                )?;
305                            }
306                        }
307                        writer.emit(
308                            LogSourceKind::ProviderFile,
309                            LogEventKind::ProviderStatus {
310                                message: "Gemini message metadata".to_string(),
311                                data: Some(serde_json::json!({
312                                    "tokens": message.get("tokens"),
313                                    "model": message.get("model"),
314                                })),
315                            },
316                        )?;
317                    }
318                    _ => {}
319                }
320            }
321        }
322
323        Ok(())
324    }
325}
326
327impl HistoricalLogAdapter for GeminiHistoricalLogAdapter {
328    fn backfill(&self, _root: Option<&str>) -> Result<Vec<BackfilledSession>> {
329        let mut sessions = Vec::new();
330        let Some(gemini_tmp) = tmp_dir() else {
331            return Ok(sessions);
332        };
333        let projects = match std::fs::read_dir(gemini_tmp) {
334            Ok(projects) => projects,
335            Err(_) => return Ok(sessions),
336        };
337        for project in projects.flatten() {
338            let chats = project.path().join("chats");
339            let files = match std::fs::read_dir(chats) {
340                Ok(files) => files,
341                Err(_) => continue,
342            };
343            for file in files.flatten() {
344                let path = file.path();
345                info!("Scanning Gemini history: {}", path.display());
346                let content = match std::fs::read_to_string(&path) {
347                    Ok(content) => content,
348                    Err(_) => continue,
349                };
350                let json: serde_json::Value = match serde_json::from_str(&content) {
351                    Ok(json) => json,
352                    Err(_) => continue,
353                };
354                let Some(session_id) = json.get("sessionId").and_then(|value| value.as_str())
355                else {
356                    continue;
357                };
358                let mut events = Vec::new();
359                if let Some(messages) = json.get("messages").and_then(|value| value.as_array()) {
360                    for message in messages {
361                        let message_id = message
362                            .get("id")
363                            .and_then(|value| value.as_str())
364                            .map(str::to_string);
365                        match message.get("type").and_then(|value| value.as_str()) {
366                            Some("user") => events.push((
367                                LogSourceKind::Backfill,
368                                LogEventKind::UserMessage {
369                                    role: "user".to_string(),
370                                    content: message
371                                        .get("content")
372                                        .and_then(|value| value.as_str())
373                                        .unwrap_or_default()
374                                        .to_string(),
375                                    message_id: message_id.clone(),
376                                },
377                            )),
378                            Some("gemini") => {
379                                events.push((
380                                    LogSourceKind::Backfill,
381                                    LogEventKind::AssistantMessage {
382                                        content: message
383                                            .get("content")
384                                            .and_then(|value| value.as_str())
385                                            .unwrap_or_default()
386                                            .to_string(),
387                                        message_id: message_id.clone(),
388                                    },
389                                ));
390                                if let Some(thoughts) =
391                                    message.get("thoughts").and_then(|value| value.as_array())
392                                {
393                                    for thought in thoughts {
394                                        events.push((
395                                            LogSourceKind::Backfill,
396                                            LogEventKind::Reasoning {
397                                                content: thought
398                                                    .get("description")
399                                                    .and_then(|value| value.as_str())
400                                                    .unwrap_or_default()
401                                                    .to_string(),
402                                                message_id: message_id.clone(),
403                                            },
404                                        ));
405                                    }
406                                }
407                            }
408                            _ => {}
409                        }
410                    }
411                }
412                sessions.push(BackfilledSession {
413                    metadata: SessionLogMetadata {
414                        provider: "gemini".to_string(),
415                        wrapper_session_id: session_id.to_string(),
416                        provider_session_id: Some(session_id.to_string()),
417                        workspace_path: None,
418                        command: "backfill".to_string(),
419                        model: None,
420                        resumed: false,
421                        backfilled: true,
422                    },
423                    completeness: LogCompleteness::Full,
424                    source_paths: vec![path.to_string_lossy().to_string()],
425                    events,
426                });
427            }
428        }
429        Ok(sessions)
430    }
431}
432
433#[async_trait]
434impl Agent for Gemini {
435    fn name(&self) -> &str {
436        "gemini"
437    }
438
439    fn default_model() -> &'static str {
440        DEFAULT_MODEL
441    }
442
443    fn model_for_size(size: ModelSize) -> &'static str {
444        match size {
445            ModelSize::Small => "gemini-2.5-flash-lite",
446            ModelSize::Medium => "gemini-2.5-flash",
447            ModelSize::Large => "gemini-2.5-pro",
448        }
449    }
450
451    fn available_models() -> &'static [&'static str] {
452        AVAILABLE_MODELS
453    }
454
455    fn system_prompt(&self) -> &str {
456        &self.system_prompt
457    }
458
459    fn set_system_prompt(&mut self, prompt: String) {
460        self.system_prompt = prompt;
461    }
462
463    fn get_model(&self) -> &str {
464        &self.model
465    }
466
467    fn set_model(&mut self, model: String) {
468        self.model = model;
469    }
470
471    fn set_root(&mut self, root: String) {
472        self.root = Some(root);
473    }
474
475    fn set_skip_permissions(&mut self, skip: bool) {
476        self.skip_permissions = skip;
477    }
478
479    fn set_output_format(&mut self, format: Option<String>) {
480        self.output_format = format;
481    }
482
483    fn set_add_dirs(&mut self, dirs: Vec<String>) {
484        self.add_dirs = dirs;
485    }
486
487    fn set_capture_output(&mut self, capture: bool) {
488        self.capture_output = capture;
489    }
490
491    fn set_sandbox(&mut self, config: SandboxConfig) {
492        self.sandbox = Some(config);
493    }
494
495    fn set_max_turns(&mut self, turns: u32) {
496        self.max_turns = Some(turns);
497    }
498
499    fn as_any_ref(&self) -> &dyn std::any::Any {
500        self
501    }
502
503    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
504        self
505    }
506
507    async fn run(&self, prompt: Option<&str>) -> Result<Option<AgentOutput>> {
508        self.execute(false, prompt).await
509    }
510
511    async fn run_interactive(&self, prompt: Option<&str>) -> Result<()> {
512        self.execute(true, prompt).await?;
513        Ok(())
514    }
515
516    async fn run_resume(&self, session_id: Option<&str>, _last: bool) -> Result<()> {
517        let mut args = Vec::new();
518
519        if let Some(id) = session_id {
520            args.extend(["--resume".to_string(), id.to_string()]);
521        } else {
522            args.extend(["--resume".to_string(), "latest".to_string()]);
523        }
524
525        if self.skip_permissions {
526            args.extend(["--approval-mode", "yolo"].map(String::from));
527        }
528
529        if !self.model.is_empty() && self.model != "auto" {
530            args.extend(["--model".to_string(), self.model.clone()]);
531        }
532
533        for dir in &self.add_dirs {
534            args.extend(["--include-directories".to_string(), dir.clone()]);
535        }
536
537        let mut cmd = self.make_command(args);
538
539        cmd.stdin(Stdio::inherit())
540            .stdout(Stdio::inherit())
541            .stderr(Stdio::inherit());
542
543        let status = cmd
544            .status()
545            .await
546            .context("Failed to execute 'gemini' CLI. Is it installed and in PATH?")?;
547        if !status.success() {
548            anyhow::bail!("Gemini resume failed with status: {}", status);
549        }
550        Ok(())
551    }
552
553    async fn cleanup(&self) -> Result<()> {
554        log::debug!("Cleaning up Gemini agent resources");
555        let base = self.get_base_path();
556        let gemini_dir = base.join(".gemini");
557        let system_file = gemini_dir.join("system.md");
558
559        if system_file.exists() {
560            fs::remove_file(&system_file).await?;
561        }
562
563        if gemini_dir.exists()
564            && fs::read_dir(&gemini_dir)
565                .await?
566                .next_entry()
567                .await?
568                .is_none()
569        {
570            fs::remove_dir(&gemini_dir).await?;
571        }
572
573        Ok(())
574    }
575}