Skip to main content

systemprompt_cloud/cli_session/
store.rs

1use anyhow::Result;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs;
6use std::path::Path;
7use systemprompt_identifiers::TenantId;
8
9use super::{CliSession, SessionKey, LOCAL_SESSION_KEY};
10
11const STORE_VERSION: u32 = 1;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SessionStore {
15    pub version: u32,
16    pub sessions: HashMap<String, CliSession>,
17    pub active_key: Option<String>,
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    pub active_profile_name: Option<String>,
20    pub updated_at: DateTime<Utc>,
21}
22
23impl Default for SessionStore {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl SessionStore {
30    #[must_use]
31    pub fn new() -> Self {
32        Self {
33            version: STORE_VERSION,
34            sessions: HashMap::new(),
35            active_key: None,
36            active_profile_name: None,
37            updated_at: Utc::now(),
38        }
39    }
40
41    #[must_use]
42    pub fn get_valid_session(&self, key: &SessionKey) -> Option<&CliSession> {
43        self.sessions
44            .get(&key.as_storage_key())
45            .filter(|s| !s.is_expired() && s.has_valid_credentials())
46    }
47
48    pub fn get_valid_session_mut(&mut self, key: &SessionKey) -> Option<&mut CliSession> {
49        self.sessions
50            .get_mut(&key.as_storage_key())
51            .filter(|s| !s.is_expired() && s.has_valid_credentials())
52    }
53
54    #[must_use]
55    pub fn get_session(&self, key: &SessionKey) -> Option<&CliSession> {
56        self.sessions.get(&key.as_storage_key())
57    }
58
59    pub fn upsert_session(&mut self, key: &SessionKey, session: CliSession) {
60        self.sessions.insert(key.as_storage_key(), session);
61        self.updated_at = Utc::now();
62    }
63
64    pub fn remove_session(&mut self, key: &SessionKey) -> Option<CliSession> {
65        let storage_key = key.as_storage_key();
66        let removed = self.sessions.remove(&storage_key);
67        if removed.is_some() {
68            self.updated_at = Utc::now();
69        }
70        removed
71    }
72
73    pub fn set_active(&mut self, key: &SessionKey) {
74        self.active_key = Some(key.as_storage_key());
75        self.updated_at = Utc::now();
76    }
77
78    pub fn set_active_with_profile(&mut self, key: &SessionKey, profile_name: &str) {
79        self.active_key = Some(key.as_storage_key());
80        self.active_profile_name = Some(profile_name.to_string());
81        self.updated_at = Utc::now();
82    }
83
84    #[must_use]
85    pub fn active_session_key(&self) -> Option<SessionKey> {
86        self.active_key.as_ref().map(|k| {
87            if k == LOCAL_SESSION_KEY {
88                SessionKey::Local
89            } else {
90                k.strip_prefix("tenant_")
91                    .map(|id| SessionKey::Tenant(TenantId::new(id)))
92                    .unwrap_or(SessionKey::Local)
93            }
94        })
95    }
96
97    #[must_use]
98    pub fn active_session(&self) -> Option<&CliSession> {
99        self.active_session_key()
100            .and_then(|key| self.get_valid_session(&key))
101    }
102
103    pub fn prune_expired(&mut self) -> usize {
104        let expired_keys: Vec<String> = self
105            .sessions
106            .iter()
107            .filter(|(_, s)| s.is_expired())
108            .map(|(k, _)| k.clone())
109            .collect();
110
111        let count = expired_keys.len();
112        for key in &expired_keys {
113            self.sessions.remove(key);
114        }
115
116        if count > 0 {
117            self.updated_at = Utc::now();
118        }
119        count
120    }
121
122    #[must_use]
123    pub fn all_sessions(&self) -> Vec<(&String, &CliSession)> {
124        self.sessions.iter().collect()
125    }
126
127    #[must_use]
128    pub fn len(&self) -> usize {
129        self.sessions.len()
130    }
131
132    #[must_use]
133    pub fn is_empty(&self) -> bool {
134        self.sessions.is_empty()
135    }
136
137    #[must_use]
138    pub fn load(sessions_dir: &Path) -> Option<Self> {
139        let index_path = sessions_dir.join("index.json");
140        let content = fs::read_to_string(&index_path)
141            .map_err(|e| tracing::debug!(error = %e, "No session store found"))
142            .ok()?;
143        serde_json::from_str(&content)
144            .map_err(|e| tracing::warn!(error = %e, "Failed to parse session store"))
145            .ok()
146    }
147
148    pub fn load_or_create(sessions_dir: &Path) -> Result<Self> {
149        Self::load(sessions_dir).map_or_else(|| Ok(Self::new()), Ok)
150    }
151
152    pub fn save(&self, sessions_dir: &Path) -> Result<()> {
153        fs::create_dir_all(sessions_dir)?;
154
155        let gitignore_path = sessions_dir.join(".gitignore");
156        if !gitignore_path.exists() {
157            fs::write(&gitignore_path, "*\n")?;
158        }
159
160        let index_path = sessions_dir.join("index.json");
161        let content = serde_json::to_string_pretty(self)?;
162        let temp_path = index_path.with_extension("tmp");
163        fs::write(&temp_path, &content)?;
164
165        #[cfg(unix)]
166        {
167            use std::os::unix::fs::PermissionsExt;
168            let mut perms = fs::metadata(&temp_path)?.permissions();
169            perms.set_mode(0o600);
170            fs::set_permissions(&temp_path, perms)?;
171        }
172
173        fs::rename(&temp_path, &index_path)?;
174        Ok(())
175    }
176}