1use std::path::PathBuf;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use serde::{Deserialize, Serialize};
15
16use crate::relay::RelayClient;
17use crate::Result;
18
19const BILLING_CACHE_TTL_SECS: u64 = 7200;
21
22const QUOTA_WARNING_THRESHOLD: f64 = 0.80;
24
25const DEFAULT_EXTRACTION_INTERVAL: u32 = 3;
27
28const DEFAULT_MAX_FACTS_PER_EXTRACTION: u32 = 15;
30
31const DEFAULT_CANDIDATE_POOL_FREE: usize = 100;
33
34const DEFAULT_CANDIDATE_POOL_PRO: usize = 250;
36
37#[derive(Debug, Clone, Serialize, Deserialize, Default)]
51pub struct FeatureFlags {
52 pub llm_dedup: Option<bool>,
53 pub extraction_interval: Option<u32>,
54 pub max_facts_per_extraction: Option<u32>,
55 pub max_candidate_pool: Option<usize>,
56 pub custom_extract_interval: Option<bool>,
57 pub min_extract_interval: Option<u32>,
58
59 pub cosine_threshold: Option<f64>,
63 pub relevance_threshold: Option<f64>,
64 pub semantic_skip_threshold: Option<f64>,
65 pub min_importance: Option<u32>,
66 pub cache_ttl_ms: Option<u64>,
67 pub trapdoor_batch_size: Option<usize>,
68 pub subgraph_page_size: Option<usize>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct BillingCache {
78 pub tier: String,
79 pub facts_used: u64,
80 pub facts_limit: u64,
81 #[serde(default)]
82 pub features: FeatureFlags,
83 pub checked_at: u64,
85}
86
87impl BillingCache {
88 pub fn is_fresh(&self) -> bool {
90 let now_ms = SystemTime::now()
91 .duration_since(UNIX_EPOCH)
92 .unwrap_or_default()
93 .as_millis() as u64;
94 now_ms.saturating_sub(self.checked_at) < BILLING_CACHE_TTL_SECS * 1000
95 }
96
97 pub fn quota_fraction(&self) -> f64 {
99 if self.facts_limit == 0 {
100 return 0.0;
101 }
102 self.facts_used as f64 / self.facts_limit as f64
103 }
104
105 pub fn is_quota_warning(&self) -> bool {
107 self.quota_fraction() > QUOTA_WARNING_THRESHOLD
108 }
109
110 pub fn quota_warning_message(&self) -> Option<String> {
112 if !self.is_quota_warning() {
113 return None;
114 }
115 let pct = (self.quota_fraction() * 100.0).round() as u32;
116 Some(format!(
117 "Memory usage at {}% ({}/{} memories). Upgrade to Pro for unlimited storage.",
118 pct, self.facts_used, self.facts_limit
119 ))
120 }
121
122 pub fn is_pro(&self) -> bool {
124 self.tier == "pro"
125 }
126}
127
128fn cache_path() -> PathBuf {
134 crate::setup::config_dir().join("billing-cache.json")
135}
136
137pub fn read_cache() -> Option<BillingCache> {
139 let path = cache_path();
140 let data = std::fs::read_to_string(&path).ok()?;
141 let cache: BillingCache = serde_json::from_str(&data).ok()?;
142 if cache.is_fresh() {
143 Some(cache)
144 } else {
145 None
146 }
147}
148
149pub fn write_cache(cache: &BillingCache) {
151 let dir = crate::setup::config_dir();
152 let _ = std::fs::create_dir_all(&dir);
153 let path = cache_path();
154 if let Ok(data) = serde_json::to_string(cache) {
155 let _ = std::fs::write(&path, data);
156 }
157}
158
159pub fn invalidate_cache() {
161 let _ = std::fs::remove_file(cache_path());
162}
163
164pub async fn fetch_billing_status(relay: &RelayClient) -> Result<BillingCache> {
172 if let Some(cached) = read_cache() {
174 return Ok(cached);
175 }
176
177 let status = relay.billing_status().await?;
179
180 let features: FeatureFlags = status
182 .features
183 .and_then(|v| serde_json::from_value(v).ok())
184 .unwrap_or_default();
185
186 let now_ms = SystemTime::now()
187 .duration_since(UNIX_EPOCH)
188 .unwrap_or_default()
189 .as_millis() as u64;
190
191 let cache = BillingCache {
192 tier: status.tier.unwrap_or_else(|| "free".into()),
193 facts_used: status.facts_used.unwrap_or(0),
194 facts_limit: status.facts_limit.unwrap_or(500),
195 features,
196 checked_at: now_ms,
197 };
198
199 write_cache(&cache);
200 Ok(cache)
201}
202
203pub fn get_extraction_interval(cache: Option<&BillingCache>) -> u32 {
211 if let Some(c) = cache {
212 if let Some(interval) = c.features.extraction_interval {
213 return interval;
214 }
215 }
216 std::env::var("TOTALRECLAW_EXTRACT_INTERVAL")
217 .ok()
218 .and_then(|s| s.parse().ok())
219 .unwrap_or(DEFAULT_EXTRACTION_INTERVAL)
220}
221
222pub fn get_max_facts_per_extraction(cache: Option<&BillingCache>) -> u32 {
226 if let Some(c) = cache {
227 if let Some(max) = c.features.max_facts_per_extraction {
228 return max;
229 }
230 }
231 DEFAULT_MAX_FACTS_PER_EXTRACTION
232}
233
234pub fn get_max_candidate_pool(cache: Option<&BillingCache>) -> usize {
238 if let Some(c) = cache {
240 if let Some(pool) = c.features.max_candidate_pool {
241 return pool;
242 }
243 }
244
245 let is_pro = cache.map_or(false, |c| c.is_pro());
247 if is_pro {
248 std::env::var("CANDIDATE_POOL_MAX_PRO")
249 .ok()
250 .and_then(|s| s.parse().ok())
251 .unwrap_or(DEFAULT_CANDIDATE_POOL_PRO)
252 } else {
253 std::env::var("CANDIDATE_POOL_MAX_FREE")
254 .ok()
255 .and_then(|s| s.parse().ok())
256 .unwrap_or(DEFAULT_CANDIDATE_POOL_FREE)
257 }
258}
259
260pub fn is_llm_dedup_enabled(cache: Option<&BillingCache>) -> bool {
264 if let Some(c) = cache {
265 if c.features.llm_dedup == Some(false) {
266 return false;
267 }
268 }
269 true
270}
271
272#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_billing_cache_fresh() {
282 let now_ms = SystemTime::now()
283 .duration_since(UNIX_EPOCH)
284 .unwrap()
285 .as_millis() as u64;
286
287 let cache = BillingCache {
288 tier: "free".into(),
289 facts_used: 100,
290 facts_limit: 500,
291 features: FeatureFlags::default(),
292 checked_at: now_ms,
293 };
294 assert!(cache.is_fresh());
295
296 let expired = BillingCache {
298 checked_at: now_ms - 3 * 60 * 60 * 1000,
299 ..cache.clone()
300 };
301 assert!(!expired.is_fresh());
302 }
303
304 #[test]
305 fn test_quota_fraction() {
306 let cache = BillingCache {
307 tier: "free".into(),
308 facts_used: 420,
309 facts_limit: 500,
310 features: FeatureFlags::default(),
311 checked_at: 0,
312 };
313 assert!((cache.quota_fraction() - 0.84).abs() < 0.01);
314 assert!(cache.is_quota_warning());
315
316 let low_usage = BillingCache {
317 facts_used: 100,
318 ..cache
319 };
320 assert!(!low_usage.is_quota_warning());
321 }
322
323 #[test]
324 fn test_quota_warning_message() {
325 let now_ms = SystemTime::now()
326 .duration_since(UNIX_EPOCH)
327 .unwrap()
328 .as_millis() as u64;
329 let cache = BillingCache {
330 tier: "free".into(),
331 facts_used: 450,
332 facts_limit: 500,
333 features: FeatureFlags::default(),
334 checked_at: now_ms,
335 };
336 let msg = cache.quota_warning_message();
337 assert!(msg.is_some());
338 assert!(msg.unwrap().contains("90%"));
339 }
340
341 #[test]
342 fn test_feature_flags_extraction_interval() {
343 let cache = BillingCache {
344 tier: "pro".into(),
345 facts_used: 0,
346 facts_limit: 0,
347 features: FeatureFlags {
348 extraction_interval: Some(5),
349 ..Default::default()
350 },
351 checked_at: 0,
352 };
353 assert_eq!(get_extraction_interval(Some(&cache)), 5);
354
355 assert_eq!(get_extraction_interval(None), DEFAULT_EXTRACTION_INTERVAL);
357 }
358
359 #[test]
360 fn test_feature_flags_max_candidate_pool() {
361 let cache = BillingCache {
362 tier: "pro".into(),
363 facts_used: 0,
364 facts_limit: 0,
365 features: FeatureFlags {
366 max_candidate_pool: Some(300),
367 ..Default::default()
368 },
369 checked_at: 0,
370 };
371 assert_eq!(get_max_candidate_pool(Some(&cache)), 300);
372
373 let free_cache = BillingCache {
375 tier: "free".into(),
376 features: FeatureFlags::default(),
377 ..cache.clone()
378 };
379 assert_eq!(get_max_candidate_pool(Some(&free_cache)), DEFAULT_CANDIDATE_POOL_FREE);
380
381 let pro_no_override = BillingCache {
383 tier: "pro".into(),
384 features: FeatureFlags::default(),
385 ..cache
386 };
387 assert_eq!(get_max_candidate_pool(Some(&pro_no_override)), DEFAULT_CANDIDATE_POOL_PRO);
388 }
389
390 #[test]
391 fn test_feature_flags_deserialization() {
392 let json = r#"{
393 "llm_dedup": true,
394 "extraction_interval": 3,
395 "max_facts_per_extraction": 15,
396 "max_candidate_pool": 200
397 }"#;
398 let flags: FeatureFlags = serde_json::from_str(json).unwrap();
399 assert_eq!(flags.llm_dedup, Some(true));
400 assert_eq!(flags.extraction_interval, Some(3));
401 assert_eq!(flags.max_facts_per_extraction, Some(15));
402 assert_eq!(flags.max_candidate_pool, Some(200));
403 }
404
405 #[test]
406 fn test_llm_dedup_kill_switch() {
407 let cache = BillingCache {
408 tier: "free".into(),
409 facts_used: 0,
410 facts_limit: 500,
411 features: FeatureFlags {
412 llm_dedup: Some(false),
413 ..Default::default()
414 },
415 checked_at: 0,
416 };
417 assert!(!is_llm_dedup_enabled(Some(&cache)));
418 assert!(is_llm_dedup_enabled(None)); }
420}