Skip to main content

sqz_engine/
engine.rs

1use std::path::Path;
2use std::sync::{Arc, Mutex};
3
4use crate::ast_parser::AstParser;
5use crate::budget_tracker::{BudgetTracker, UsageReport};
6use crate::cache_manager::CacheManager;
7use crate::confidence_router::ConfidenceRouter;
8use crate::cost_calculator::{CostCalculator, SessionCostSummary};
9use crate::ctx_format::CtxFormat;
10use crate::error::{Result, SqzError};
11use crate::model_router::ModelRouter;
12use crate::pin_manager::PinManager;
13use crate::pipeline::CompressionPipeline;
14use crate::plugin_api::PluginLoader;
15use crate::preset::{Preset, PresetParser};
16use crate::session_store::{SessionStore, SessionSummary};
17use crate::terse_mode::TerseMode;
18use crate::types::{CompressedContent, PinEntry, Provenance, SessionId};
19use crate::verifier::Verifier;
20
21/// Top-level facade that wires all sqz_engine modules together.
22///
23/// # Concurrency design
24///
25/// `SqzEngine` is designed for single-threaded use on the main thread.
26/// The only cross-thread sharing happens during preset hot-reload: the
27/// file-watcher callback runs on a background thread and needs to update
28/// the preset, pipeline, and model router. These three fields are wrapped
29/// in `Arc<Mutex<>>` specifically for that purpose. All other fields are
30/// owned directly — no unnecessary synchronization.
31pub struct SqzEngine {
32    // --- Hot-reloadable state (shared with file-watcher thread) ---
33    preset: Arc<Mutex<Preset>>,
34    pipeline: Arc<Mutex<CompressionPipeline>>,
35    model_router: Arc<Mutex<ModelRouter>>,
36
37    // --- Single-owner state (no cross-thread sharing needed) ---
38    session_store: SessionStore,
39    #[allow(dead_code)] // used internally by compress pipeline; public API pending
40    cache_manager: CacheManager,
41    budget_tracker: BudgetTracker,
42    cost_calculator: CostCalculator,
43    ast_parser: AstParser,
44    terse_mode: TerseMode,
45    pin_manager: PinManager,
46    confidence_router: ConfidenceRouter,
47    _plugin_loader: PluginLoader,
48}
49
50impl SqzEngine {
51    /// Create a new engine with the default preset and a persistent session store.
52    ///
53    /// Sessions are stored in `~/.sqz/sessions.db` for cross-session continuity.
54    /// Falls back to a temp-file store if the home directory is unavailable.
55    pub fn new() -> Result<Self> {
56        let preset = Preset::default();
57        let store_path = Self::default_store_path();
58        Self::with_preset_and_store(preset, &store_path)
59    }
60
61    /// Resolve the default session store path: `~/.sqz/sessions.db`.
62    /// Falls back to a temp-file path if home dir is unavailable.
63    fn default_store_path() -> std::path::PathBuf {
64        if let Some(home) = dirs_next::home_dir() {
65            let sqz_dir = home.join(".sqz");
66            if std::fs::create_dir_all(&sqz_dir).is_ok() {
67                return sqz_dir.join("sessions.db");
68            }
69        }
70        // Fallback: temp dir with unique name
71        let dir = std::env::temp_dir();
72        dir.join(format!(
73            "sqz_session_{}_{}.db",
74            std::process::id(),
75            std::time::SystemTime::now()
76                .duration_since(std::time::UNIX_EPOCH)
77                .map(|d| d.as_nanos())
78                .unwrap_or(0)
79        ))
80    }
81
82    /// Create with a custom preset and a file-backed session store.
83    ///
84    /// Opens a single SQLite connection for the session store. The cache
85    /// manager and pin manager share the same store via separate connections
86    /// (SQLite WAL mode supports concurrent readers).
87    pub fn with_preset_and_store(preset: Preset, store_path: &Path) -> Result<Self> {
88        let pipeline = CompressionPipeline::new(&preset);
89        let window_size = preset.budget.default_window_size;
90
91        // One connection per consumer. SQLite WAL mode handles concurrency.
92        let session_store = SessionStore::open_or_create(store_path)?;
93        let cache_store = SessionStore::open_or_create(store_path)?;
94        let pin_store = SessionStore::open_or_create(store_path)?;
95
96        Ok(SqzEngine {
97            preset: Arc::new(Mutex::new(preset.clone())),
98            pipeline: Arc::new(Mutex::new(pipeline)),
99            model_router: Arc::new(Mutex::new(ModelRouter::new(&preset))),
100            session_store,
101            cache_manager: CacheManager::new(cache_store, 512 * 1024 * 1024),
102            budget_tracker: BudgetTracker::new(window_size, &preset),
103            cost_calculator: CostCalculator::with_defaults(),
104            ast_parser: AstParser::new(),
105            terse_mode: TerseMode,
106            pin_manager: PinManager::new(pin_store),
107            confidence_router: ConfidenceRouter::new(),
108            _plugin_loader: PluginLoader::new(Path::new("plugins")),
109        })
110    }
111
112    /// Compress input text using the current preset.
113    ///
114    /// Two-pass pipeline:
115    /// 1. Route to compression mode based on content entropy and risk patterns.
116    /// 2. Compress using the pipeline (safe preset for Safe mode, default otherwise).
117    /// 3. Verify invariants (error lines, JSON keys, diff hunks, etc.).
118    /// 4. If verification confidence is low, fall back to safe mode and re-compress.
119    pub fn compress(&self, input: &str) -> Result<CompressedContent> {
120        let preset = self.preset.lock()
121            .map_err(|_| SqzError::Other("preset lock poisoned".into()))?;
122        let pipeline = self.pipeline.lock()
123            .map_err(|_| SqzError::Other("pipeline lock poisoned".into()))?;
124        let ctx = crate::pipeline::SessionContext {
125            session_id: "engine".to_string(),
126        };
127
128        // Step 1: Route — check content risk before compressing
129        let mode = self.confidence_router.route(input);
130
131        // Step 2: If Safe mode, skip aggressive pipeline and go straight to safe compress
132        if mode == crate::confidence_router::CompressionMode::Safe {
133            eprintln!("[sqz] fallback: safe mode — content classified as high-risk (stack trace / migration / secret)");
134            return self.compress_safe(input, &pipeline, &ctx);
135        }
136
137        // Step 3: Compress with the configured pipeline
138        let mut result = pipeline.compress(input, &ctx, &preset)?;
139
140        // Step 4: Verify invariants
141        let verify = Verifier::verify(input, &result.data);
142        let fallback = verify.fallback_triggered;
143        result.verify = Some(verify);
144
145        // Step 5: If verifier signals low confidence, re-compress with safe settings
146        if fallback && result.data != input {
147            eprintln!("[sqz] fallback: verifier confidence {:.2} below threshold — re-compressing in safe mode",
148                result.verify.as_ref().map(|v| v.confidence).unwrap_or(0.0));
149            let safe_result = self.compress_safe(input, &pipeline, &ctx)?;
150            return Ok(safe_result);
151        }
152
153        Ok(result)
154    }
155
156    /// Defensive compression: any input in, `CompressedContent` out, guaranteed.
157    ///
158    /// Unlike `compress()` which returns `Result`, this method never returns
159    /// an error. On any internal failure it returns the original input
160    /// unchanged with a 1.0 compression ratio. This makes it safe to call
161    /// from contexts where error handling is impractical (e.g. shell hooks,
162    /// browser extension bridges).
163    pub fn compress_or_passthrough(&self, input: &str) -> CompressedContent {
164        match self.compress(input) {
165            Ok(result) => result,
166            Err(_) => {
167                let tokens = (input.len() as u32 + 3) / 4;
168                CompressedContent {
169                    data: input.to_string(),
170                    tokens_compressed: tokens,
171                    tokens_original: tokens,
172                    stages_applied: vec![],
173                    compression_ratio: 1.0,
174                    provenance: crate::types::Provenance::default(),
175                    verify: None,
176                }
177            }
178        }
179    }
180
181    /// Compress with explicit mode override, bypassing the confidence router.
182    ///
183    /// - `CompressionMode::Safe` → safe pipeline only (ANSI strip + condense)
184    /// - `CompressionMode::Default` → standard pipeline
185    /// - `CompressionMode::Aggressive` → standard pipeline (aggressive preset TBD)
186    pub fn compress_with_mode(&self, input: &str, mode: crate::confidence_router::CompressionMode) -> Result<CompressedContent> {
187        let pipeline = self.pipeline.lock()
188            .map_err(|_| SqzError::Other("pipeline lock poisoned".into()))?;
189        let ctx = crate::pipeline::SessionContext {
190            session_id: "engine".to_string(),
191        };
192
193        match mode {
194            crate::confidence_router::CompressionMode::Safe => {
195                self.compress_safe(input, &pipeline, &ctx)
196            }
197            _ => {
198                // Default and Aggressive: run normal pipeline + verify
199                drop(pipeline); // release lock before calling compress()
200                self.compress(input)
201            }
202        }
203    }
204
205    /// Safe-mode compression: minimal transforms only (ANSI strip + condense).
206    fn compress_safe(
207        &self,
208        input: &str,
209        pipeline: &crate::pipeline::CompressionPipeline,
210        ctx: &crate::pipeline::SessionContext,
211    ) -> Result<CompressedContent> {
212        use crate::preset::{
213            CompressionConfig, CondenseConfig, CustomTransformsConfig, BudgetConfig,
214            ModelConfig, PresetMeta, TerseModeConfig, TerseLevel, ToolSelectionConfig,
215        };
216
217        let safe_preset = Preset {
218            preset: PresetMeta {
219                name: "safe".to_string(),
220                version: "1.0".to_string(),
221                description: "Safe fallback — minimal compression".to_string(),
222            },
223            compression: CompressionConfig {
224                stages: vec!["condense".to_string()],
225                keep_fields: None,
226                strip_fields: None,
227                condense: Some(CondenseConfig { enabled: true, max_repeated_lines: 3 }),
228                git_diff_fold: None,
229                strip_nulls: None,
230                flatten: None,
231                truncate_strings: None,
232                collapse_arrays: None,
233                custom_transforms: Some(CustomTransformsConfig { enabled: false }),
234            },
235            tool_selection: ToolSelectionConfig {
236                max_tools: 5,
237                similarity_threshold: 0.7,
238                default_tools: vec![],
239            },
240            budget: BudgetConfig {
241                warning_threshold: 0.70,
242                ceiling_threshold: 0.85,
243                default_window_size: 200_000,
244                agents: Default::default(),
245            },
246            terse_mode: TerseModeConfig { enabled: false, level: TerseLevel::Moderate },
247            model: ModelConfig {
248                family: "anthropic".to_string(),
249                primary: String::new(),
250                local: String::new(),
251                complexity_threshold: 0.4,
252                pricing: None,
253            },
254        };
255
256        let mut result = pipeline.compress(input, ctx, &safe_preset)?;
257        let verify = Verifier::verify(input, &result.data);
258        result.verify = Some(verify);
259        result.provenance = Provenance {
260            label: Some("safe-fallback".to_string()),
261            ..Default::default()
262        };
263        Ok(result)
264    }
265
266    /// Compress with explicit provenance metadata attached to the result.
267    pub fn compress_with_provenance(
268        &self,
269        input: &str,
270        provenance: Provenance,
271    ) -> Result<CompressedContent> {
272        let mut result = self.compress(input)?;
273        result.provenance = provenance;
274        Ok(result)
275    }
276
277    /// Export a session to CTX format.
278    pub fn export_ctx(&self, session_id: &str) -> Result<String> {
279        let session = self.session_store.load_session(session_id.to_string())?;
280        CtxFormat::serialize(&session)
281    }
282
283    /// Import a CTX string and save as a new session.
284    pub fn import_ctx(&self, ctx: &str) -> Result<SessionId> {
285        let session = CtxFormat::deserialize(ctx)?;
286        self.session_store.save_session(&session)
287    }
288
289    /// Pin a conversation turn.
290    pub fn pin(&self, session_id: &str, turn_index: usize, reason: &str, tokens: u32) -> Result<PinEntry> {
291        self.pin_manager.pin(session_id, turn_index, reason, tokens)
292    }
293
294    /// Unpin a conversation turn.
295    pub fn unpin(&self, session_id: &str, turn_index: usize) -> Result<()> {
296        self.pin_manager.unpin(session_id, turn_index)
297    }
298
299    /// Search sessions by keyword.
300    pub fn search_sessions(&self, query: &str) -> Result<Vec<SessionSummary>> {
301        self.session_store.search(query)
302    }
303
304    /// Get usage report for an agent.
305    pub fn usage_report(&self, agent_id: &str) -> UsageReport {
306        self.budget_tracker.usage_report(agent_id.to_string())
307    }
308
309    /// Get cost summary for a session.
310    pub fn cost_summary(&self, session_id: &str) -> Result<SessionCostSummary> {
311        let session = self.session_store.load_session(session_id.to_string())?;
312        Ok(self.cost_calculator.session_summary(&session))
313    }
314
315    /// Reload the preset from a TOML string (hot-reload support).
316    pub fn reload_preset(&mut self, toml: &str) -> Result<()> {
317        let new_preset = PresetParser::parse(toml)?;
318        if let Ok(mut pipeline) = self.pipeline.lock() {
319            pipeline.reload_preset(&new_preset)?;
320        }
321        if let Ok(mut router) = self.model_router.lock() {
322            *router = ModelRouter::new(&new_preset);
323        }
324        if let Ok(mut preset) = self.preset.lock() {
325            *preset = new_preset;
326        }
327        Ok(())
328    }
329
330    /// Spawn a background thread that watches `path` for preset file changes.
331    ///
332    /// Only the preset, pipeline, and model_router are shared with the watcher
333    /// thread (via `Arc<Mutex<>>`). All other engine state stays on the main thread.
334    pub fn watch_preset_file(&self, path: &Path) -> Result<notify::RecommendedWatcher> {
335        use notify::{Event, EventKind, RecursiveMode, Watcher};
336
337        let preset_arc = Arc::clone(&self.preset);
338        let pipeline_arc = Arc::clone(&self.pipeline);
339        let router_arc = Arc::clone(&self.model_router);
340        let watched_path = path.to_owned();
341
342        let mut watcher = notify::recommended_watcher(move |res: notify::Result<Event>| {
343            if let Ok(event) = res {
344                if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) {
345                    match std::fs::read_to_string(&watched_path) {
346                        Ok(toml_str) => match PresetParser::parse(&toml_str) {
347                            Ok(new_preset) => {
348                                if let Ok(mut p) = pipeline_arc.lock() {
349                                    let _ = p.reload_preset(&new_preset);
350                                }
351                                if let Ok(mut r) = router_arc.lock() {
352                                    *r = ModelRouter::new(&new_preset);
353                                }
354                                if let Ok(mut pr) = preset_arc.lock() {
355                                    *pr = new_preset;
356                                }
357                            }
358                            Err(e) => eprintln!("[sqz] invalid preset: {e}"),
359                        },
360                        Err(e) => eprintln!("[sqz] preset read error: {e}"),
361                    }
362                }
363            }
364        })
365        .map_err(|e| SqzError::Other(format!("watcher error: {e}")))?;
366
367        watcher
368            .watch(path, RecursiveMode::NonRecursive)
369            .map_err(|e| SqzError::Other(format!("watch error: {e}")))?;
370
371        Ok(watcher)
372    }
373
374    /// Access the underlying `SessionStore`.
375    pub fn session_store(&self) -> &SessionStore {
376        &self.session_store
377    }
378
379    /// Access the `CacheManager` for persistent dedup.
380    pub fn cache_manager(&self) -> &CacheManager {
381        &self.cache_manager
382    }
383
384    /// Access the `AstParser`.
385    pub fn ast_parser(&self) -> &AstParser {
386        &self.ast_parser
387    }
388
389    /// Access the `TerseMode` helper.
390    pub fn terse_mode(&self) -> &TerseMode {
391        &self.terse_mode
392    }
393
394    /// Reorder context sections using the LITM positioner to mitigate
395    /// the "Lost In The Middle" attention bias in long-context models.
396    ///
397    /// Places highest-priority sections at the beginning and end of the
398    /// context window, lowest-priority in the middle.
399    pub fn reorder_context(
400        &self,
401        sections: &mut Vec<crate::litm_positioner::ContextSection>,
402        strategy: crate::litm_positioner::LitmStrategy,
403    ) {
404        let positioner = crate::litm_positioner::LitmPositioner::new(strategy);
405        positioner.reorder(sections);
406    }
407
408    /// Route content to the appropriate compression mode based on entropy
409    /// and risk pattern analysis.
410    pub fn route_compression_mode(&self, content: &str) -> crate::confidence_router::CompressionMode {
411        self.confidence_router.route(content)
412    }
413}
414
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use crate::types::{BudgetState, CorrectionLog, ModelFamily, SessionState};
420    use chrono::Utc;
421    use std::path::PathBuf;
422
423    fn make_session(id: &str) -> SessionState {
424        let now = Utc::now();
425        SessionState {
426            id: id.to_string(),
427            project_dir: PathBuf::from("/tmp/test"),
428            conversation: vec![],
429            corrections: CorrectionLog::default(),
430            pins: vec![],
431            learnings: vec![],
432            compressed_summary: "test session".to_string(),
433            budget: BudgetState {
434                window_size: 200_000,
435                consumed: 0,
436                pinned: 0,
437                model_family: ModelFamily::AnthropicClaude,
438            },
439            tool_usage: vec![],
440            created_at: now,
441            updated_at: now,
442        }
443    }
444
445    #[test]
446    fn test_engine_new() {
447        let engine = SqzEngine::new();
448        assert!(engine.is_ok(), "SqzEngine::new() should succeed");
449    }
450
451    #[test]
452    fn test_compress_or_passthrough_returns_result_on_valid_input() {
453        let engine = SqzEngine::new().unwrap();
454        let result = engine.compress_or_passthrough("hello world");
455        assert_eq!(result.data, "hello world");
456        assert!(result.tokens_original > 0);
457    }
458
459    #[test]
460    fn test_compress_or_passthrough_never_panics_on_empty() {
461        let engine = SqzEngine::new().unwrap();
462        let result = engine.compress_or_passthrough("");
463        assert_eq!(result.data, "");
464        assert_eq!(result.compression_ratio, 1.0);
465    }
466
467    #[test]
468    fn test_compress_or_passthrough_handles_json() {
469        let engine = SqzEngine::new().unwrap();
470        let result = engine.compress_or_passthrough(r#"{"key":"value"}"#);
471        // Should compress successfully — data may be TOON-encoded
472        assert!(!result.data.is_empty());
473    }
474
475    #[test]
476    fn test_compress_or_passthrough_handles_binary_garbage() {
477        let engine = SqzEngine::new().unwrap();
478        // Feed it something weird — should never panic, always return something
479        let garbage = "\x00\x01\x02\x7f invalid control chars \t\n\r";
480        let result = engine.compress_or_passthrough(garbage);
481        assert!(!result.data.is_empty());
482    }
483
484    #[test]
485    fn test_compress_plain_text() {
486        let engine = SqzEngine::new().unwrap();
487        let result = engine.compress("hello world");
488        assert!(result.is_ok());
489        assert_eq!(result.unwrap().data, "hello world");
490    }
491
492    #[test]
493    fn test_compress_json_applies_toon() {
494        let engine = SqzEngine::new().unwrap();
495        let result = engine.compress(r#"{"name":"Alice","age":30}"#).unwrap();
496        assert!(result.data.starts_with("TOON:"), "JSON should be TOON-encoded");
497    }
498
499    #[test]
500    fn test_export_import_ctx_round_trip() {
501        let dir = tempfile::tempdir().unwrap();
502        let store_path = dir.path().join("store.db");
503        let engine = SqzEngine::with_preset_and_store(Preset::default(), &store_path).unwrap();
504
505        let session = make_session("sess-rt");
506        engine.session_store().save_session(&session).unwrap();
507
508        let ctx = engine.export_ctx("sess-rt").unwrap();
509        let imported_id = engine.import_ctx(&ctx).unwrap();
510        assert_eq!(imported_id, "sess-rt");
511    }
512
513    #[test]
514    fn test_search_sessions() {
515        let dir = tempfile::tempdir().unwrap();
516        let store_path = dir.path().join("store.db");
517        let engine = SqzEngine::with_preset_and_store(Preset::default(), &store_path).unwrap();
518
519        let mut session = make_session("sess-search");
520        session.compressed_summary = "authentication refactor".to_string();
521        engine.session_store().save_session(&session).unwrap();
522
523        let results = engine.search_sessions("authentication").unwrap();
524        assert_eq!(results.len(), 1);
525        assert_eq!(results[0].id, "sess-search");
526    }
527
528    #[test]
529    fn test_usage_report_starts_at_zero() {
530        let engine = SqzEngine::new().unwrap();
531        let report = engine.usage_report("default");
532        assert_eq!(report.consumed, 0);
533        assert_eq!(report.available, report.allocated);
534    }
535
536    #[test]
537    fn test_cost_summary() {
538        let dir = tempfile::tempdir().unwrap();
539        let store_path = dir.path().join("store.db");
540        let engine = SqzEngine::with_preset_and_store(Preset::default(), &store_path).unwrap();
541
542        let session = make_session("sess-cost");
543        engine.session_store().save_session(&session).unwrap();
544
545        let summary = engine.cost_summary("sess-cost").unwrap();
546        assert_eq!(summary.total_tokens, 0);
547        assert!((summary.total_usd - 0.0).abs() < f64::EPSILON);
548    }
549
550    #[test]
551    fn test_reload_preset_updates_state() {
552        let mut engine = SqzEngine::new().unwrap();
553        let toml = r#"
554[preset]
555name = "reloaded"
556version = "2.0"
557
558[compression]
559stages = []
560
561[tool_selection]
562max_tools = 5
563similarity_threshold = 0.7
564
565[budget]
566warning_threshold = 0.70
567ceiling_threshold = 0.85
568default_window_size = 200000
569
570[terse_mode]
571enabled = false
572level = "moderate"
573
574[model]
575family = "anthropic"
576primary = "claude-sonnet-4-20250514"
577complexity_threshold = 0.4
578"#;
579        assert!(engine.reload_preset(toml).is_ok());
580        // Verify the preset was actually updated
581        let preset = engine.preset.lock().unwrap();
582        assert_eq!(preset.preset.name, "reloaded");
583    }
584
585    #[test]
586    fn test_reload_invalid_preset_returns_error() {
587        let mut engine = SqzEngine::new().unwrap();
588        let result = engine.reload_preset("not valid toml [[[");
589        assert!(result.is_err(), "invalid TOML should return error");
590    }
591
592    #[test]
593    fn test_export_nonexistent_session_returns_error() {
594        let engine = SqzEngine::new().unwrap();
595        let result = engine.export_ctx("does-not-exist");
596        assert!(result.is_err());
597    }
598
599    #[test]
600    fn test_import_invalid_ctx_returns_error() {
601        let engine = SqzEngine::new().unwrap();
602        let result = engine.import_ctx("not valid json {{{");
603        assert!(result.is_err());
604    }
605}