Skip to main content

sqz_engine/
engine.rs

1use std::path::Path;
2use std::sync::{Arc, Mutex};
3
4use crate::ast_parser::AstParser;
5use crate::budget_tracker::{BudgetTracker, UsageReport};
6use crate::cache_manager::CacheManager;
7use crate::cost_calculator::{CostCalculator, SessionCostSummary};
8use crate::ctx_format::CtxFormat;
9use crate::error::{Result, SqzError};
10use crate::model_router::ModelRouter;
11use crate::pin_manager::PinManager;
12use crate::pipeline::CompressionPipeline;
13use crate::plugin_api::PluginLoader;
14use crate::preset::{Preset, PresetParser};
15use crate::session_store::{SessionStore, SessionSummary};
16use crate::terse_mode::TerseMode;
17use crate::types::{CompressedContent, PinEntry, SessionId};
18
19/// Top-level facade that wires all sqz_engine modules together.
20///
21/// # Concurrency design
22///
23/// `SqzEngine` is designed for single-threaded use on the main thread.
24/// The only cross-thread sharing happens during preset hot-reload: the
25/// file-watcher callback runs on a background thread and needs to update
26/// the preset, pipeline, and model router. These three fields are wrapped
27/// in `Arc<Mutex<>>` specifically for that purpose. All other fields are
28/// owned directly — no unnecessary synchronization.
29pub struct SqzEngine {
30    // --- Hot-reloadable state (shared with file-watcher thread) ---
31    preset: Arc<Mutex<Preset>>,
32    pipeline: Arc<Mutex<CompressionPipeline>>,
33    model_router: Arc<Mutex<ModelRouter>>,
34
35    // --- Single-owner state (no cross-thread sharing needed) ---
36    session_store: SessionStore,
37    #[allow(dead_code)] // used internally by compress pipeline; public API pending
38    cache_manager: CacheManager,
39    budget_tracker: BudgetTracker,
40    cost_calculator: CostCalculator,
41    ast_parser: AstParser,
42    terse_mode: TerseMode,
43    pin_manager: PinManager,
44    _plugin_loader: PluginLoader,
45}
46
47impl SqzEngine {
48    /// Create a new engine with the default preset and a persistent session store.
49    ///
50    /// Sessions are stored in `~/.sqz/sessions.db` for cross-session continuity.
51    /// Falls back to a temp-file store if the home directory is unavailable.
52    pub fn new() -> Result<Self> {
53        let preset = Preset::default();
54        let store_path = Self::default_store_path();
55        Self::with_preset_and_store(preset, &store_path)
56    }
57
58    /// Resolve the default session store path: `~/.sqz/sessions.db`.
59    /// Falls back to a temp-file path if home dir is unavailable.
60    fn default_store_path() -> std::path::PathBuf {
61        if let Some(home) = dirs_next::home_dir() {
62            let sqz_dir = home.join(".sqz");
63            if std::fs::create_dir_all(&sqz_dir).is_ok() {
64                return sqz_dir.join("sessions.db");
65            }
66        }
67        // Fallback: temp dir with unique name
68        let dir = std::env::temp_dir();
69        dir.join(format!(
70            "sqz_session_{}_{}.db",
71            std::process::id(),
72            std::time::SystemTime::now()
73                .duration_since(std::time::UNIX_EPOCH)
74                .map(|d| d.as_nanos())
75                .unwrap_or(0)
76        ))
77    }
78
79    /// Create with a custom preset and a file-backed session store.
80    ///
81    /// Opens a single SQLite connection for the session store. The cache
82    /// manager and pin manager share the same store via separate connections
83    /// (SQLite WAL mode supports concurrent readers).
84    pub fn with_preset_and_store(preset: Preset, store_path: &Path) -> Result<Self> {
85        let pipeline = CompressionPipeline::new(&preset);
86        let window_size = preset.budget.default_window_size;
87
88        // One connection per consumer. SQLite WAL mode handles concurrency.
89        let session_store = SessionStore::open_or_create(store_path)?;
90        let cache_store = SessionStore::open_or_create(store_path)?;
91        let pin_store = SessionStore::open_or_create(store_path)?;
92
93        Ok(SqzEngine {
94            preset: Arc::new(Mutex::new(preset.clone())),
95            pipeline: Arc::new(Mutex::new(pipeline)),
96            model_router: Arc::new(Mutex::new(ModelRouter::new(&preset))),
97            session_store,
98            cache_manager: CacheManager::new(cache_store, 512 * 1024 * 1024),
99            budget_tracker: BudgetTracker::new(window_size, &preset),
100            cost_calculator: CostCalculator::with_defaults(),
101            ast_parser: AstParser::new(),
102            terse_mode: TerseMode,
103            pin_manager: PinManager::new(pin_store),
104            _plugin_loader: PluginLoader::new(Path::new("plugins")),
105        })
106    }
107
108    /// Compress input text using the current preset.
109    pub fn compress(&self, input: &str) -> Result<CompressedContent> {
110        let preset = self.preset.lock()
111            .map_err(|_| SqzError::Other("preset lock poisoned".into()))?;
112        let pipeline = self.pipeline.lock()
113            .map_err(|_| SqzError::Other("pipeline lock poisoned".into()))?;
114        let ctx = crate::pipeline::SessionContext {
115            session_id: "engine".to_string(),
116        };
117        pipeline.compress(input, &ctx, &preset)
118    }
119
120    /// Export a session to CTX format.
121    pub fn export_ctx(&self, session_id: &str) -> Result<String> {
122        let session = self.session_store.load_session(session_id.to_string())?;
123        CtxFormat::serialize(&session)
124    }
125
126    /// Import a CTX string and save as a new session.
127    pub fn import_ctx(&self, ctx: &str) -> Result<SessionId> {
128        let session = CtxFormat::deserialize(ctx)?;
129        self.session_store.save_session(&session)
130    }
131
132    /// Pin a conversation turn.
133    pub fn pin(&self, session_id: &str, turn_index: usize, reason: &str, tokens: u32) -> Result<PinEntry> {
134        self.pin_manager.pin(session_id, turn_index, reason, tokens)
135    }
136
137    /// Unpin a conversation turn.
138    pub fn unpin(&self, session_id: &str, turn_index: usize) -> Result<()> {
139        self.pin_manager.unpin(session_id, turn_index)
140    }
141
142    /// Search sessions by keyword.
143    pub fn search_sessions(&self, query: &str) -> Result<Vec<SessionSummary>> {
144        self.session_store.search(query)
145    }
146
147    /// Get usage report for an agent.
148    pub fn usage_report(&self, agent_id: &str) -> UsageReport {
149        self.budget_tracker.usage_report(agent_id.to_string())
150    }
151
152    /// Get cost summary for a session.
153    pub fn cost_summary(&self, session_id: &str) -> Result<SessionCostSummary> {
154        let session = self.session_store.load_session(session_id.to_string())?;
155        Ok(self.cost_calculator.session_summary(&session))
156    }
157
158    /// Reload the preset from a TOML string (hot-reload support).
159    pub fn reload_preset(&mut self, toml: &str) -> Result<()> {
160        let new_preset = PresetParser::parse(toml)?;
161        if let Ok(mut pipeline) = self.pipeline.lock() {
162            pipeline.reload_preset(&new_preset)?;
163        }
164        if let Ok(mut router) = self.model_router.lock() {
165            *router = ModelRouter::new(&new_preset);
166        }
167        if let Ok(mut preset) = self.preset.lock() {
168            *preset = new_preset;
169        }
170        Ok(())
171    }
172
173    /// Spawn a background thread that watches `path` for preset file changes.
174    ///
175    /// Only the preset, pipeline, and model_router are shared with the watcher
176    /// thread (via `Arc<Mutex<>>`). All other engine state stays on the main thread.
177    pub fn watch_preset_file(&self, path: &Path) -> Result<notify::RecommendedWatcher> {
178        use notify::{Event, EventKind, RecursiveMode, Watcher};
179
180        let preset_arc = Arc::clone(&self.preset);
181        let pipeline_arc = Arc::clone(&self.pipeline);
182        let router_arc = Arc::clone(&self.model_router);
183        let watched_path = path.to_owned();
184
185        let mut watcher = notify::recommended_watcher(move |res: notify::Result<Event>| {
186            if let Ok(event) = res {
187                if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) {
188                    match std::fs::read_to_string(&watched_path) {
189                        Ok(toml_str) => match PresetParser::parse(&toml_str) {
190                            Ok(new_preset) => {
191                                if let Ok(mut p) = pipeline_arc.lock() {
192                                    let _ = p.reload_preset(&new_preset);
193                                }
194                                if let Ok(mut r) = router_arc.lock() {
195                                    *r = ModelRouter::new(&new_preset);
196                                }
197                                if let Ok(mut pr) = preset_arc.lock() {
198                                    *pr = new_preset;
199                                }
200                            }
201                            Err(e) => eprintln!("[sqz] invalid preset: {e}"),
202                        },
203                        Err(e) => eprintln!("[sqz] preset read error: {e}"),
204                    }
205                }
206            }
207        })
208        .map_err(|e| SqzError::Other(format!("watcher error: {e}")))?;
209
210        watcher
211            .watch(path, RecursiveMode::NonRecursive)
212            .map_err(|e| SqzError::Other(format!("watch error: {e}")))?;
213
214        Ok(watcher)
215    }
216
217    /// Access the underlying `SessionStore`.
218    pub fn session_store(&self) -> &SessionStore {
219        &self.session_store
220    }
221
222    /// Access the `AstParser`.
223    pub fn ast_parser(&self) -> &AstParser {
224        &self.ast_parser
225    }
226
227    /// Access the `TerseMode` helper.
228    pub fn terse_mode(&self) -> &TerseMode {
229        &self.terse_mode
230    }
231}
232
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::types::{BudgetState, CorrectionLog, ModelFamily, SessionState};
238    use chrono::Utc;
239    use std::path::PathBuf;
240
241    fn make_session(id: &str) -> SessionState {
242        let now = Utc::now();
243        SessionState {
244            id: id.to_string(),
245            project_dir: PathBuf::from("/tmp/test"),
246            conversation: vec![],
247            corrections: CorrectionLog::default(),
248            pins: vec![],
249            learnings: vec![],
250            compressed_summary: "test session".to_string(),
251            budget: BudgetState {
252                window_size: 200_000,
253                consumed: 0,
254                pinned: 0,
255                model_family: ModelFamily::AnthropicClaude,
256            },
257            tool_usage: vec![],
258            created_at: now,
259            updated_at: now,
260        }
261    }
262
263    #[test]
264    fn test_engine_new() {
265        let engine = SqzEngine::new();
266        assert!(engine.is_ok(), "SqzEngine::new() should succeed");
267    }
268
269    #[test]
270    fn test_compress_plain_text() {
271        let engine = SqzEngine::new().unwrap();
272        let result = engine.compress("hello world");
273        assert!(result.is_ok());
274        assert_eq!(result.unwrap().data, "hello world");
275    }
276
277    #[test]
278    fn test_compress_json_applies_toon() {
279        let engine = SqzEngine::new().unwrap();
280        let result = engine.compress(r#"{"name":"Alice","age":30}"#).unwrap();
281        assert!(result.data.starts_with("TOON:"), "JSON should be TOON-encoded");
282    }
283
284    #[test]
285    fn test_export_import_ctx_round_trip() {
286        let dir = tempfile::tempdir().unwrap();
287        let store_path = dir.path().join("store.db");
288        let engine = SqzEngine::with_preset_and_store(Preset::default(), &store_path).unwrap();
289
290        let session = make_session("sess-rt");
291        engine.session_store().save_session(&session).unwrap();
292
293        let ctx = engine.export_ctx("sess-rt").unwrap();
294        let imported_id = engine.import_ctx(&ctx).unwrap();
295        assert_eq!(imported_id, "sess-rt");
296    }
297
298    #[test]
299    fn test_search_sessions() {
300        let dir = tempfile::tempdir().unwrap();
301        let store_path = dir.path().join("store.db");
302        let engine = SqzEngine::with_preset_and_store(Preset::default(), &store_path).unwrap();
303
304        let mut session = make_session("sess-search");
305        session.compressed_summary = "authentication refactor".to_string();
306        engine.session_store().save_session(&session).unwrap();
307
308        let results = engine.search_sessions("authentication").unwrap();
309        assert_eq!(results.len(), 1);
310        assert_eq!(results[0].id, "sess-search");
311    }
312
313    #[test]
314    fn test_usage_report_starts_at_zero() {
315        let engine = SqzEngine::new().unwrap();
316        let report = engine.usage_report("default");
317        assert_eq!(report.consumed, 0);
318        assert_eq!(report.available, report.allocated);
319    }
320
321    #[test]
322    fn test_cost_summary() {
323        let dir = tempfile::tempdir().unwrap();
324        let store_path = dir.path().join("store.db");
325        let engine = SqzEngine::with_preset_and_store(Preset::default(), &store_path).unwrap();
326
327        let session = make_session("sess-cost");
328        engine.session_store().save_session(&session).unwrap();
329
330        let summary = engine.cost_summary("sess-cost").unwrap();
331        assert_eq!(summary.total_tokens, 0);
332        assert!((summary.total_usd - 0.0).abs() < f64::EPSILON);
333    }
334
335    #[test]
336    fn test_reload_preset_updates_state() {
337        let mut engine = SqzEngine::new().unwrap();
338        let toml = r#"
339[preset]
340name = "reloaded"
341version = "2.0"
342
343[compression]
344stages = []
345
346[tool_selection]
347max_tools = 5
348similarity_threshold = 0.7
349
350[budget]
351warning_threshold = 0.70
352ceiling_threshold = 0.85
353default_window_size = 200000
354
355[terse_mode]
356enabled = false
357level = "moderate"
358
359[model]
360family = "anthropic"
361primary = "claude-sonnet-4-20250514"
362complexity_threshold = 0.4
363"#;
364        assert!(engine.reload_preset(toml).is_ok());
365        // Verify the preset was actually updated
366        let preset = engine.preset.lock().unwrap();
367        assert_eq!(preset.preset.name, "reloaded");
368    }
369
370    #[test]
371    fn test_reload_invalid_preset_returns_error() {
372        let mut engine = SqzEngine::new().unwrap();
373        let result = engine.reload_preset("not valid toml [[[");
374        assert!(result.is_err(), "invalid TOML should return error");
375    }
376
377    #[test]
378    fn test_export_nonexistent_session_returns_error() {
379        let engine = SqzEngine::new().unwrap();
380        let result = engine.export_ctx("does-not-exist");
381        assert!(result.is_err());
382    }
383
384    #[test]
385    fn test_import_invalid_ctx_returns_error() {
386        let engine = SqzEngine::new().unwrap();
387        let result = engine.import_ctx("not valid json {{{");
388        assert!(result.is_err());
389    }
390}