Skip to main content

systemprompt_cloud/cli_session/
store.rs

1use anyhow::Result;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs;
6use std::path::{Path, PathBuf};
7use systemprompt_identifiers::TenantId;
8
9use super::{CliSession, SessionKey, LOCAL_SESSION_KEY};
10
11const STORE_VERSION: u32 = 1;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SessionStore {
15    pub version: u32,
16    pub sessions: HashMap<String, CliSession>,
17    pub active_key: Option<String>,
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    pub active_profile_name: Option<String>,
20    pub updated_at: DateTime<Utc>,
21}
22
23impl Default for SessionStore {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl SessionStore {
30    #[must_use]
31    pub fn new() -> Self {
32        Self {
33            version: STORE_VERSION,
34            sessions: HashMap::new(),
35            active_key: None,
36            active_profile_name: None,
37            updated_at: Utc::now(),
38        }
39    }
40
41    #[must_use]
42    pub fn get_valid_session(&self, key: &SessionKey) -> Option<&CliSession> {
43        self.sessions
44            .get(&key.as_storage_key())
45            .filter(|s| !s.is_expired() && s.has_valid_credentials())
46    }
47
48    pub fn get_valid_session_mut(&mut self, key: &SessionKey) -> Option<&mut CliSession> {
49        self.sessions
50            .get_mut(&key.as_storage_key())
51            .filter(|s| !s.is_expired() && s.has_valid_credentials())
52    }
53
54    #[must_use]
55    pub fn get_session(&self, key: &SessionKey) -> Option<&CliSession> {
56        self.sessions.get(&key.as_storage_key())
57    }
58
59    pub fn upsert_session(&mut self, key: &SessionKey, session: CliSession) {
60        self.sessions.insert(key.as_storage_key(), session);
61        self.updated_at = Utc::now();
62    }
63
64    pub fn remove_session(&mut self, key: &SessionKey) -> Option<CliSession> {
65        let storage_key = key.as_storage_key();
66        let removed = self.sessions.remove(&storage_key);
67        if removed.is_some() {
68            self.updated_at = Utc::now();
69        }
70        removed
71    }
72
73    pub fn set_active(&mut self, key: &SessionKey) {
74        self.active_key = Some(key.as_storage_key());
75        self.updated_at = Utc::now();
76    }
77
78    pub fn set_active_with_profile(&mut self, key: &SessionKey, profile_name: &str) {
79        self.active_key = Some(key.as_storage_key());
80        self.active_profile_name = Some(profile_name.to_string());
81        self.updated_at = Utc::now();
82    }
83
84    pub fn set_active_with_profile_path(
85        &mut self,
86        key: &SessionKey,
87        profile_name: &str,
88        profile_path: PathBuf,
89    ) {
90        self.active_key = Some(key.as_storage_key());
91        self.active_profile_name = Some(profile_name.to_string());
92
93        if let Some(session) = self.sessions.get_mut(&key.as_storage_key()) {
94            session.update_profile_path(profile_path);
95        }
96
97        self.updated_at = Utc::now();
98    }
99
100    #[must_use]
101    pub fn active_session_key(&self) -> Option<SessionKey> {
102        self.active_key.as_ref().map(|k| {
103            if k == LOCAL_SESSION_KEY {
104                SessionKey::Local
105            } else {
106                k.strip_prefix("tenant_")
107                    .map(|id| SessionKey::Tenant(TenantId::new(id)))
108                    .unwrap_or(SessionKey::Local)
109            }
110        })
111    }
112
113    #[must_use]
114    pub fn active_session(&self) -> Option<&CliSession> {
115        self.active_session_key()
116            .and_then(|key| self.get_valid_session(&key))
117    }
118
119    pub fn prune_expired(&mut self) -> usize {
120        let expired_keys: Vec<String> = self
121            .sessions
122            .iter()
123            .filter(|(_, s)| s.is_expired())
124            .map(|(k, _)| k.clone())
125            .collect();
126
127        let count = expired_keys.len();
128        for key in &expired_keys {
129            self.sessions.remove(key);
130        }
131
132        if count > 0 {
133            self.updated_at = Utc::now();
134        }
135        count
136    }
137
138    #[must_use]
139    pub fn all_sessions(&self) -> Vec<(&String, &CliSession)> {
140        self.sessions.iter().collect()
141    }
142
143    #[must_use]
144    pub fn len(&self) -> usize {
145        self.sessions.len()
146    }
147
148    #[must_use]
149    pub fn is_empty(&self) -> bool {
150        self.sessions.is_empty()
151    }
152
153    #[must_use]
154    pub fn load(sessions_dir: &Path) -> Option<Self> {
155        let index_path = sessions_dir.join("index.json");
156        let content = fs::read_to_string(&index_path)
157            .map_err(|e| tracing::debug!(error = %e, "No session store found"))
158            .ok()?;
159        serde_json::from_str(&content)
160            .map_err(|e| tracing::warn!(error = %e, "Failed to parse session store"))
161            .ok()
162    }
163
164    pub fn load_or_create(sessions_dir: &Path) -> Result<Self> {
165        Self::load(sessions_dir).map_or_else(|| Ok(Self::new()), Ok)
166    }
167
168    pub fn save(&self, sessions_dir: &Path) -> Result<()> {
169        fs::create_dir_all(sessions_dir)?;
170
171        let gitignore_path = sessions_dir.join(".gitignore");
172        if !gitignore_path.exists() {
173            fs::write(&gitignore_path, "*\n")?;
174        }
175
176        let index_path = sessions_dir.join("index.json");
177        let content = serde_json::to_string_pretty(self)?;
178        let temp_path = index_path.with_extension("tmp");
179        fs::write(&temp_path, &content)?;
180
181        #[cfg(unix)]
182        {
183            use std::os::unix::fs::PermissionsExt;
184            let mut perms = fs::metadata(&temp_path)?.permissions();
185            perms.set_mode(0o600);
186            fs::set_permissions(&temp_path, perms)?;
187        }
188
189        fs::rename(&temp_path, &index_path)?;
190        Ok(())
191    }
192}