Skip to main content

zag_agent/
session.rs

1//! Session-to-worktree mapping store.
2//!
3//! Persists session-worktree mappings in `~/.zag/projects/<id>/sessions.json`
4//! so that `zag run --resume <id>` can resume inside the correct workspace.
5
6use crate::config::Config;
7use crate::session_log::{GlobalSessionEntry, upsert_global_entry};
8use anyhow::{Context, Result};
9use chrono::{DateTime, FixedOffset};
10use log::debug;
11use serde::{Deserialize, Serialize};
12use std::path::PathBuf;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SessionEntry {
16    pub session_id: String,
17    pub provider: String,
18    #[serde(default)]
19    pub model: String,
20    pub worktree_path: String,
21    pub worktree_name: String,
22    pub created_at: String,
23    #[serde(default)]
24    pub provider_session_id: Option<String>,
25    #[serde(default)]
26    pub sandbox_name: Option<String>,
27    #[serde(default)]
28    pub is_worktree: bool,
29    #[serde(default)]
30    pub discovered: bool,
31    #[serde(default)]
32    pub discovery_source: Option<String>,
33    #[serde(default)]
34    pub log_path: Option<String>,
35    #[serde(default = "default_log_completeness")]
36    pub log_completeness: String,
37    /// Human-readable session name for discovery.
38    #[serde(default)]
39    pub name: Option<String>,
40    /// Short description of the session's purpose.
41    #[serde(default)]
42    pub description: Option<String>,
43    /// Arbitrary tags for categorization and discovery.
44    #[serde(default)]
45    pub tags: Vec<String>,
46    /// Session IDs that this session depends on (must complete before this session starts).
47    #[serde(default, skip_serializing_if = "Vec::is_empty")]
48    pub dependencies: Vec<String>,
49    /// Session ID that this session is a retry of.
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub retried_from: Option<String>,
52    /// Whether this is a long-lived interactive session (FIFO-based).
53    #[serde(default, skip_serializing_if = "is_false")]
54    pub interactive: bool,
55    /// Exit-mode constraints captured at launch (`--exit`, `--json`,
56    /// `--json-schema`). `None` means the session was launched without
57    /// `--exit` and `zag ps kill` accepts any result. See
58    /// [`crate::exit_mode::ExitConstraints`].
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub exit: Option<crate::exit_mode::ExitConstraints>,
61}
62
63fn is_false(v: &bool) -> bool {
64    !v
65}
66
67#[derive(Debug, Clone, Default, Serialize, Deserialize)]
68pub struct SessionStore {
69    pub sessions: Vec<SessionEntry>,
70}
71
72impl SessionStore {
73    /// Path to the sessions file.
74    fn path(root: Option<&str>) -> PathBuf {
75        Config::agent_dir(root).join("sessions.json")
76    }
77
78    /// Load session store from disk. Returns empty store if file doesn't exist.
79    pub fn load(root: Option<&str>) -> Result<Self> {
80        let path = Self::path(root);
81        debug!("Loading session store from {}", path.display());
82        if !path.exists() {
83            debug!("Session store not found, using empty store");
84            return Ok(Self::default());
85        }
86        let content = std::fs::read_to_string(&path)
87            .with_context(|| format!("Failed to read sessions file: {}", path.display()))?;
88        let store: SessionStore = serde_json::from_str(&content)
89            .with_context(|| format!("Failed to parse sessions file: {}", path.display()))?;
90        debug!(
91            "Loaded {} sessions from {}",
92            store.sessions.len(),
93            path.display()
94        );
95        Ok(store)
96    }
97
98    /// Save session store to disk.
99    pub fn save(&self, root: Option<&str>) -> Result<()> {
100        let path = Self::path(root);
101        debug!(
102            "Saving {} sessions to {}",
103            self.sessions.len(),
104            path.display()
105        );
106        if let Some(parent) = path.parent() {
107            std::fs::create_dir_all(parent)
108                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
109        }
110        let content = serde_json::to_string_pretty(self).context("Failed to serialize sessions")?;
111        crate::file_util::atomic_write_str(&path, &content)
112            .with_context(|| format!("Failed to write sessions file: {}", path.display()))?;
113        debug!("Session store saved to {}", path.display());
114
115        // Also upsert entries with log_path into the global session index
116        let global_dir = Config::global_base_dir();
117        let project = path
118            .parent()
119            .and_then(|p| p.file_name())
120            .map(|n| n.to_string_lossy().to_string())
121            .unwrap_or_default();
122        for entry in &self.sessions {
123            if let Some(ref log_path) = entry.log_path {
124                let _ = upsert_global_entry(
125                    &global_dir,
126                    GlobalSessionEntry {
127                        session_id: entry.session_id.clone(),
128                        project: project.clone(),
129                        log_path: log_path.clone(),
130                        provider: entry.provider.clone(),
131                        started_at: entry.created_at.clone(),
132                    },
133                );
134            }
135        }
136
137        Ok(())
138    }
139
140    /// Load all session stores across all projects under `~/.zag/projects/`.
141    pub fn load_all() -> Result<Self> {
142        let projects_dir = Config::global_base_dir().join("projects");
143        debug!("Loading all session stores from {}", projects_dir.display());
144        let mut all_sessions = Vec::new();
145        if let Ok(entries) = std::fs::read_dir(&projects_dir) {
146            for entry in entries.flatten() {
147                let sessions_path = entry.path().join("sessions.json");
148                if sessions_path.exists() {
149                    if let Ok(content) = std::fs::read_to_string(&sessions_path) {
150                        if let Ok(store) = serde_json::from_str::<SessionStore>(&content) {
151                            all_sessions.extend(store.sessions);
152                        }
153                    }
154                }
155            }
156        }
157        // Also load the global base directory sessions (non-repo usage)
158        let global_sessions = Config::global_base_dir().join("sessions.json");
159        if global_sessions.exists() {
160            if let Ok(content) = std::fs::read_to_string(&global_sessions) {
161                if let Ok(store) = serde_json::from_str::<SessionStore>(&content) {
162                    all_sessions.extend(store.sessions);
163                }
164            }
165        }
166        debug!("Loaded {} sessions across all projects", all_sessions.len());
167        Ok(Self {
168            sessions: all_sessions,
169        })
170    }
171
172    /// Add a session entry.
173    pub fn add(&mut self, entry: SessionEntry) {
174        self.sessions.retain(|existing| {
175            existing.session_id != entry.session_id
176                && !(entry.provider_session_id.is_some()
177                    && existing.provider_session_id == entry.provider_session_id)
178        });
179        debug!(
180            "Adding session: id={}, provider={}, worktree={}",
181            entry.session_id, entry.provider, entry.worktree_name
182        );
183        self.sessions.push(entry);
184    }
185
186    /// Find a session by ID.
187    pub fn find_by_session_id(&self, id: &str) -> Option<&SessionEntry> {
188        let result = self.sessions.iter().find(|e| e.session_id == id);
189        if result.is_some() {
190            debug!("Found session: {id}");
191        } else {
192            debug!("Session not found: {id}");
193        }
194        result
195    }
196
197    /// Find a session by provider-native session ID.
198    pub fn find_by_provider_session_id(&self, id: &str) -> Option<&SessionEntry> {
199        let result = self
200            .sessions
201            .iter()
202            .find(|e| e.provider_session_id.as_deref() == Some(id));
203        if result.is_some() {
204            debug!("Found provider session: {id}");
205        } else {
206            debug!("Provider session not found: {id}");
207        }
208        result
209    }
210
211    /// Find a session by either wrapper or provider-native ID.
212    pub fn find_by_any_id(&self, id: &str) -> Option<&SessionEntry> {
213        self.find_by_session_id(id)
214            .or_else(|| self.find_by_provider_session_id(id))
215    }
216
217    /// Get the most recently created session.
218    pub fn latest(&self) -> Option<&SessionEntry> {
219        self.sessions.iter().max_by(|a, b| {
220            parse_created_at(&a.created_at)
221                .cmp(&parse_created_at(&b.created_at))
222                .then_with(|| a.session_id.cmp(&b.session_id))
223        })
224    }
225
226    /// Update a wrapper session with the provider-native session ID.
227    pub fn set_provider_session_id(&mut self, session_id: &str, provider_session_id: String) {
228        if let Some(entry) = self
229            .sessions
230            .iter_mut()
231            .find(|e| e.session_id == session_id)
232        {
233            entry.provider_session_id = Some(provider_session_id);
234        }
235    }
236
237    /// Remove a session by ID.
238    pub fn remove(&mut self, session_id: &str) {
239        debug!("Removing session: {session_id}");
240        self.sessions.retain(|e| e.session_id != session_id);
241    }
242
243    /// List all sessions as `SessionInfo`, sorted by created_at descending (newest first).
244    pub fn list(&self) -> Vec<SessionInfo> {
245        let mut infos: Vec<SessionInfo> = self.sessions.iter().map(SessionInfo::from).collect();
246        infos.sort_by(|a, b| {
247            parse_created_at(&b.created_at)
248                .cmp(&parse_created_at(&a.created_at))
249                .then_with(|| b.session_id.cmp(&a.session_id))
250        });
251        infos
252    }
253
254    /// Get a session by any ID (wrapper or provider-native) as `SessionInfo`.
255    pub fn get(&self, id: &str) -> Option<SessionInfo> {
256        self.find_by_any_id(id).map(SessionInfo::from)
257    }
258
259    /// Find a session by exact name match. Returns the most recent if multiple match.
260    pub fn find_by_name(&self, name: &str) -> Option<&SessionEntry> {
261        self.sessions
262            .iter()
263            .filter(|e| e.name.as_deref() == Some(name))
264            .max_by(|a, b| {
265                parse_created_at(&a.created_at)
266                    .cmp(&parse_created_at(&b.created_at))
267                    .then_with(|| a.session_id.cmp(&b.session_id))
268            })
269    }
270
271    /// Find all sessions with a matching tag (case-insensitive).
272    pub fn find_by_tag(&self, tag: &str) -> Vec<&SessionEntry> {
273        let tag_lower = tag.to_lowercase();
274        self.sessions
275            .iter()
276            .filter(|e| e.tags.iter().any(|t| t.to_lowercase() == tag_lower))
277            .collect()
278    }
279}
280
281/// Public session info struct for programmatic API consumers.
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct SessionInfo {
284    pub session_id: String,
285    pub provider: String,
286    pub model: String,
287    pub created_at: String,
288    pub provider_session_id: Option<String>,
289    pub worktree_path: Option<String>,
290    pub sandbox_name: Option<String>,
291    pub log_completeness: String,
292    #[serde(default, skip_serializing_if = "Option::is_none")]
293    pub name: Option<String>,
294    #[serde(default, skip_serializing_if = "Option::is_none")]
295    pub description: Option<String>,
296    #[serde(default, skip_serializing_if = "Vec::is_empty")]
297    pub tags: Vec<String>,
298}
299
300impl From<&SessionEntry> for SessionInfo {
301    fn from(e: &SessionEntry) -> Self {
302        Self {
303            session_id: e.session_id.clone(),
304            provider: e.provider.clone(),
305            model: e.model.clone(),
306            created_at: e.created_at.clone(),
307            provider_session_id: e.provider_session_id.clone(),
308            worktree_path: if e.worktree_path.is_empty() {
309                None
310            } else {
311                Some(e.worktree_path.clone())
312            },
313            sandbox_name: e.sandbox_name.clone(),
314            log_completeness: e.log_completeness.clone(),
315            name: e.name.clone(),
316            description: e.description.clone(),
317            tags: e.tags.clone(),
318        }
319    }
320}
321
322fn default_log_completeness() -> String {
323    "partial".to_string()
324}
325
326fn parse_created_at(created_at: &str) -> Option<DateTime<FixedOffset>> {
327    DateTime::parse_from_rfc3339(created_at).ok()
328}
329
330#[cfg(test)]
331#[path = "session_tests.rs"]
332mod tests;