Skip to main content

zag_agent/
session.rs

1//! Session-to-worktree mapping store.
2//!
3//! Persists session-worktree mappings in `~/.zag/projects/<id>/sessions.json`
4//! so that `zag run --resume <id>` can resume inside the correct workspace.
5
6use crate::config::Config;
7use crate::session_log::{GlobalSessionEntry, upsert_global_entry};
8use anyhow::{Context, Result};
9use chrono::{DateTime, FixedOffset};
10use log::debug;
11use serde::{Deserialize, Serialize};
12use std::path::PathBuf;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SessionEntry {
16    pub session_id: String,
17    pub provider: String,
18    #[serde(default)]
19    pub model: String,
20    pub worktree_path: String,
21    pub worktree_name: String,
22    pub created_at: String,
23    #[serde(default)]
24    pub provider_session_id: Option<String>,
25    #[serde(default)]
26    pub sandbox_name: Option<String>,
27    #[serde(default)]
28    pub is_worktree: bool,
29    #[serde(default)]
30    pub discovered: bool,
31    #[serde(default)]
32    pub discovery_source: Option<String>,
33    #[serde(default)]
34    pub log_path: Option<String>,
35    #[serde(default = "default_log_completeness")]
36    pub log_completeness: String,
37    /// Human-readable session name for discovery.
38    #[serde(default)]
39    pub name: Option<String>,
40    /// Short description of the session's purpose.
41    #[serde(default)]
42    pub description: Option<String>,
43    /// Arbitrary tags for categorization and discovery.
44    #[serde(default)]
45    pub tags: Vec<String>,
46    /// Session IDs that this session depends on (must complete before this session starts).
47    #[serde(default, skip_serializing_if = "Vec::is_empty")]
48    pub dependencies: Vec<String>,
49    /// Session ID that this session is a retry of.
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub retried_from: Option<String>,
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
55pub struct SessionStore {
56    pub sessions: Vec<SessionEntry>,
57}
58
59impl SessionStore {
60    /// Path to the sessions file.
61    fn path(root: Option<&str>) -> PathBuf {
62        Config::agent_dir(root).join("sessions.json")
63    }
64
65    /// Load session store from disk. Returns empty store if file doesn't exist.
66    pub fn load(root: Option<&str>) -> Result<Self> {
67        let path = Self::path(root);
68        debug!("Loading session store from {}", path.display());
69        if !path.exists() {
70            debug!("Session store not found, using empty store");
71            return Ok(Self::default());
72        }
73        let content = std::fs::read_to_string(&path)
74            .with_context(|| format!("Failed to read sessions file: {}", path.display()))?;
75        let store: SessionStore = serde_json::from_str(&content)
76            .with_context(|| format!("Failed to parse sessions file: {}", path.display()))?;
77        debug!(
78            "Loaded {} sessions from {}",
79            store.sessions.len(),
80            path.display()
81        );
82        Ok(store)
83    }
84
85    /// Save session store to disk.
86    pub fn save(&self, root: Option<&str>) -> Result<()> {
87        let path = Self::path(root);
88        debug!(
89            "Saving {} sessions to {}",
90            self.sessions.len(),
91            path.display()
92        );
93        if let Some(parent) = path.parent() {
94            std::fs::create_dir_all(parent)
95                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
96        }
97        let content = serde_json::to_string_pretty(self).context("Failed to serialize sessions")?;
98        crate::file_util::atomic_write_str(&path, &content)
99            .with_context(|| format!("Failed to write sessions file: {}", path.display()))?;
100        debug!("Session store saved to {}", path.display());
101
102        // Also upsert entries with log_path into the global session index
103        let global_dir = Config::global_base_dir();
104        let project = path
105            .parent()
106            .and_then(|p| p.file_name())
107            .map(|n| n.to_string_lossy().to_string())
108            .unwrap_or_default();
109        for entry in &self.sessions {
110            if let Some(ref log_path) = entry.log_path {
111                let _ = upsert_global_entry(
112                    &global_dir,
113                    GlobalSessionEntry {
114                        session_id: entry.session_id.clone(),
115                        project: project.clone(),
116                        log_path: log_path.clone(),
117                        provider: entry.provider.clone(),
118                        started_at: entry.created_at.clone(),
119                    },
120                );
121            }
122        }
123
124        Ok(())
125    }
126
127    /// Load all session stores across all projects under `~/.zag/projects/`.
128    pub fn load_all() -> Result<Self> {
129        let projects_dir = Config::global_base_dir().join("projects");
130        debug!("Loading all session stores from {}", projects_dir.display());
131        let mut all_sessions = Vec::new();
132        if let Ok(entries) = std::fs::read_dir(&projects_dir) {
133            for entry in entries.flatten() {
134                let sessions_path = entry.path().join("sessions.json");
135                if sessions_path.exists() {
136                    if let Ok(content) = std::fs::read_to_string(&sessions_path) {
137                        if let Ok(store) = serde_json::from_str::<SessionStore>(&content) {
138                            all_sessions.extend(store.sessions);
139                        }
140                    }
141                }
142            }
143        }
144        // Also load the global base directory sessions (non-repo usage)
145        let global_sessions = Config::global_base_dir().join("sessions.json");
146        if global_sessions.exists() {
147            if let Ok(content) = std::fs::read_to_string(&global_sessions) {
148                if let Ok(store) = serde_json::from_str::<SessionStore>(&content) {
149                    all_sessions.extend(store.sessions);
150                }
151            }
152        }
153        debug!("Loaded {} sessions across all projects", all_sessions.len());
154        Ok(Self {
155            sessions: all_sessions,
156        })
157    }
158
159    /// Add a session entry.
160    pub fn add(&mut self, entry: SessionEntry) {
161        self.sessions.retain(|existing| {
162            existing.session_id != entry.session_id
163                && !(entry.provider_session_id.is_some()
164                    && existing.provider_session_id == entry.provider_session_id)
165        });
166        debug!(
167            "Adding session: id={}, provider={}, worktree={}",
168            entry.session_id, entry.provider, entry.worktree_name
169        );
170        self.sessions.push(entry);
171    }
172
173    /// Find a session by ID.
174    pub fn find_by_session_id(&self, id: &str) -> Option<&SessionEntry> {
175        let result = self.sessions.iter().find(|e| e.session_id == id);
176        if result.is_some() {
177            debug!("Found session: {}", id);
178        } else {
179            debug!("Session not found: {}", id);
180        }
181        result
182    }
183
184    /// Find a session by provider-native session ID.
185    pub fn find_by_provider_session_id(&self, id: &str) -> Option<&SessionEntry> {
186        let result = self
187            .sessions
188            .iter()
189            .find(|e| e.provider_session_id.as_deref() == Some(id));
190        if result.is_some() {
191            debug!("Found provider session: {}", id);
192        } else {
193            debug!("Provider session not found: {}", id);
194        }
195        result
196    }
197
198    /// Find a session by either wrapper or provider-native ID.
199    pub fn find_by_any_id(&self, id: &str) -> Option<&SessionEntry> {
200        self.find_by_session_id(id)
201            .or_else(|| self.find_by_provider_session_id(id))
202    }
203
204    /// Get the most recently created session.
205    pub fn latest(&self) -> Option<&SessionEntry> {
206        self.sessions.iter().max_by(|a, b| {
207            parse_created_at(&a.created_at)
208                .cmp(&parse_created_at(&b.created_at))
209                .then_with(|| a.session_id.cmp(&b.session_id))
210        })
211    }
212
213    /// Update a wrapper session with the provider-native session ID.
214    pub fn set_provider_session_id(&mut self, session_id: &str, provider_session_id: String) {
215        if let Some(entry) = self
216            .sessions
217            .iter_mut()
218            .find(|e| e.session_id == session_id)
219        {
220            entry.provider_session_id = Some(provider_session_id);
221        }
222    }
223
224    /// Remove a session by ID.
225    pub fn remove(&mut self, session_id: &str) {
226        debug!("Removing session: {}", session_id);
227        self.sessions.retain(|e| e.session_id != session_id);
228    }
229
230    /// List all sessions as `SessionInfo`, sorted by created_at descending (newest first).
231    pub fn list(&self) -> Vec<SessionInfo> {
232        let mut infos: Vec<SessionInfo> = self.sessions.iter().map(SessionInfo::from).collect();
233        infos.sort_by(|a, b| {
234            parse_created_at(&b.created_at)
235                .cmp(&parse_created_at(&a.created_at))
236                .then_with(|| b.session_id.cmp(&a.session_id))
237        });
238        infos
239    }
240
241    /// Get a session by any ID (wrapper or provider-native) as `SessionInfo`.
242    pub fn get(&self, id: &str) -> Option<SessionInfo> {
243        self.find_by_any_id(id).map(SessionInfo::from)
244    }
245
246    /// Find a session by exact name match. Returns the most recent if multiple match.
247    pub fn find_by_name(&self, name: &str) -> Option<&SessionEntry> {
248        self.sessions
249            .iter()
250            .filter(|e| e.name.as_deref() == Some(name))
251            .max_by(|a, b| {
252                parse_created_at(&a.created_at)
253                    .cmp(&parse_created_at(&b.created_at))
254                    .then_with(|| a.session_id.cmp(&b.session_id))
255            })
256    }
257
258    /// Find all sessions with a matching tag (case-insensitive).
259    pub fn find_by_tag(&self, tag: &str) -> Vec<&SessionEntry> {
260        let tag_lower = tag.to_lowercase();
261        self.sessions
262            .iter()
263            .filter(|e| e.tags.iter().any(|t| t.to_lowercase() == tag_lower))
264            .collect()
265    }
266}
267
268/// Public session info struct for programmatic API consumers.
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct SessionInfo {
271    pub session_id: String,
272    pub provider: String,
273    pub model: String,
274    pub created_at: String,
275    pub provider_session_id: Option<String>,
276    pub worktree_path: Option<String>,
277    pub sandbox_name: Option<String>,
278    pub log_completeness: String,
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub name: Option<String>,
281    #[serde(default, skip_serializing_if = "Option::is_none")]
282    pub description: Option<String>,
283    #[serde(default, skip_serializing_if = "Vec::is_empty")]
284    pub tags: Vec<String>,
285}
286
287impl From<&SessionEntry> for SessionInfo {
288    fn from(e: &SessionEntry) -> Self {
289        Self {
290            session_id: e.session_id.clone(),
291            provider: e.provider.clone(),
292            model: e.model.clone(),
293            created_at: e.created_at.clone(),
294            provider_session_id: e.provider_session_id.clone(),
295            worktree_path: if e.worktree_path.is_empty() {
296                None
297            } else {
298                Some(e.worktree_path.clone())
299            },
300            sandbox_name: e.sandbox_name.clone(),
301            log_completeness: e.log_completeness.clone(),
302            name: e.name.clone(),
303            description: e.description.clone(),
304            tags: e.tags.clone(),
305        }
306    }
307}
308
309fn default_log_completeness() -> String {
310    "partial".to_string()
311}
312
313fn parse_created_at(created_at: &str) -> Option<DateTime<FixedOffset>> {
314    DateTime::parse_from_rfc3339(created_at).ok()
315}
316
317#[cfg(test)]
318#[path = "session_tests.rs"]
319mod tests;