Skip to main content

shodh_memory/
config.rs

1//! Configuration management for Shodh-Memory
2//!
3//! All configurable parameters in one place with environment variable overrides.
4//! Follows the principle: sensible defaults, configurable in production.
5
6use std::env;
7use std::path::PathBuf;
8use tracing::info;
9
10/// Legacy storage directory name used in versions <= 0.1.80.
11const LEGACY_STORAGE_DIR: &str = "shodh_memory_data";
12
13/// Returns the platform-appropriate default storage path.
14///
15/// Resolution order:
16/// 1. If `./shodh_memory_data` exists in the cwd (legacy location), use it and warn.
17///    This preserves data for users upgrading from <= 0.1.80.
18/// 2. Otherwise, use the platform data directory:
19///    - Linux: `~/.local/share/shodh-memory/`
20///    - macOS: `~/Library/Application Support/shodh-memory/`
21///    - Windows: `C:\Users\<user>\AppData\Roaming\shodh-memory\`
22/// 3. Falls back to `./shodh_memory_data` only if the home directory cannot be determined.
23pub fn default_storage_path() -> PathBuf {
24    let legacy_path = PathBuf::from(LEGACY_STORAGE_DIR);
25    if legacy_path.exists() && legacy_path.is_dir() {
26        eprintln!(
27            "[shodh-memory] Found legacy data at ./{LEGACY_STORAGE_DIR}/ in the current directory. \
28             Using it for backward compatibility. To migrate, move it to the platform default \
29             and unset SHODH_MEMORY_PATH. See: https://github.com/varun29ankuS/shodh-memory/issues/89"
30        );
31        return legacy_path;
32    }
33
34    dirs::data_dir()
35        .map(|p| p.join("shodh-memory"))
36        .unwrap_or_else(|| PathBuf::from(LEGACY_STORAGE_DIR))
37}
38
39/// CORS configuration
40#[derive(Debug, Clone)]
41pub struct CorsConfig {
42    /// Allowed origins (empty = allow all)
43    pub allowed_origins: Vec<String>,
44    /// Allowed HTTP methods
45    pub allowed_methods: Vec<String>,
46    /// Allowed headers
47    pub allowed_headers: Vec<String>,
48    /// Whether to allow credentials
49    pub allow_credentials: bool,
50    /// Max age for preflight cache (seconds)
51    pub max_age_seconds: u64,
52}
53
54impl Default for CorsConfig {
55    fn default() -> Self {
56        Self {
57            allowed_origins: Vec::new(), // Empty = allow all origins
58            allowed_methods: vec![
59                "GET".to_string(),
60                "POST".to_string(),
61                "PUT".to_string(),
62                "DELETE".to_string(),
63                "OPTIONS".to_string(),
64            ],
65            allowed_headers: vec![
66                "Content-Type".to_string(),
67                "Authorization".to_string(),
68                "X-Request-ID".to_string(),
69            ],
70            allow_credentials: false,
71            max_age_seconds: 86400, // 24 hours
72        }
73    }
74}
75
76impl CorsConfig {
77    /// Load from environment variables with production safety checks
78    ///
79    /// In production mode (SHODH_ENV=production), warns if CORS origins are not configured.
80    /// This prevents accidentally running in production with permissive CORS.
81    pub fn from_env() -> Self {
82        let mut config = Self::default();
83
84        if let Ok(origins) = env::var("SHODH_CORS_ORIGINS") {
85            config.allowed_origins = origins
86                .split(',')
87                .map(|s| s.trim().to_string())
88                .filter(|s| !s.is_empty())
89                .collect();
90        }
91
92        if let Ok(methods) = env::var("SHODH_CORS_METHODS") {
93            config.allowed_methods = methods
94                .split(',')
95                .map(|s| s.trim().to_uppercase())
96                .filter(|s| !s.is_empty())
97                .collect();
98        }
99
100        if let Ok(headers) = env::var("SHODH_CORS_HEADERS") {
101            config.allowed_headers = headers
102                .split(',')
103                .map(|s| s.trim().to_string())
104                .filter(|s| !s.is_empty())
105                .collect();
106        }
107
108        if let Ok(val) = env::var("SHODH_CORS_CREDENTIALS") {
109            config.allow_credentials = val.to_lowercase() == "true" || val == "1";
110        }
111
112        if let Ok(val) = env::var("SHODH_CORS_MAX_AGE") {
113            if let Ok(n) = val.parse() {
114                config.max_age_seconds = n;
115            }
116        }
117
118        // Safety check: warn if CORS origins are not configured
119        // Warns in ALL modes unless suppressed with SHODH_CORS_WARN=false
120        let is_production = env::var("SHODH_ENV")
121            .map(|v| {
122                let v = v.to_lowercase();
123                v == "production" || v == "prod"
124            })
125            .unwrap_or(false);
126
127        let cors_warn_suppressed = env::var("SHODH_CORS_WARN")
128            .map(|v| v.to_lowercase() == "false" || v == "0")
129            .unwrap_or(false);
130
131        if config.allowed_origins.is_empty() && !cors_warn_suppressed {
132            if is_production {
133                tracing::warn!(
134                    "⚠️  PRODUCTION WARNING: CORS allows all origins. Set SHODH_CORS_ORIGINS for security."
135                );
136            } else {
137                tracing::warn!(
138                    "CORS allows all origins (no SHODH_CORS_ORIGINS set). \
139                     Set SHODH_CORS_WARN=false to suppress this warning."
140                );
141            }
142        }
143
144        config
145    }
146
147    /// Check if any origin restrictions are configured
148    pub fn is_restricted(&self) -> bool {
149        !self.allowed_origins.is_empty()
150    }
151
152    /// Convert to tower-http CorsLayer
153    pub fn to_layer(&self) -> tower_http::cors::CorsLayer {
154        use tower_http::cors::{AllowOrigin, Any, CorsLayer};
155
156        let mut layer = CorsLayer::new();
157
158        // Configure allowed origins
159        if self.allowed_origins.is_empty() {
160            // Intentionally permissive - no origins configured
161            layer = layer.allow_origin(Any);
162        } else {
163            // Parse configured origins, tracking failures
164            let mut valid_origins = Vec::new();
165            let mut invalid_origins = Vec::new();
166
167            for origin_str in &self.allowed_origins {
168                match origin_str.parse::<axum::http::HeaderValue>() {
169                    Ok(origin) => valid_origins.push(origin),
170                    Err(_) => invalid_origins.push(origin_str.clone()),
171                }
172            }
173
174            // Log any invalid origins
175            for invalid in &invalid_origins {
176                tracing::warn!("CORS: Invalid origin '{}' - skipping", invalid);
177            }
178
179            if valid_origins.is_empty() {
180                // All configured origins failed to parse - this is a config error
181                // Do NOT fall back to permissive - that would be a security hole
182                tracing::error!(
183                    "CORS: All {} configured origin(s) failed to parse. \
184                     Rejecting all cross-origin requests. Fix SHODH_CORS_ORIGINS.",
185                    self.allowed_origins.len()
186                );
187                // Use an impossible origin to effectively deny all CORS
188                layer =
189                    layer.allow_origin(AllowOrigin::list(Vec::<axum::http::HeaderValue>::new()));
190            } else {
191                if !invalid_origins.is_empty() {
192                    tracing::info!(
193                        "CORS: Using {} valid origin(s), {} invalid skipped",
194                        valid_origins.len(),
195                        invalid_origins.len()
196                    );
197                }
198                layer = layer.allow_origin(AllowOrigin::list(valid_origins));
199            }
200        }
201
202        // Configure allowed methods
203        let methods: Vec<axum::http::Method> = self
204            .allowed_methods
205            .iter()
206            .filter_map(|m| m.parse().ok())
207            .collect();
208        if methods.is_empty() {
209            layer = layer.allow_methods(Any);
210        } else {
211            layer = layer.allow_methods(methods);
212        }
213
214        // Configure allowed headers
215        let headers: Vec<axum::http::HeaderName> = self
216            .allowed_headers
217            .iter()
218            .filter_map(|h| h.parse().ok())
219            .collect();
220        if headers.is_empty() {
221            layer = layer.allow_headers(Any);
222        } else {
223            layer = layer.allow_headers(headers);
224        }
225
226        // Configure credentials
227        if self.allow_credentials {
228            layer = layer.allow_credentials(true);
229        }
230
231        // Configure max age
232        layer = layer.max_age(std::time::Duration::from_secs(self.max_age_seconds));
233
234        layer
235    }
236}
237
238/// Server configuration loaded from environment with defaults
239#[derive(Debug, Clone)]
240pub struct ServerConfig {
241    /// Server host address (default: 127.0.0.1)
242    /// Set to 0.0.0.0 for Docker or network-accessible deployments
243    pub host: String,
244
245    /// Server port (default: 3030)
246    pub port: u16,
247
248    /// Storage path for RocksDB (default: platform data dir, e.g. ~/.local/share/shodh-memory/)
249    pub storage_path: PathBuf,
250
251    /// Maximum users to keep in memory LRU cache (default: 1000)
252    pub max_users_in_memory: usize,
253
254    /// Maximum audit log entries per user (default: 10000)
255    pub audit_max_entries_per_user: usize,
256
257    /// Audit log rotation check interval (default: 100)
258    pub audit_rotation_check_interval: usize,
259
260    /// Audit log retention days (default: 30)
261    pub audit_retention_days: u64,
262
263    /// Rate limit: requests per second (default: 4000 - LLM-friendly)
264    pub rate_limit_per_second: u64,
265
266    /// Rate limit: burst size (default: 8000 - allows rapid agent bursts)
267    pub rate_limit_burst: u32,
268
269    /// Maximum concurrent requests (default: 200)
270    pub max_concurrent_requests: usize,
271
272    /// Request timeout in seconds (default: 60)
273    /// Requests exceeding this duration are terminated with 408 status
274    pub request_timeout_secs: u64,
275
276    /// Whether running in production mode
277    pub is_production: bool,
278
279    /// CORS configuration
280    pub cors: CorsConfig,
281
282    /// Memory maintenance interval in seconds (default: 300 = 5 minutes)
283    /// Controls how often consolidation and activation decay run
284    pub maintenance_interval_secs: u64,
285
286    /// Activation decay factor per maintenance cycle (default: 0.95)
287    /// Memories lose 5% activation each cycle: A_new = A_old * 0.95
288    pub activation_decay_factor: f32,
289
290    /// Backup configuration
291    /// Automatic backup interval in seconds (default: 86400 = 24 hours)
292    /// Set to 0 to disable automatic backups
293    pub backup_interval_secs: u64,
294
295    /// Maximum backups to keep per user (default: 7)
296    /// Older backups are automatically purged
297    pub backup_max_count: usize,
298
299    /// Whether backups are enabled (default: true in production, false in dev)
300    pub backup_enabled: bool,
301
302    /// Maximum entities extracted per memory for graph insertion (default: 10)
303    /// Caps the number of NER/tag/regex entities to prevent O(n²) edge explosion
304    /// in the knowledge graph. 10 entities → max 45 co-occurrence edges.
305    pub max_entities_per_memory: usize,
306}
307
308impl Default for ServerConfig {
309    fn default() -> Self {
310        Self {
311            host: "127.0.0.1".to_string(),
312            port: 3030,
313            storage_path: default_storage_path(),
314            max_users_in_memory: 1000,
315            audit_max_entries_per_user: 10_000,
316            audit_rotation_check_interval: 100,
317            audit_retention_days: 30,
318            rate_limit_per_second: 4000,
319            rate_limit_burst: 8000,
320            max_concurrent_requests: 200,
321            request_timeout_secs: 60,
322            is_production: false,
323            cors: CorsConfig::default(),
324            maintenance_interval_secs: 3600, // 1 hour (aligns with biological consolidation timescales)
325            activation_decay_factor: 0.98, // 2% decay per cycle → 62% retained after 24hr, near-zero at 30 days
326            backup_interval_secs: 86400,   // 24 hours
327            backup_max_count: 7,           // Keep 7 backups (1 week of daily backups)
328            backup_enabled: false,         // Disabled by default, auto-enabled in production
329            max_entities_per_memory: 10,   // Cap entities per memory (10 → max 45 edges)
330        }
331    }
332}
333
334impl ServerConfig {
335    /// Load configuration from environment variables with defaults
336    #[allow(clippy::field_reassign_with_default)] // Environment overrides require mutable config
337    pub fn from_env() -> Self {
338        let mut config = Self::default();
339
340        // Check production mode first
341        config.is_production = env::var("SHODH_ENV")
342            .map(|v| {
343                let v = v.to_lowercase();
344                v == "production" || v == "prod"
345            })
346            .unwrap_or(false);
347
348        // Host (bind address)
349        if let Ok(val) = env::var("SHODH_HOST") {
350            config.host = val;
351        }
352
353        // Port
354        if let Ok(val) = env::var("SHODH_PORT") {
355            if let Ok(port) = val.parse() {
356                config.port = port;
357            }
358        }
359
360        // Storage path
361        if let Ok(val) = env::var("SHODH_MEMORY_PATH") {
362            config.storage_path = PathBuf::from(val);
363        }
364
365        // Max users in memory
366        if let Ok(val) = env::var("SHODH_MAX_USERS") {
367            if let Ok(n) = val.parse() {
368                config.max_users_in_memory = n;
369            }
370        }
371
372        // Audit settings
373        if let Ok(val) = env::var("SHODH_AUDIT_MAX_ENTRIES") {
374            if let Ok(n) = val.parse() {
375                config.audit_max_entries_per_user = n;
376            }
377        }
378
379        if let Ok(val) = env::var("SHODH_AUDIT_RETENTION_DAYS") {
380            if let Ok(n) = val.parse() {
381                config.audit_retention_days = n;
382            }
383        }
384
385        // Rate limiting
386        if let Ok(val) = env::var("SHODH_RATE_LIMIT") {
387            if let Ok(n) = val.parse() {
388                config.rate_limit_per_second = n;
389            }
390        }
391
392        if let Ok(val) = env::var("SHODH_RATE_BURST") {
393            if let Ok(n) = val.parse() {
394                config.rate_limit_burst = n;
395            }
396        }
397
398        // Concurrency
399        if let Ok(val) = env::var("SHODH_MAX_CONCURRENT") {
400            if let Ok(n) = val.parse() {
401                config.max_concurrent_requests = n;
402            }
403        }
404
405        // Request timeout
406        if let Ok(val) = env::var("SHODH_REQUEST_TIMEOUT") {
407            if let Ok(n) = val.parse() {
408                config.request_timeout_secs = n;
409            }
410        }
411
412        // CORS configuration
413        config.cors = CorsConfig::from_env();
414
415        // Memory maintenance settings
416        if let Ok(val) = env::var("SHODH_MAINTENANCE_INTERVAL") {
417            if let Ok(n) = val.parse() {
418                config.maintenance_interval_secs = n;
419            }
420        }
421
422        if let Ok(val) = env::var("SHODH_ACTIVATION_DECAY") {
423            if let Ok(n) = val.parse::<f32>() {
424                let clamped = n.clamp(0.5, 0.99);
425                if (clamped - n).abs() > f32::EPSILON {
426                    tracing::warn!(
427                        "SHODH_ACTIVATION_DECAY={} clamped to {} (valid range: 0.5–0.99)",
428                        n,
429                        clamped
430                    );
431                }
432                config.activation_decay_factor = clamped;
433            }
434        }
435
436        // Backup configuration
437        if let Ok(val) = env::var("SHODH_BACKUP_INTERVAL") {
438            if let Ok(n) = val.parse::<u64>() {
439                if n == 0 {
440                    tracing::warn!(
441                        "SHODH_BACKUP_INTERVAL=0 — backups will run every maintenance cycle"
442                    );
443                }
444                config.backup_interval_secs = n;
445            }
446        }
447
448        if let Ok(val) = env::var("SHODH_BACKUP_MAX_COUNT") {
449            if let Ok(n) = val.parse() {
450                config.backup_max_count = n;
451            }
452        }
453
454        // Auto-enable backups in production mode unless explicitly disabled
455        if let Ok(val) = env::var("SHODH_BACKUP_ENABLED") {
456            config.backup_enabled = val.to_lowercase() == "true" || val == "1";
457        } else if config.is_production {
458            // Auto-enable in production
459            config.backup_enabled = true;
460        }
461
462        // Entity extraction cap
463        if let Ok(val) = env::var("SHODH_MAX_ENTITIES") {
464            if let Ok(n) = val.parse::<usize>() {
465                let clamped = n.clamp(1, 50);
466                if clamped != n {
467                    tracing::warn!(
468                        "SHODH_MAX_ENTITIES={} clamped to {} (valid range: 1–50)",
469                        n,
470                        clamped
471                    );
472                }
473                config.max_entities_per_memory = clamped;
474            }
475        }
476
477        config
478    }
479
480    /// Log the current configuration
481    pub fn log(&self) {
482        info!("📋 Configuration:");
483        info!(
484            "   Mode: {}",
485            if self.is_production {
486                "PRODUCTION"
487            } else {
488                "Development"
489            }
490        );
491        info!("   Port: {}", self.port);
492        info!("   Storage: {:?}", self.storage_path);
493        info!("   Max users in memory: {}", self.max_users_in_memory);
494        if self.rate_limit_per_second > 0 {
495            info!(
496                "   Rate limit: {} req/sec (burst: {})",
497                self.rate_limit_per_second, self.rate_limit_burst
498            );
499        } else {
500            info!("   Rate limit: disabled");
501        }
502        info!("   Max concurrent: {}", self.max_concurrent_requests);
503        info!("   Request timeout: {}s", self.request_timeout_secs);
504        info!("   Audit retention: {} days", self.audit_retention_days);
505        if self.cors.is_restricted() {
506            info!("   CORS origins: {:?}", self.cors.allowed_origins);
507        } else {
508            info!("   CORS: Permissive (all origins allowed)");
509        }
510        info!(
511            "   Maintenance interval: {}s (decay factor: {:.2})",
512            self.maintenance_interval_secs, self.activation_decay_factor
513        );
514        if self.backup_enabled {
515            let interval_hours = self.backup_interval_secs / 3600;
516            info!(
517                "   Backup: enabled (every {}h, keep {})",
518                interval_hours, self.backup_max_count
519            );
520        } else {
521            info!("   Backup: disabled");
522        }
523    }
524}
525
526/// Environment variable documentation
527#[allow(unused)] // Public API - available for CLI help output
528pub fn print_env_help() {
529    println!("Shodh-Memory Configuration Environment Variables:");
530    println!();
531    println!("  SHODH_ENV              - Set to 'production' or 'prod' for production mode");
532    println!(
533        "  SHODH_HOST             - Bind address (default: 127.0.0.1, use 0.0.0.0 for Docker)"
534    );
535    println!("  SHODH_PORT             - Server port (default: 3030)");
536    println!("  SHODH_MEMORY_PATH      - Storage directory (default: platform data dir, e.g. ~/.local/share/shodh-memory/)");
537    println!("  SHODH_API_KEYS         - Comma-separated API keys (required in production)");
538    println!("  SHODH_DEV_API_KEY      - Development API key (required in dev if SHODH_API_KEYS not set)");
539    println!("  SHODH_MAX_USERS        - Max users in memory LRU (default: 1000)");
540    println!("  SHODH_RATE_LIMIT       - Requests per second (default: 4000)");
541    println!("  SHODH_RATE_BURST       - Burst size (default: 8000)");
542    println!("  SHODH_MAX_CONCURRENT   - Max concurrent requests (default: 200)");
543    println!("  SHODH_REQUEST_TIMEOUT  - Request timeout in seconds (default: 60)");
544    println!("  SHODH_AUDIT_MAX_ENTRIES    - Max audit entries per user (default: 10000)");
545    println!("  SHODH_AUDIT_RETENTION_DAYS - Audit log retention days (default: 30)");
546    println!();
547    println!("Integration APIs:");
548    println!("  LINEAR_API_URL         - Linear GraphQL API URL (default: https://api.linear.app/graphql)");
549    println!("  LINEAR_WEBHOOK_SECRET  - Linear webhook signing secret for HMAC verification");
550    println!("  GITHUB_API_URL         - GitHub REST API URL (default: https://api.github.com)");
551    println!("  GITHUB_WEBHOOK_SECRET  - GitHub webhook secret for HMAC verification");
552    println!();
553    println!("CORS Configuration:");
554    println!("  SHODH_CORS_ORIGINS     - Comma-separated allowed origins (default: all)");
555    println!("  SHODH_CORS_METHODS     - Comma-separated allowed methods (default: GET,POST,PUT,DELETE,OPTIONS)");
556    println!("  SHODH_CORS_HEADERS     - Comma-separated allowed headers (default: Content-Type,Authorization,X-Request-ID)");
557    println!("  SHODH_CORS_CREDENTIALS - Allow credentials true/false (default: false)");
558    println!("  SHODH_CORS_MAX_AGE     - Preflight cache seconds (default: 86400)");
559    println!();
560    println!("Backup Configuration:");
561    println!("  SHODH_BACKUP_ENABLED   - Enable automatic backups true/false (default: auto in production)");
562    println!("  SHODH_BACKUP_INTERVAL  - Backup interval in seconds (default: 86400 = 24 hours)");
563    println!("  SHODH_BACKUP_MAX_COUNT - Max backups to keep per user (default: 7)");
564    println!();
565    println!("  RUST_LOG               - Log level (e.g., info, debug, trace)");
566    println!();
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_default_config() {
575        let config = ServerConfig::default();
576        assert_eq!(config.port, 3030);
577        assert_eq!(config.max_users_in_memory, 1000);
578        assert!(!config.is_production);
579    }
580
581    #[test]
582    fn test_env_override() {
583        env::set_var("SHODH_PORT", "8080");
584        env::set_var("SHODH_MAX_USERS", "500");
585
586        let config = ServerConfig::from_env();
587        assert_eq!(config.port, 8080);
588        assert_eq!(config.max_users_in_memory, 500);
589
590        env::remove_var("SHODH_PORT");
591        env::remove_var("SHODH_MAX_USERS");
592    }
593
594    #[test]
595    fn test_cors_default_is_permissive() {
596        let cors = CorsConfig::default();
597        assert!(!cors.is_restricted());
598        assert!(cors.allowed_origins.is_empty());
599        assert!(!cors.allowed_methods.is_empty());
600        assert!(!cors.allowed_headers.is_empty());
601    }
602
603    #[test]
604    fn test_cors_with_origins_is_restricted() {
605        let cors = CorsConfig {
606            allowed_origins: vec!["https://example.com".to_string()],
607            ..Default::default()
608        };
609        assert!(cors.is_restricted());
610    }
611
612    #[test]
613    fn test_cors_to_layer_permissive() {
614        let cors = CorsConfig::default();
615        let _layer = cors.to_layer(); // Should not panic
616    }
617
618    #[test]
619    fn test_cors_to_layer_restricted() {
620        let cors = CorsConfig {
621            allowed_origins: vec!["https://example.com".to_string()],
622            ..Default::default()
623        };
624        let _layer = cors.to_layer(); // Should not panic
625    }
626}