Skip to main content

zag_agent/
session.rs

1//! Session-to-worktree mapping store.
2//!
3//! Persists session-worktree mappings in `~/.zag/projects/<id>/sessions.json`
4//! so that `zag run --resume <id>` can resume inside the correct workspace.
5
6use crate::config::Config;
7use crate::session_log::{GlobalSessionEntry, upsert_global_entry};
8use anyhow::{Context, Result};
9use chrono::{DateTime, FixedOffset};
10use log::debug;
11use serde::{Deserialize, Serialize};
12use std::path::PathBuf;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SessionEntry {
16    pub session_id: String,
17    pub provider: String,
18    #[serde(default)]
19    pub model: String,
20    pub worktree_path: String,
21    pub worktree_name: String,
22    pub created_at: String,
23    #[serde(default)]
24    pub provider_session_id: Option<String>,
25    #[serde(default)]
26    pub sandbox_name: Option<String>,
27    #[serde(default)]
28    pub is_worktree: bool,
29    #[serde(default)]
30    pub discovered: bool,
31    #[serde(default)]
32    pub discovery_source: Option<String>,
33    #[serde(default)]
34    pub log_path: Option<String>,
35    #[serde(default = "default_log_completeness")]
36    pub log_completeness: String,
37    /// Human-readable session name for discovery.
38    #[serde(default)]
39    pub name: Option<String>,
40    /// Short description of the session's purpose.
41    #[serde(default)]
42    pub description: Option<String>,
43    /// Arbitrary tags for categorization and discovery.
44    #[serde(default)]
45    pub tags: Vec<String>,
46    /// Session IDs that this session depends on (must complete before this session starts).
47    #[serde(default, skip_serializing_if = "Vec::is_empty")]
48    pub dependencies: Vec<String>,
49    /// Session ID that this session is a retry of.
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub retried_from: Option<String>,
52    /// Whether this is a long-lived interactive session (FIFO-based).
53    #[serde(default, skip_serializing_if = "is_false")]
54    pub interactive: bool,
55}
56
57fn is_false(v: &bool) -> bool {
58    !v
59}
60
61#[derive(Debug, Clone, Default, Serialize, Deserialize)]
62pub struct SessionStore {
63    pub sessions: Vec<SessionEntry>,
64}
65
66impl SessionStore {
67    /// Path to the sessions file.
68    fn path(root: Option<&str>) -> PathBuf {
69        Config::agent_dir(root).join("sessions.json")
70    }
71
72    /// Load session store from disk. Returns empty store if file doesn't exist.
73    pub fn load(root: Option<&str>) -> Result<Self> {
74        let path = Self::path(root);
75        debug!("Loading session store from {}", path.display());
76        if !path.exists() {
77            debug!("Session store not found, using empty store");
78            return Ok(Self::default());
79        }
80        let content = std::fs::read_to_string(&path)
81            .with_context(|| format!("Failed to read sessions file: {}", path.display()))?;
82        let store: SessionStore = serde_json::from_str(&content)
83            .with_context(|| format!("Failed to parse sessions file: {}", path.display()))?;
84        debug!(
85            "Loaded {} sessions from {}",
86            store.sessions.len(),
87            path.display()
88        );
89        Ok(store)
90    }
91
92    /// Save session store to disk.
93    pub fn save(&self, root: Option<&str>) -> Result<()> {
94        let path = Self::path(root);
95        debug!(
96            "Saving {} sessions to {}",
97            self.sessions.len(),
98            path.display()
99        );
100        if let Some(parent) = path.parent() {
101            std::fs::create_dir_all(parent)
102                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
103        }
104        let content = serde_json::to_string_pretty(self).context("Failed to serialize sessions")?;
105        crate::file_util::atomic_write_str(&path, &content)
106            .with_context(|| format!("Failed to write sessions file: {}", path.display()))?;
107        debug!("Session store saved to {}", path.display());
108
109        // Also upsert entries with log_path into the global session index
110        let global_dir = Config::global_base_dir();
111        let project = path
112            .parent()
113            .and_then(|p| p.file_name())
114            .map(|n| n.to_string_lossy().to_string())
115            .unwrap_or_default();
116        for entry in &self.sessions {
117            if let Some(ref log_path) = entry.log_path {
118                let _ = upsert_global_entry(
119                    &global_dir,
120                    GlobalSessionEntry {
121                        session_id: entry.session_id.clone(),
122                        project: project.clone(),
123                        log_path: log_path.clone(),
124                        provider: entry.provider.clone(),
125                        started_at: entry.created_at.clone(),
126                    },
127                );
128            }
129        }
130
131        Ok(())
132    }
133
134    /// Load all session stores across all projects under `~/.zag/projects/`.
135    pub fn load_all() -> Result<Self> {
136        let projects_dir = Config::global_base_dir().join("projects");
137        debug!("Loading all session stores from {}", projects_dir.display());
138        let mut all_sessions = Vec::new();
139        if let Ok(entries) = std::fs::read_dir(&projects_dir) {
140            for entry in entries.flatten() {
141                let sessions_path = entry.path().join("sessions.json");
142                if sessions_path.exists() {
143                    if let Ok(content) = std::fs::read_to_string(&sessions_path) {
144                        if let Ok(store) = serde_json::from_str::<SessionStore>(&content) {
145                            all_sessions.extend(store.sessions);
146                        }
147                    }
148                }
149            }
150        }
151        // Also load the global base directory sessions (non-repo usage)
152        let global_sessions = Config::global_base_dir().join("sessions.json");
153        if global_sessions.exists() {
154            if let Ok(content) = std::fs::read_to_string(&global_sessions) {
155                if let Ok(store) = serde_json::from_str::<SessionStore>(&content) {
156                    all_sessions.extend(store.sessions);
157                }
158            }
159        }
160        debug!("Loaded {} sessions across all projects", all_sessions.len());
161        Ok(Self {
162            sessions: all_sessions,
163        })
164    }
165
166    /// Add a session entry.
167    pub fn add(&mut self, entry: SessionEntry) {
168        self.sessions.retain(|existing| {
169            existing.session_id != entry.session_id
170                && !(entry.provider_session_id.is_some()
171                    && existing.provider_session_id == entry.provider_session_id)
172        });
173        debug!(
174            "Adding session: id={}, provider={}, worktree={}",
175            entry.session_id, entry.provider, entry.worktree_name
176        );
177        self.sessions.push(entry);
178    }
179
180    /// Find a session by ID.
181    pub fn find_by_session_id(&self, id: &str) -> Option<&SessionEntry> {
182        let result = self.sessions.iter().find(|e| e.session_id == id);
183        if result.is_some() {
184            debug!("Found session: {}", id);
185        } else {
186            debug!("Session not found: {}", id);
187        }
188        result
189    }
190
191    /// Find a session by provider-native session ID.
192    pub fn find_by_provider_session_id(&self, id: &str) -> Option<&SessionEntry> {
193        let result = self
194            .sessions
195            .iter()
196            .find(|e| e.provider_session_id.as_deref() == Some(id));
197        if result.is_some() {
198            debug!("Found provider session: {}", id);
199        } else {
200            debug!("Provider session not found: {}", id);
201        }
202        result
203    }
204
205    /// Find a session by either wrapper or provider-native ID.
206    pub fn find_by_any_id(&self, id: &str) -> Option<&SessionEntry> {
207        self.find_by_session_id(id)
208            .or_else(|| self.find_by_provider_session_id(id))
209    }
210
211    /// Get the most recently created session.
212    pub fn latest(&self) -> Option<&SessionEntry> {
213        self.sessions.iter().max_by(|a, b| {
214            parse_created_at(&a.created_at)
215                .cmp(&parse_created_at(&b.created_at))
216                .then_with(|| a.session_id.cmp(&b.session_id))
217        })
218    }
219
220    /// Update a wrapper session with the provider-native session ID.
221    pub fn set_provider_session_id(&mut self, session_id: &str, provider_session_id: String) {
222        if let Some(entry) = self
223            .sessions
224            .iter_mut()
225            .find(|e| e.session_id == session_id)
226        {
227            entry.provider_session_id = Some(provider_session_id);
228        }
229    }
230
231    /// Remove a session by ID.
232    pub fn remove(&mut self, session_id: &str) {
233        debug!("Removing session: {}", session_id);
234        self.sessions.retain(|e| e.session_id != session_id);
235    }
236
237    /// List all sessions as `SessionInfo`, sorted by created_at descending (newest first).
238    pub fn list(&self) -> Vec<SessionInfo> {
239        let mut infos: Vec<SessionInfo> = self.sessions.iter().map(SessionInfo::from).collect();
240        infos.sort_by(|a, b| {
241            parse_created_at(&b.created_at)
242                .cmp(&parse_created_at(&a.created_at))
243                .then_with(|| b.session_id.cmp(&a.session_id))
244        });
245        infos
246    }
247
248    /// Get a session by any ID (wrapper or provider-native) as `SessionInfo`.
249    pub fn get(&self, id: &str) -> Option<SessionInfo> {
250        self.find_by_any_id(id).map(SessionInfo::from)
251    }
252
253    /// Find a session by exact name match. Returns the most recent if multiple match.
254    pub fn find_by_name(&self, name: &str) -> Option<&SessionEntry> {
255        self.sessions
256            .iter()
257            .filter(|e| e.name.as_deref() == Some(name))
258            .max_by(|a, b| {
259                parse_created_at(&a.created_at)
260                    .cmp(&parse_created_at(&b.created_at))
261                    .then_with(|| a.session_id.cmp(&b.session_id))
262            })
263    }
264
265    /// Find all sessions with a matching tag (case-insensitive).
266    pub fn find_by_tag(&self, tag: &str) -> Vec<&SessionEntry> {
267        let tag_lower = tag.to_lowercase();
268        self.sessions
269            .iter()
270            .filter(|e| e.tags.iter().any(|t| t.to_lowercase() == tag_lower))
271            .collect()
272    }
273}
274
275/// Public session info struct for programmatic API consumers.
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct SessionInfo {
278    pub session_id: String,
279    pub provider: String,
280    pub model: String,
281    pub created_at: String,
282    pub provider_session_id: Option<String>,
283    pub worktree_path: Option<String>,
284    pub sandbox_name: Option<String>,
285    pub log_completeness: String,
286    #[serde(default, skip_serializing_if = "Option::is_none")]
287    pub name: Option<String>,
288    #[serde(default, skip_serializing_if = "Option::is_none")]
289    pub description: Option<String>,
290    #[serde(default, skip_serializing_if = "Vec::is_empty")]
291    pub tags: Vec<String>,
292}
293
294impl From<&SessionEntry> for SessionInfo {
295    fn from(e: &SessionEntry) -> Self {
296        Self {
297            session_id: e.session_id.clone(),
298            provider: e.provider.clone(),
299            model: e.model.clone(),
300            created_at: e.created_at.clone(),
301            provider_session_id: e.provider_session_id.clone(),
302            worktree_path: if e.worktree_path.is_empty() {
303                None
304            } else {
305                Some(e.worktree_path.clone())
306            },
307            sandbox_name: e.sandbox_name.clone(),
308            log_completeness: e.log_completeness.clone(),
309            name: e.name.clone(),
310            description: e.description.clone(),
311            tags: e.tags.clone(),
312        }
313    }
314}
315
316fn default_log_completeness() -> String {
317    "partial".to_string()
318}
319
320fn parse_created_at(created_at: &str) -> Option<DateTime<FixedOffset>> {
321    DateTime::parse_from_rfc3339(created_at).ok()
322}
323
324#[cfg(test)]
325#[path = "session_tests.rs"]
326mod tests;