Skip to main content

totalreclaw_memory/
billing.rs

1//! Billing cache with 2-hour TTL and feature flags parsing.
2//!
3//! Matches the TypeScript plugin's billing cache (`~/.totalreclaw/billing-cache.json`).
4//!
5//! Feature flags from the relay's `GET /v1/billing/status` response drive:
6//! - Extraction interval (`extraction_interval`)
7//! - Max facts per extraction (`max_facts_per_extraction`)
8//! - Max candidate pool size (`max_candidate_pool`)
9//! - LLM dedup kill-switch (`llm_dedup`)
10
11use std::path::PathBuf;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use serde::{Deserialize, Serialize};
15
16use crate::relay::RelayClient;
17use crate::Result;
18
19/// Billing cache TTL: 2 hours (7200 seconds), matching all other clients.
20const BILLING_CACHE_TTL_SECS: u64 = 7200;
21
22/// Quota warning threshold: 80%.
23const QUOTA_WARNING_THRESHOLD: f64 = 0.80;
24
25/// Default extraction interval (turns).
26const DEFAULT_EXTRACTION_INTERVAL: u32 = 3;
27
28/// Default max facts per extraction.
29const DEFAULT_MAX_FACTS_PER_EXTRACTION: u32 = 15;
30
31/// Default candidate pool size (free tier).
32const DEFAULT_CANDIDATE_POOL_FREE: usize = 100;
33
34/// Default candidate pool size (pro tier).
35const DEFAULT_CANDIDATE_POOL_PRO: usize = 250;
36
37// ---------------------------------------------------------------------------
38// Feature flags (from relay billing response)
39// ---------------------------------------------------------------------------
40
41/// Feature flags parsed from the billing status response.
42///
43/// The relay returns these in the `features` JSON blob on the billing
44/// status endpoint. Clients consult them at the call-site when resolving
45/// tuning knobs; env-var fallbacks are retained for self-hosted deployments.
46///
47/// See `docs/guides/env-vars-reference.md` — as of the v1 env var cleanup,
48/// managed-service clients read tuning knobs from this struct and never
49/// from env vars.
50#[derive(Debug, Clone, Serialize, Deserialize, Default)]
51pub struct FeatureFlags {
52    pub llm_dedup: Option<bool>,
53    pub extraction_interval: Option<u32>,
54    pub max_facts_per_extraction: Option<u32>,
55    pub max_candidate_pool: Option<usize>,
56    pub custom_extract_interval: Option<bool>,
57    pub min_extract_interval: Option<u32>,
58
59    // Tuning knobs moved to server-side delivery in the v1 env cleanup.
60    // Optional — when absent, clients fall back to their built-in defaults
61    // (or their self-hosted env-var overrides).
62    pub cosine_threshold: Option<f64>,
63    pub relevance_threshold: Option<f64>,
64    pub semantic_skip_threshold: Option<f64>,
65    pub min_importance: Option<u32>,
66    pub cache_ttl_ms: Option<u64>,
67    pub trapdoor_batch_size: Option<usize>,
68    pub subgraph_page_size: Option<usize>,
69}
70
71// ---------------------------------------------------------------------------
72// Billing cache entry (persisted to disk)
73// ---------------------------------------------------------------------------
74
75/// Cached billing status, matching the TypeScript `BillingCache` interface.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct BillingCache {
78    pub tier: String,
79    pub facts_used: u64,
80    pub facts_limit: u64,
81    #[serde(default)]
82    pub features: FeatureFlags,
83    /// Unix epoch millis when this cache was written.
84    pub checked_at: u64,
85}
86
87impl BillingCache {
88    /// Whether this cache entry is still valid (within TTL).
89    pub fn is_fresh(&self) -> bool {
90        let now_ms = SystemTime::now()
91            .duration_since(UNIX_EPOCH)
92            .unwrap_or_default()
93            .as_millis() as u64;
94        now_ms.saturating_sub(self.checked_at) < BILLING_CACHE_TTL_SECS * 1000
95    }
96
97    /// Quota usage as a fraction (0.0 to 1.0).
98    pub fn quota_fraction(&self) -> f64 {
99        if self.facts_limit == 0 {
100            return 0.0;
101        }
102        self.facts_used as f64 / self.facts_limit as f64
103    }
104
105    /// Whether the user is above the 80% quota warning threshold.
106    pub fn is_quota_warning(&self) -> bool {
107        self.quota_fraction() > QUOTA_WARNING_THRESHOLD
108    }
109
110    /// Human-readable quota warning message (or None if under threshold).
111    pub fn quota_warning_message(&self) -> Option<String> {
112        if !self.is_quota_warning() {
113            return None;
114        }
115        let pct = (self.quota_fraction() * 100.0).round() as u32;
116        Some(format!(
117            "Memory usage at {}% ({}/{} memories). Upgrade to Pro for unlimited storage.",
118            pct, self.facts_used, self.facts_limit
119        ))
120    }
121
122    /// Is this a Pro tier user?
123    pub fn is_pro(&self) -> bool {
124        self.tier == "pro"
125    }
126}
127
128// ---------------------------------------------------------------------------
129// Disk persistence
130// ---------------------------------------------------------------------------
131
132/// Get the billing cache file path (`~/.totalreclaw/billing-cache.json`).
133fn cache_path() -> PathBuf {
134    crate::setup::config_dir().join("billing-cache.json")
135}
136
137/// Read the billing cache from disk. Returns `None` if missing, expired, or corrupt.
138pub fn read_cache() -> Option<BillingCache> {
139    let path = cache_path();
140    let data = std::fs::read_to_string(&path).ok()?;
141    let cache: BillingCache = serde_json::from_str(&data).ok()?;
142    if cache.is_fresh() {
143        Some(cache)
144    } else {
145        None
146    }
147}
148
149/// Write the billing cache to disk. Best-effort (does not error on failure).
150pub fn write_cache(cache: &BillingCache) {
151    let dir = crate::setup::config_dir();
152    let _ = std::fs::create_dir_all(&dir);
153    let path = cache_path();
154    if let Ok(data) = serde_json::to_string(cache) {
155        let _ = std::fs::write(&path, data);
156    }
157}
158
159/// Invalidate (delete) the billing cache. Used on 403 responses.
160pub fn invalidate_cache() {
161    let _ = std::fs::remove_file(cache_path());
162}
163
164// ---------------------------------------------------------------------------
165// Fetch + cache from relay
166// ---------------------------------------------------------------------------
167
168/// Fetch billing status from the relay server, update the local cache, and return it.
169///
170/// If the cache is fresh, returns the cached value without a network call.
171pub async fn fetch_billing_status(relay: &RelayClient) -> Result<BillingCache> {
172    // Return cached if fresh
173    if let Some(cached) = read_cache() {
174        return Ok(cached);
175    }
176
177    // Fetch from relay
178    let status = relay.billing_status().await?;
179
180    // Parse feature flags from the `features` JSON blob
181    let features: FeatureFlags = status
182        .features
183        .and_then(|v| serde_json::from_value(v).ok())
184        .unwrap_or_default();
185
186    let now_ms = SystemTime::now()
187        .duration_since(UNIX_EPOCH)
188        .unwrap_or_default()
189        .as_millis() as u64;
190
191    let cache = BillingCache {
192        tier: status.tier.unwrap_or_else(|| "free".into()),
193        facts_used: status.facts_used.unwrap_or(0),
194        facts_limit: status.facts_limit.unwrap_or(500),
195        features,
196        checked_at: now_ms,
197    };
198
199    write_cache(&cache);
200    Ok(cache)
201}
202
203// ---------------------------------------------------------------------------
204// Feature flag accessors (with env overrides + defaults)
205// ---------------------------------------------------------------------------
206
207/// Get the effective extraction interval.
208///
209/// Priority: server-side config (from billing cache) > env var > default (3).
210pub fn get_extraction_interval(cache: Option<&BillingCache>) -> u32 {
211    if let Some(c) = cache {
212        if let Some(interval) = c.features.extraction_interval {
213            return interval;
214        }
215    }
216    std::env::var("TOTALRECLAW_EXTRACT_INTERVAL")
217        .ok()
218        .and_then(|s| s.parse().ok())
219        .unwrap_or(DEFAULT_EXTRACTION_INTERVAL)
220}
221
222/// Get the max facts per extraction.
223///
224/// Priority: server-side config > default (15).
225pub fn get_max_facts_per_extraction(cache: Option<&BillingCache>) -> u32 {
226    if let Some(c) = cache {
227        if let Some(max) = c.features.max_facts_per_extraction {
228            return max;
229        }
230    }
231    DEFAULT_MAX_FACTS_PER_EXTRACTION
232}
233
234/// Get the max candidate pool size for search.
235///
236/// Priority: server-side config > env var > tier default.
237pub fn get_max_candidate_pool(cache: Option<&BillingCache>) -> usize {
238    // Server-side value first
239    if let Some(c) = cache {
240        if let Some(pool) = c.features.max_candidate_pool {
241            return pool;
242        }
243    }
244
245    // Env overrides
246    let is_pro = cache.map_or(false, |c| c.is_pro());
247    if is_pro {
248        std::env::var("CANDIDATE_POOL_MAX_PRO")
249            .ok()
250            .and_then(|s| s.parse().ok())
251            .unwrap_or(DEFAULT_CANDIDATE_POOL_PRO)
252    } else {
253        std::env::var("CANDIDATE_POOL_MAX_FREE")
254            .ok()
255            .and_then(|s| s.parse().ok())
256            .unwrap_or(DEFAULT_CANDIDATE_POOL_FREE)
257    }
258}
259
260/// Whether LLM-guided dedup is enabled.
261///
262/// Always true unless the server explicitly sets `llm_dedup: false` (kill-switch).
263pub fn is_llm_dedup_enabled(cache: Option<&BillingCache>) -> bool {
264    if let Some(c) = cache {
265        if c.features.llm_dedup == Some(false) {
266            return false;
267        }
268    }
269    true
270}
271
272// ---------------------------------------------------------------------------
273// Tests
274// ---------------------------------------------------------------------------
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_billing_cache_fresh() {
282        let now_ms = SystemTime::now()
283            .duration_since(UNIX_EPOCH)
284            .unwrap()
285            .as_millis() as u64;
286
287        let cache = BillingCache {
288            tier: "free".into(),
289            facts_used: 100,
290            facts_limit: 500,
291            features: FeatureFlags::default(),
292            checked_at: now_ms,
293        };
294        assert!(cache.is_fresh());
295
296        // Expired cache (3 hours ago)
297        let expired = BillingCache {
298            checked_at: now_ms - 3 * 60 * 60 * 1000,
299            ..cache.clone()
300        };
301        assert!(!expired.is_fresh());
302    }
303
304    #[test]
305    fn test_quota_fraction() {
306        let cache = BillingCache {
307            tier: "free".into(),
308            facts_used: 420,
309            facts_limit: 500,
310            features: FeatureFlags::default(),
311            checked_at: 0,
312        };
313        assert!((cache.quota_fraction() - 0.84).abs() < 0.01);
314        assert!(cache.is_quota_warning());
315
316        let low_usage = BillingCache {
317            facts_used: 100,
318            ..cache
319        };
320        assert!(!low_usage.is_quota_warning());
321    }
322
323    #[test]
324    fn test_quota_warning_message() {
325        let now_ms = SystemTime::now()
326            .duration_since(UNIX_EPOCH)
327            .unwrap()
328            .as_millis() as u64;
329        let cache = BillingCache {
330            tier: "free".into(),
331            facts_used: 450,
332            facts_limit: 500,
333            features: FeatureFlags::default(),
334            checked_at: now_ms,
335        };
336        let msg = cache.quota_warning_message();
337        assert!(msg.is_some());
338        assert!(msg.unwrap().contains("90%"));
339    }
340
341    #[test]
342    fn test_feature_flags_extraction_interval() {
343        let cache = BillingCache {
344            tier: "pro".into(),
345            facts_used: 0,
346            facts_limit: 0,
347            features: FeatureFlags {
348                extraction_interval: Some(5),
349                ..Default::default()
350            },
351            checked_at: 0,
352        };
353        assert_eq!(get_extraction_interval(Some(&cache)), 5);
354
355        // Without cache, returns default
356        assert_eq!(get_extraction_interval(None), DEFAULT_EXTRACTION_INTERVAL);
357    }
358
359    #[test]
360    fn test_feature_flags_max_candidate_pool() {
361        let cache = BillingCache {
362            tier: "pro".into(),
363            facts_used: 0,
364            facts_limit: 0,
365            features: FeatureFlags {
366                max_candidate_pool: Some(300),
367                ..Default::default()
368            },
369            checked_at: 0,
370        };
371        assert_eq!(get_max_candidate_pool(Some(&cache)), 300);
372
373        // Free tier default
374        let free_cache = BillingCache {
375            tier: "free".into(),
376            features: FeatureFlags::default(),
377            ..cache.clone()
378        };
379        assert_eq!(get_max_candidate_pool(Some(&free_cache)), DEFAULT_CANDIDATE_POOL_FREE);
380
381        // Pro tier default (no server override)
382        let pro_no_override = BillingCache {
383            tier: "pro".into(),
384            features: FeatureFlags::default(),
385            ..cache
386        };
387        assert_eq!(get_max_candidate_pool(Some(&pro_no_override)), DEFAULT_CANDIDATE_POOL_PRO);
388    }
389
390    #[test]
391    fn test_feature_flags_deserialization() {
392        let json = r#"{
393            "llm_dedup": true,
394            "extraction_interval": 3,
395            "max_facts_per_extraction": 15,
396            "max_candidate_pool": 200
397        }"#;
398        let flags: FeatureFlags = serde_json::from_str(json).unwrap();
399        assert_eq!(flags.llm_dedup, Some(true));
400        assert_eq!(flags.extraction_interval, Some(3));
401        assert_eq!(flags.max_facts_per_extraction, Some(15));
402        assert_eq!(flags.max_candidate_pool, Some(200));
403    }
404
405    #[test]
406    fn test_llm_dedup_kill_switch() {
407        let cache = BillingCache {
408            tier: "free".into(),
409            facts_used: 0,
410            facts_limit: 500,
411            features: FeatureFlags {
412                llm_dedup: Some(false),
413                ..Default::default()
414            },
415            checked_at: 0,
416        };
417        assert!(!is_llm_dedup_enabled(Some(&cache)));
418        assert!(is_llm_dedup_enabled(None)); // Default: enabled
419    }
420}