Skip to main content

unifly_api/session/
session_cache.rs

1// Persistent session cache for session API auth.
2//
3// Stores session cookies and CSRF tokens on disk so subsequent CLI
4// invocations can skip the login handshake (especially valuable when
5// MFA/TOTP is enabled). Cache files live under `$XDG_CACHE_HOME/unifly/`
6// and are keyed by `{profile_name}_{host_hash}.json`.
7//
8// Security: files are created with 0600 permissions on Unix.
9// Expiry: parsed from the JWT `exp` claim with a 60-second safety margin.
10// Validation: a lightweight probe to `/api/s/{site}/self` confirms the
11// session is still alive before trusting the cache.
12
13use std::fs;
14use std::io::Write;
15use std::path::{Path, PathBuf};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18use base64::Engine;
19use base64::engine::general_purpose::URL_SAFE_NO_PAD;
20use serde::{Deserialize, Serialize};
21use tracing::{debug, warn};
22
23/// Cached session data persisted to disk.
24#[derive(Debug, Serialize, Deserialize)]
25struct CachedSession {
26    /// Session cookie string (e.g. `TOKEN=abc...`).
27    cookie: String,
28    /// CSRF token for UniFi OS.
29    csrf_token: Option<String>,
30    /// Unix timestamp when the session expires.
31    expires_at: u64,
32}
33
34/// Handle for reading/writing a session cache file.
35pub struct SessionCache {
36    path: PathBuf,
37}
38
39impl SessionCache {
40    /// Create a new cache handle for the given profile and controller URL.
41    ///
42    /// Returns `None` if the cache directory can't be determined.
43    pub fn new(profile_name: &str, controller_url: &str) -> Option<Self> {
44        let cache_dir = cache_dir()?;
45        // Hash the URL to avoid path-unsafe characters
46        let url_hash = simple_hash(controller_url);
47        let filename = format!("{profile_name}_{url_hash}.json");
48        Some(Self {
49            path: cache_dir.join(filename),
50        })
51    }
52
53    /// Load a cached session if it exists and hasn't expired.
54    ///
55    /// Returns `(cookie_header, csrf_token)` on success.
56    pub fn load(&self) -> Option<(String, Option<String>)> {
57        let data = fs::read_to_string(&self.path).ok()?;
58        let session: CachedSession = serde_json::from_str(&data).ok()?;
59
60        let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs();
61
62        if now >= session.expires_at {
63            debug!("cached session expired, removing");
64            self.clear();
65            return None;
66        }
67
68        debug!(
69            expires_in_secs = session.expires_at.saturating_sub(now),
70            "loaded cached session"
71        );
72        Some((session.cookie, session.csrf_token))
73    }
74
75    /// Save a session to disk with atomic write.
76    pub fn save(&self, cookie: &str, csrf_token: Option<&str>, expires_at: u64) {
77        let session = CachedSession {
78            cookie: cookie.to_owned(),
79            csrf_token: csrf_token.map(str::to_owned),
80            expires_at,
81        };
82
83        let Ok(json) = serde_json::to_string_pretty(&session) else {
84            warn!("failed to serialize session cache");
85            return;
86        };
87
88        if let Err(e) = atomic_write(&self.path, json.as_bytes()) {
89            warn!(error = %e, "failed to write session cache");
90        } else {
91            debug!("session cached to {}", self.path.display());
92        }
93    }
94
95    /// Remove the cache file.
96    pub fn clear(&self) {
97        let _ = fs::remove_file(&self.path);
98    }
99}
100
101/// Extract the `exp` claim from a JWT token string.
102///
103/// Parses the payload (second segment) of the JWT to read the expiry.
104/// Returns `None` if the token is malformed or missing `exp`.
105pub fn jwt_expiry(token: &str) -> Option<u64> {
106    // JWT cookies are "TOKEN=eyJ...", extract just the JWT value
107    let jwt = token.split(';').next()?.split('=').nth(1)?;
108
109    let parts: Vec<&str> = jwt.split('.').collect();
110    if parts.len() != 3 {
111        return None;
112    }
113
114    let payload = URL_SAFE_NO_PAD.decode(parts[1]).ok()?;
115    let claims: serde_json::Value = serde_json::from_slice(&payload).ok()?;
116    claims["exp"].as_u64()
117}
118
119/// Default fallback expiry: 2 hours from now.
120pub fn fallback_expiry() -> u64 {
121    SystemTime::now()
122        .duration_since(UNIX_EPOCH)
123        .map_or(0, |d| d.as_secs() + 2 * 3600)
124}
125
126/// Safety margin subtracted from JWT expiry (60 seconds).
127pub const EXPIRY_MARGIN_SECS: u64 = 60;
128
129// ── Internals ──────────────────────────────────────────────────────
130
131/// Resolve the cache directory.
132///
133/// Unix (including macOS): `$XDG_CACHE_HOME/unifly/` or `~/.cache/unifly/`
134/// Windows: platform-native via `ProjectDirs`
135///
136/// On macOS, migrates from the old `~/Library/Caches/unifly/` location automatically.
137fn cache_dir() -> Option<PathBuf> {
138    let dir = platform_cache_dir()?;
139    #[cfg(target_os = "macos")]
140    migrate_macos_cache(&dir);
141    Some(dir)
142}
143
144#[cfg(not(windows))]
145fn platform_cache_dir() -> Option<PathBuf> {
146    let base = std::env::var_os("XDG_CACHE_HOME")
147        .map(PathBuf::from)
148        .or_else(|| std::env::var_os("HOME").map(|h| PathBuf::from(h).join(".cache")))?;
149    Some(base.join("unifly"))
150}
151
152#[cfg(windows)]
153fn platform_cache_dir() -> Option<PathBuf> {
154    directories::ProjectDirs::from("", "", "unifly").map(|dirs| dirs.cache_dir().to_owned())
155}
156
157#[cfg(target_os = "macos")]
158fn migrate_macos_cache(new_dir: &Path) {
159    use std::sync::Once;
160    static ONCE: Once = Once::new();
161    ONCE.call_once(|| {
162        let Some(old_dir) =
163            directories::ProjectDirs::from("", "", "unifly").map(|d| d.cache_dir().to_owned())
164        else {
165            return;
166        };
167        if old_dir == *new_dir || !old_dir.exists() || new_dir.exists() {
168            return;
169        }
170        if let Some(parent) = new_dir.parent() {
171            let _ = fs::create_dir_all(parent);
172        }
173        if fs::rename(&old_dir, new_dir).is_ok() {
174            debug!(
175                "migrated session cache from {} to {}",
176                old_dir.display(),
177                new_dir.display()
178            );
179        }
180    });
181}
182
183/// Simple non-cryptographic hash for URL → filename mapping.
184fn simple_hash(s: &str) -> String {
185    let mut hash: u64 = 5381;
186    for byte in s.bytes() {
187        hash = hash.wrapping_mul(33).wrapping_add(u64::from(byte));
188    }
189    format!("{hash:016x}")
190}
191
192/// Atomic write: write to a temp file in the same directory, then rename.
193fn atomic_write(path: &Path, data: &[u8]) -> std::io::Result<()> {
194    if let Some(parent) = path.parent() {
195        fs::create_dir_all(parent)?;
196    }
197
198    let tmp_path = path.with_extension("tmp");
199    let mut file = fs::File::create(&tmp_path)?;
200    file.write_all(data)?;
201    file.flush()?;
202
203    // Restrict to owner-only access
204    #[cfg(unix)]
205    {
206        use std::os::unix::fs::PermissionsExt;
207        file.set_permissions(fs::Permissions::from_mode(0o600))?;
208    }
209
210    drop(file);
211
212    // On Windows, rename fails if the target exists — remove it first.
213    #[cfg(windows)]
214    let _ = fs::remove_file(path);
215
216    fs::rename(&tmp_path, path)?;
217    Ok(())
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn jwt_expiry_parses_valid_token() {
226        // Build a minimal JWT: header.payload.signature
227        let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"HS256"}"#);
228        let payload = URL_SAFE_NO_PAD.encode(r#"{"exp":1700000000}"#);
229        let token = format!("TOKEN={header}.{payload}.sig");
230        assert_eq!(jwt_expiry(&token), Some(1_700_000_000));
231    }
232
233    #[test]
234    fn jwt_expiry_returns_none_for_garbage() {
235        assert_eq!(jwt_expiry("not-a-jwt"), None);
236        assert_eq!(jwt_expiry("TOKEN=a.b"), None);
237    }
238
239    #[test]
240    fn simple_hash_is_deterministic() {
241        let a = simple_hash("https://192.168.1.1");
242        let b = simple_hash("https://192.168.1.1");
243        assert_eq!(a, b);
244        assert_ne!(a, simple_hash("https://10.0.0.1"));
245    }
246
247    #[test]
248    fn session_cache_round_trips() {
249        let dir = tempfile::tempdir().expect("tmpdir");
250        let cache = SessionCache {
251            path: dir.path().join("test.json"),
252        };
253
254        cache.save("TOKEN=abc", Some("csrf123"), fallback_expiry());
255        let loaded = cache.load().expect("cache should load");
256        assert_eq!(loaded.0, "TOKEN=abc");
257        assert_eq!(loaded.1.as_deref(), Some("csrf123"));
258    }
259
260    #[test]
261    fn expired_session_returns_none() {
262        let dir = tempfile::tempdir().expect("tmpdir");
263        let cache = SessionCache {
264            path: dir.path().join("expired.json"),
265        };
266
267        cache.save("TOKEN=old", None, 0); // expired at epoch
268        assert!(cache.load().is_none());
269    }
270}