Skip to main content

sparrow/
extras.rs

1use crate::engine::{Engine, Task};
2use crate::event::Event;
3use crate::memory::{Fact, Memory};
4use std::sync::Arc;
5use tokio::sync::mpsc;
6
7// ─── Auto-distillation ─────────────────────────────────────────────────────────
8
9/// After a successful run, extract durable facts about the user
10/// from the conversation trajectory.
11/// §3.8: "after sessions, distill durable facts/preferences into identity/facts_about_user"
12pub struct Distiller;
13
14impl Distiller {
15    /// Analyze run events and extract durable facts about the user and project.
16    /// Called automatically after every successful run (§3.8).
17    pub async fn distill(memory: &Arc<dyn Memory>, events: &[Event], task_description: &str) {
18        let mut facts = Vec::new();
19
20        // ── 1. Languages & frameworks (from file paths) ──────────────────────
21        let mut lang_hints: Vec<String> = Vec::new();
22        let mut framework_hints: Vec<String> = Vec::new();
23        let mut tool_usage: std::collections::HashMap<String, u32> =
24            std::collections::HashMap::new();
25        let mut style_hints: Vec<String> = Vec::new();
26        let mut pref_hints: Vec<String> = Vec::new();
27        let mut convention_hints: Vec<String> = Vec::new();
28        let mut directive_hints: Vec<String> = Vec::new();
29
30        for event in events {
31            match event {
32                Event::ToolUseProposed { name, args, .. } => {
33                    // Track tool usage frequency
34                    *tool_usage.entry(name.clone()).or_insert(0) += 1;
35
36                    // Languages from file extensions
37                    if let Some(path) = args.get("path").and_then(|v| v.as_str()) {
38                        detect_languages(path, &mut lang_hints);
39                        detect_conventions(path, &mut convention_hints);
40                    }
41                    // Frameworks from content
42                    if let Some(content) = args.get("content").and_then(|v| v.as_str()) {
43                        detect_frameworks(content, &mut framework_hints);
44                    }
45                }
46                Event::ThinkingDelta { text, .. } => {
47                    // Style preferences
48                    if text.contains("refactor") {
49                        style_hints.push("refactoring-oriented".to_string());
50                    }
51                    if text.contains("test") || text.contains("TDD") {
52                        style_hints.push("test-driven".to_string());
53                    }
54                    if text.contains("async") || text.contains("await") {
55                        style_hints.push("async-first".to_string());
56                    }
57                    // Explicit user preferences mentioned in thinking
58                    detect_preferences(text, &mut pref_hints);
59                }
60                Event::Message { text, role, .. } if role == "user" => {
61                    // User messages often contain preferences
62                    detect_preferences(text, &mut pref_hints);
63                    // Explicit durable directives ("remember that…", "I prefer…",
64                    // "always…", "my name is…"). Captured verbatim so memory keeps
65                    // the real instruction, not a generic hint.
66                    detect_directives(text, &mut directive_hints);
67                }
68                _ => {}
69            }
70        }
71
72        // ── 2. Deduplicate ──────────────────────────────────────────────────
73        dedup(&mut lang_hints);
74        dedup(&mut framework_hints);
75        dedup(&mut style_hints);
76        dedup(&mut pref_hints);
77        dedup(&mut convention_hints);
78        dedup(&mut directive_hints);
79
80        // ── 3. Save facts ───────────────────────────────────────────────────
81        for lang in &lang_hints {
82            facts.push(fact("user:language", lang));
83        }
84        for fw in &framework_hints {
85            facts.push(fact("user:framework", fw));
86        }
87        for style in &style_hints {
88            facts.push(fact("user:style", style));
89        }
90        for pref in &pref_hints {
91            facts.push(fact("user:preference", pref));
92        }
93        for conv in &convention_hints {
94            facts.push(fact("project:convention", conv));
95        }
96        // Each captured directive becomes its own fact keyed by a stable hash of
97        // its text, so distinct directives never collide on a shared key.
98        for d in &directive_hints {
99            facts.push(fact(&format!("user:directive:{}", short_hash(d)), d));
100        }
101        // Tools used 3+ times become learned preferences
102        for (tool, count) in &tool_usage {
103            if *count >= 3 {
104                facts.push(fact(
105                    "user:frequent_tool",
106                    &format!("uses {} frequently ({}x this session)", tool, count),
107                ));
108            }
109        }
110
111        // ── 4. Persist (skip duplicates) ────────────────────────────────────
112        // v0.9.1 fix: deduplicate on the (key, value) PAIR, not the key alone.
113        // The keys here are generic buckets (`user:preference`, `user:language`,
114        // `user:directive`…). The previous `existing_keys.contains(key)` check
115        // meant that once ONE `user:preference` existed, no further preference
116        // — with a different value — was ever stored. Memory saturated after the
117        // first run and stopped learning. Comparing the full pair lets every new
118        // distinct fact land while still skipping exact repeats.
119        let existing = memory.all_facts();
120        let existing_pairs: std::collections::HashSet<(String, String)> = existing
121            .iter()
122            .map(|f| (f.key.clone(), f.value.clone()))
123            .collect();
124        let mut saved = 0;
125
126        for fact in &facts {
127            if !existing_pairs.contains(&(fact.key.clone(), fact.value.clone())) {
128                let _ = memory.remember(fact.clone());
129                saved += 1;
130            }
131        }
132
133        if saved > 0 {
134            tracing::info!(
135                "Distiller: extracted {} facts ({} new) from task: {}",
136                facts.len(),
137                saved,
138                &task_description[..task_description.len().min(60)]
139            );
140        }
141    }
142}
143
144// ─── Distiller helpers ─────────────────────────────────────────────────────────
145
146fn detect_languages(path: &str, hints: &mut Vec<String>) {
147    let ext_map: &[(&str, &str)] = &[
148        (".rs", "Rust"),
149        (".ts", "TypeScript"),
150        (".tsx", "TypeScript/React"),
151        (".py", "Python"),
152        (".go", "Go"),
153        (".js", "JavaScript"),
154        (".jsx", "JavaScript/React"),
155        (".java", "Java"),
156        (".rb", "Ruby"),
157        (".css", "CSS"),
158        (".html", "HTML"),
159        (".sql", "SQL"),
160        (".tf", "Terraform"),
161        (".yml", "YAML"),
162        (".yaml", "YAML"),
163        (".toml", "TOML"),
164        (".json", "JSON"),
165        (".md", "Markdown"),
166        (".sh", "Shell"),
167    ];
168    let lower = path.to_lowercase();
169    for (ext, lang) in ext_map {
170        if lower.ends_with(ext) {
171            hints.push(lang.to_string());
172            return;
173        }
174    }
175}
176
177fn detect_frameworks(content: &str, hints: &mut Vec<String>) {
178    let fw_map: &[(&str, &str)] = &[
179        ("Cargo.toml", "Rust/Cargo"),
180        ("package.json", "Node.js"),
181        ("go.mod", "Go modules"),
182        ("requirements.txt", "Python/pip"),
183        ("pyproject.toml", "Python/poetry"),
184        ("Dockerfile", "Docker"),
185        ("docker-compose", "Docker Compose"),
186        ("Makefile", "Make"),
187        ("CMakeLists.txt", "CMake"),
188        ("pom.xml", "Java/Maven"),
189        ("build.gradle", "Java/Gradle"),
190    ];
191    for (pattern, fw) in fw_map {
192        if content.contains(pattern) {
193            hints.push(fw.to_string());
194        }
195    }
196}
197
198fn detect_preferences(text: &str, hints: &mut Vec<String>) {
199    let pref_patterns: &[(&str, &str)] = &[
200        ("prefer async", "prefers async/await"),
201        ("prefer sync", "prefers synchronous code"),
202        ("use tabs", "uses tabs for indentation"),
203        ("use spaces", "uses spaces for indentation"),
204        (
205            "prefer unwrap",
206            "prefers .unwrap() over proper error handling",
207        ),
208        ("prefer anyhow", "prefers anyhow for error handling"),
209        ("instead of", "has strong opinions about alternatives"),
210        ("don't use", "has explicit dislikes"),
211        ("always use", "has explicit preferences"),
212        ("I like", "expressed a personal preference"),
213        ("I want", "expressed a desire"),
214    ];
215    let lower = text.to_lowercase();
216    for (pattern, hint) in pref_patterns {
217        if lower.contains(pattern) {
218            hints.push(hint.to_string());
219        }
220    }
221}
222
223/// Capture explicit durable directives from a user message — the kind of thing
224/// a user expects the agent to *remember* across sessions. We keep the user's
225/// actual sentence (trimmed to one line, capped) rather than a generic hint, so
226/// recall is faithful. Triggers on common FR/EN durable-intent markers.
227fn detect_directives(text: &str, hints: &mut Vec<String>) {
228    const MARKERS: &[&str] = &[
229        // English
230        "remember that",
231        "remember to",
232        "don't forget",
233        "keep in mind",
234        "note that",
235        "from now on",
236        "always ",
237        "never ",
238        "my name is",
239        "i prefer",
240        "i want you to",
241        "make sure to",
242        "going forward",
243        // French
244        "retiens",
245        "souviens-toi",
246        "souviens toi",
247        "n'oublie pas",
248        "note que",
249        "désormais",
250        "dorénavant",
251        "toujours ",
252        "jamais ",
253        "je m'appelle",
254        "je préfère",
255        "je veux que",
256        "à partir de maintenant",
257        "rappelle-toi",
258    ];
259    let lower = text.to_lowercase();
260    for line in text.lines() {
261        let line_l = line.to_lowercase();
262        if MARKERS.iter().any(|m| line_l.contains(m)) {
263            let cleaned = line.trim();
264            if cleaned.len() >= 4 {
265                // Cap to keep a fact compact; preserve the real instruction.
266                let capped: String = cleaned.chars().take(220).collect();
267                hints.push(capped);
268            }
269        }
270    }
271    // Fallback: single-line message with a marker but no line break handled above.
272    if hints.is_empty() && MARKERS.iter().any(|m| lower.contains(m)) {
273        let capped: String = text.trim().chars().take(220).collect();
274        if capped.len() >= 4 {
275            hints.push(capped);
276        }
277    }
278}
279
280/// Short stable hex hash for deriving collision-free fact keys from text.
281fn short_hash(s: &str) -> String {
282    use std::hash::{Hash, Hasher};
283    let mut h = std::collections::hash_map::DefaultHasher::new();
284    s.trim().to_lowercase().hash(&mut h);
285    format!("{:08x}", (h.finish() & 0xffff_ffff) as u32)
286}
287
288fn detect_conventions(path: &str, hints: &mut Vec<String>) {
289    let conv_patterns: &[(&str, &str)] = &[
290        ("src/main.rs", "Rust binary project structure"),
291        ("src/lib.rs", "Rust library project structure"),
292        ("src/index.ts", "TypeScript entry point convention"),
293        ("src/app.py", "Python app entry point"),
294        ("tests/", "has a test directory"),
295        ("spec/", "has a spec directory"),
296        ("docs/", "maintains documentation"),
297        (".github/workflows/", "uses GitHub Actions CI"),
298        (".gitignore", "has gitignore"),
299    ];
300    let lower = path.to_lowercase();
301    for (pattern, hint) in conv_patterns {
302        if lower.contains(&pattern.to_lowercase()) {
303            hints.push(hint.to_string());
304        }
305    }
306}
307
308fn dedup(v: &mut Vec<String>) {
309    v.sort();
310    v.dedup();
311}
312
313fn fact(key: &str, value: &str) -> Fact {
314    Fact {
315        id: uuid::Uuid::new_v4().to_string(),
316        key: key.to_string(),
317        value: value.to_string(),
318        created_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
319        updated_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
320    }
321}
322
323#[cfg(test)]
324mod distiller_tests {
325    use super::*;
326
327    #[test]
328    fn detect_directives_captures_english_durable_instructions() {
329        let mut h = Vec::new();
330        detect_directives("Remember that I want reports in ./artifacts", &mut h);
331        assert!(h.iter().any(|s| s.contains("reports in ./artifacts")));
332
333        let mut h2 = Vec::new();
334        detect_directives("From now on, always run the tests first", &mut h2);
335        assert_eq!(h2.len(), 1);
336    }
337
338    #[test]
339    fn detect_directives_captures_french_durable_instructions() {
340        let mut h = Vec::new();
341        detect_directives("Retiens que je m'appelle Abdou", &mut h);
342        assert!(h.iter().any(|s| s.contains("Abdou")));
343
344        let mut h2 = Vec::new();
345        detect_directives("Désormais, range les livrables dans ./out", &mut h2);
346        assert_eq!(h2.len(), 1);
347    }
348
349    #[test]
350    fn detect_directives_ignores_ordinary_text() {
351        let mut h = Vec::new();
352        detect_directives("Can you fix the bug in main.rs please?", &mut h);
353        assert!(h.is_empty());
354    }
355
356    #[test]
357    fn short_hash_is_stable_and_distinct() {
358        // Stable across calls, case/whitespace-insensitive, distinct for
359        // different content — so directive keys never collide.
360        assert_eq!(
361            short_hash("use ./artifacts"),
362            short_hash("  USE ./artifacts  ")
363        );
364        assert_ne!(short_hash("prefer async"), short_hash("prefer sync"));
365        assert_eq!(short_hash("x").len(), 8);
366    }
367}
368
369// ─── Lightweight deterministic embeddings ──────────────────────────────────────
370
371/// Lightweight semantic embeddings for repo memory.
372/// §3.8: "optional embeddings (per project)"
373#[derive(Debug, Clone)]
374pub struct Embeddings {
375    /// Stored text + normalized hashing-vector embedding.
376    pub vectors: Vec<(String, Vec<f64>)>,
377    dimensions: usize,
378}
379
380impl Embeddings {
381    pub const DEFAULT_DIMENSIONS: usize = 512;
382
383    pub fn new() -> Self {
384        Self {
385            vectors: Vec::new(),
386            dimensions: Self::DEFAULT_DIMENSIONS,
387        }
388    }
389
390    pub fn with_dimensions(dimensions: usize) -> Self {
391        Self {
392            vectors: Vec::new(),
393            dimensions: dimensions.max(16),
394        }
395    }
396
397    /// Build a deterministic hashing-vector embedding from text.
398    ///
399    /// This is intentionally local-first: no model/API key required, fixed
400    /// dimensions across documents, stable across sessions, and good enough for
401    /// lexical semantic recall in memory. It uses signed feature hashing with
402    /// unigram + adjacent bigram features, sublinear term frequency, and L2
403    /// normalization.
404    pub fn embed(&self, text: &str) -> Vec<f64> {
405        embed_with_dimensions(text, self.dimensions)
406    }
407
408    pub fn add(&mut self, text: &str) {
409        let clean = text.trim();
410        if clean.is_empty() {
411            return;
412        }
413        self.vectors.push((clean.to_string(), self.embed(clean)));
414    }
415
416    pub fn add_many<I, S>(&mut self, texts: I)
417    where
418        I: IntoIterator<Item = S>,
419        S: AsRef<str>,
420    {
421        for text in texts {
422            self.add(text.as_ref());
423        }
424    }
425
426    /// Find the most similar stored text to the query
427    pub fn search(&self, query: &str, k: usize) -> Vec<String> {
428        self.search_scored(query, k)
429            .into_iter()
430            .map(|(_, text)| text)
431            .collect()
432    }
433
434    pub fn search_scored(&self, query: &str, k: usize) -> Vec<(f64, String)> {
435        if k == 0 {
436            return Vec::new();
437        }
438        let q_embed = self.embed(query);
439        let mut scored: Vec<(f64, usize, &str)> = self
440            .vectors
441            .iter()
442            .enumerate()
443            .map(|(idx, (text, emb))| (cosine_sim(&q_embed, emb), idx, text.as_str()))
444            .collect();
445        scored.sort_by(|a, b| {
446            b.0.partial_cmp(&a.0)
447                .unwrap_or(std::cmp::Ordering::Equal)
448                .then(a.1.cmp(&b.1))
449        });
450        scored
451            .into_iter()
452            .take(k)
453            .filter(|(score, _, _)| *score > 0.0)
454            .map(|(score, _, text)| (score, text.to_string()))
455            .collect()
456    }
457
458    pub fn save_to_path(&self, path: impl AsRef<std::path::Path>) -> anyhow::Result<()> {
459        let snapshot = EmbeddingsSnapshot {
460            dimensions: self.dimensions,
461            texts: self.vectors.iter().map(|(text, _)| text.clone()).collect(),
462        };
463        let json = serde_json::to_string_pretty(&snapshot)?;
464        if let Some(parent) = path.as_ref().parent() {
465            std::fs::create_dir_all(parent)?;
466        }
467        std::fs::write(path, json)?;
468        Ok(())
469    }
470
471    pub fn load_from_path(path: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
472        let json = std::fs::read_to_string(path)?;
473        let snapshot: EmbeddingsSnapshot = serde_json::from_str(&json)?;
474        let mut index = Self::with_dimensions(snapshot.dimensions);
475        index.add_many(snapshot.texts);
476        Ok(index)
477    }
478}
479
480impl Default for Embeddings {
481    fn default() -> Self {
482        Self::new()
483    }
484}
485
486#[derive(serde::Serialize, serde::Deserialize)]
487struct EmbeddingsSnapshot {
488    dimensions: usize,
489    texts: Vec<String>,
490}
491
492fn embed_with_dimensions(text: &str, dimensions: usize) -> Vec<f64> {
493    let mut vector = vec![0.0; dimensions.max(16)];
494    let tokens = tokenize(text);
495    for token in &tokens {
496        add_feature(&mut vector, token, 1.0);
497    }
498    for pair in tokens.windows(2) {
499        add_feature(&mut vector, &format!("{}__{}", pair[0], pair[1]), 1.35);
500    }
501    for value in &mut vector {
502        if *value != 0.0 {
503            *value = value.signum() * value.abs().ln_1p();
504        }
505    }
506    normalize(&mut vector);
507    vector
508}
509
510fn tokenize(text: &str) -> Vec<String> {
511    let mut tokens = Vec::new();
512    let mut current = String::new();
513    for ch in text.chars() {
514        if ch.is_alphanumeric() {
515            current.extend(ch.to_lowercase());
516        } else if !current.is_empty() {
517            tokens.push(std::mem::take(&mut current));
518        }
519    }
520    if !current.is_empty() {
521        tokens.push(current);
522    }
523    tokens
524}
525
526fn add_feature(vector: &mut [f64], feature: &str, weight: f64) {
527    let hash = fnv1a64(feature.as_bytes());
528    let idx = (hash as usize) % vector.len();
529    let sign = if hash & (1 << 63) == 0 { 1.0 } else { -1.0 };
530    vector[idx] += sign * weight;
531}
532
533fn fnv1a64(bytes: &[u8]) -> u64 {
534    let mut hash = 0xcbf29ce484222325u64;
535    for byte in bytes {
536        hash ^= *byte as u64;
537        hash = hash.wrapping_mul(0x100000001b3);
538    }
539    hash
540}
541
542fn normalize(vector: &mut [f64]) {
543    let norm = vector.iter().map(|v| v * v).sum::<f64>().sqrt();
544    if norm > 0.0 {
545        for value in vector {
546            *value /= norm;
547        }
548    }
549}
550
551fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
552    let len = a.len().min(b.len());
553    if len == 0 {
554        return 0.0;
555    }
556    let dot: f64 = a.iter().zip(b.iter()).take(len).map(|(x, y)| x * y).sum();
557    let norm_a: f64 = a.iter().take(len).map(|x| x * x).sum::<f64>().sqrt();
558    let norm_b: f64 = b.iter().take(len).map(|x| x * x).sum::<f64>().sqrt();
559    if norm_a == 0.0 || norm_b == 0.0 {
560        0.0
561    } else {
562        dot / (norm_a * norm_b)
563    }
564}
565
566// ─── Replay re-execute ──────────────────────────────────────────────────────────
567
568/// Re-execute a transcript against a chosen model.
569/// §3.15: "can re-execute against a chosen model"
570pub struct ReExecuter {
571    engine: Arc<Engine>,
572}
573
574impl ReExecuter {
575    pub fn new(engine: Arc<Engine>) -> Self {
576        Self { engine }
577    }
578
579    /// Re-execute from a transcript: send the original task to the engine
580    /// with the same parameters.
581    pub async fn re_execute(
582        &self,
583        transcript: &crate::runtime::recorder::Transcript,
584    ) -> anyhow::Result<crate::event::OutcomeSummary> {
585        let (tx, _rx) = mpsc::unbounded_channel::<Event>();
586        let task = Task {
587            description: transcript.inputs.task.clone(),
588            context: vec![],
589        };
590        self.engine.drive(task, tx).await
591    }
592}
593
594// ─── OAuth flow ─────────────────────────────────────────────────────────────────
595
596pub struct OAuthFlow;
597
598impl OAuthFlow {
599    /// Start a device code OAuth flow.
600    /// Accepts the endpoints and scope from the provider registry — no hardcoded list.
601    pub async fn start_device_flow(
602        device_endpoint: &str,
603        token_endpoint_hint: &str, // unused here, kept for symmetry
604        client_id: &str,
605        scope: &str,
606    ) -> anyhow::Result<(String, String, String)> {
607        let _ = token_endpoint_hint;
608        let client = reqwest::Client::new();
609        let resp: serde_json::Value = client
610            .post(device_endpoint)
611            .form(&[("client_id", client_id), ("scope", scope)])
612            .send()
613            .await?
614            .json()
615            .await?;
616
617        let verification_uri = resp["verification_uri"]
618            .as_str()
619            .or_else(|| resp["verification_url"].as_str())
620            .unwrap_or("")
621            .to_string();
622        let user_code = resp["user_code"].as_str().unwrap_or("").to_string();
623        let device_code = resp["device_code"].as_str().unwrap_or("").to_string();
624
625        if device_code.is_empty() {
626            anyhow::bail!("Device flow start failed — provider response: {}", resp);
627        }
628
629        Ok((verification_uri, user_code, device_code))
630    }
631
632    /// Poll for token completion using the provider token endpoint.
633    pub async fn poll_token(
634        token_endpoint: &str,
635        client_id: &str,
636        device_code: &str,
637        timeout_secs: u64,
638    ) -> anyhow::Result<String> {
639        let client = reqwest::Client::new();
640        let start = std::time::Instant::now();
641
642        loop {
643            if start.elapsed().as_secs() > timeout_secs {
644                anyhow::bail!("OAuth timed out after {}s", timeout_secs);
645            }
646
647            let resp: serde_json::Value = client
648                .post(token_endpoint)
649                .form(&[
650                    ("client_id", client_id),
651                    ("device_code", device_code),
652                    ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
653                ])
654                .send()
655                .await?
656                .json()
657                .await?;
658
659            if let Some(token) = resp["access_token"].as_str() {
660                return Ok(token.to_string());
661            }
662
663            match resp["error"].as_str() {
664                Some("authorization_pending") | Some("slow_down") => {
665                    tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
666                    continue;
667                }
668                Some(e) => anyhow::bail!("OAuth error: {}", e),
669                None => {
670                    tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
671                }
672            }
673        }
674    }
675}
676
677// ─── IBM Plex Mono reference ────────────────────────────────────────────────────
678
679/// §9.3: "IBM Plex Mono everywhere (TUI authenticity + web)"
680/// The font is not embedded in the binary; users install it system-wide.
681/// This constant provides the download URL and instructions.
682pub const IBM_PLEX_MONO_URL: &str =
683    "https://github.com/IBM/plex/releases/latest/download/IBM-Plex-Mono.zip";
684
685pub fn ibm_plex_install_instructions() -> String {
686    r#"IBM Plex Mono — recommended font for Sparrow TUI.
687
688Install:
689  Linux:   sudo apt install fonts-ibm-plex
690  macOS:   brew install font-ibm-plex
691  Windows: Download from https://github.com/IBM/plex/releases
692
693Then update your terminal to use "IBM Plex Mono" as the font.
694"#
695    .to_string()
696}
697
698// ─── Chat mode ──────────────────────────────────────────────────────────────────
699
700/// Interactive multi-turn chat loop.
701/// §4: "sparrow chat — interactive multi-turn (TUI/inline)"
702pub struct ChatSession {
703    engine: Arc<Engine>,
704    history: Vec<crate::provider::Msg>,
705    running: bool,
706}
707
708impl ChatSession {
709    pub fn new(engine: Arc<Engine>) -> Self {
710        Self {
711            engine,
712            history: Vec::new(),
713            running: true,
714        }
715    }
716
717    pub async fn run_interactive(&mut self) -> anyhow::Result<()> {
718        use std::io::{self, Write};
719
720        println!("═══ Sparrow Chat ═══");
721        println!("Type your message and press Enter. Type /exit to quit.");
722        println!();
723
724        while self.running {
725            print!("◆ you › ");
726            io::stdout().flush()?;
727
728            let mut input = String::new();
729            io::stdin().read_line(&mut input)?;
730            let input = input.trim().to_string();
731
732            if input.is_empty() {
733                continue;
734            }
735            if input == "/exit" || input == "/quit" {
736                break;
737            }
738
739            self.history.push(crate::provider::Msg {
740                role: "user".into(),
741                content: vec![crate::provider::ContentBlock::Text {
742                    text: input.clone(),
743                }],
744            });
745
746            let (tx, mut rx) = mpsc::unbounded_channel::<Event>();
747            let task = Task {
748                description: input.clone(),
749                context: self.history.clone(),
750            };
751
752            let engine = self.engine.clone();
753            let handle = tokio::spawn(async move { engine.drive(task, tx).await });
754
755            while let Some(event) = rx.recv().await {
756                match &event {
757                    Event::ThinkingDelta { text, .. } => {
758                        print!("{}", text);
759                        io::stdout().flush()?;
760                    }
761                    Event::RunFinished { outcome, .. } => {
762                        println!(
763                            "\n── {} | ${:.4} {}──",
764                            outcome.status,
765                            outcome.cost_usd,
766                            crate::cost::format_comparison_oneliner(
767                                outcome.cost_usd,
768                                &outcome.tokens
769                            )
770                        );
771                    }
772                    Event::Error { message, .. } => {
773                        eprintln!("\nError: {}", message);
774                    }
775                    _ => {}
776                }
777            }
778
779            match handle.await? {
780                Ok(outcome) => {
781                    self.history.push(crate::provider::Msg {
782                        role: "assistant".into(),
783                        content: vec![crate::provider::ContentBlock::Text {
784                            text: format!("[{}]", outcome.status),
785                        }],
786                    });
787                }
788                Err(e) => {
789                    eprintln!("Chat error: {}", e);
790                }
791            }
792            println!();
793        }
794
795        Ok(())
796    }
797}
798
799// ─── Configurable pipeline ──────────────────────────────────────────────────────
800
801/// Allow users to define custom swarm pipeline graphs.
802/// §3.11: "Configurable: number of agents, which model per role, pipeline graph."
803#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
804pub struct PipelineConfig {
805    pub name: String,
806    pub steps: Vec<PipelineStep>,
807    pub max_reworks: u32,
808}
809
810#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
811pub struct PipelineStep {
812    pub role: String,
813    pub model_preference: Option<String>,
814    pub prompt_override: Option<String>,
815    pub depends_on: Vec<String>,
816}
817
818impl PipelineConfig {
819    pub fn default_pipeline() -> Self {
820        Self {
821            name: "planner-coder-verifier".into(),
822            steps: vec![
823                PipelineStep {
824                    role: "planner".into(),
825                    model_preference: None,
826                    prompt_override: None,
827                    depends_on: vec![],
828                },
829                PipelineStep {
830                    role: "coder".into(),
831                    model_preference: None,
832                    prompt_override: None,
833                    depends_on: vec!["planner".into()],
834                },
835                PipelineStep {
836                    role: "verifier".into(),
837                    model_preference: None,
838                    prompt_override: None,
839                    depends_on: vec!["coder".into()],
840                },
841            ],
842            max_reworks: 3,
843        }
844    }
845
846    pub fn validate(&self) -> anyhow::Result<()> {
847        if self.steps.is_empty() {
848            anyhow::bail!("Pipeline must have at least one step");
849        }
850        for step in &self.steps {
851            for dep in &step.depends_on {
852                if !self.steps.iter().any(|s| s.role == *dep) {
853                    anyhow::bail!("Step '{}' depends on unknown role '{}'", step.role, dep);
854                }
855            }
856        }
857        Ok(())
858    }
859
860    pub fn from_toml(content: &str) -> anyhow::Result<Self> {
861        Ok(toml::from_str(content)?)
862    }
863
864    pub fn to_toml(&self) -> anyhow::Result<String> {
865        Ok(toml::to_string_pretty(self)?)
866    }
867}
868
869// ─── Profile isolation ──────────────────────────────────────────────────────────
870
871/// Full profile isolation: separate config, memory, agents per profile.
872/// §4: "sparrow profile <create|list|use> — multi-instance profiles"
873pub struct Profile {
874    pub name: String,
875    pub config_dir: std::path::PathBuf,
876    pub state_dir: std::path::PathBuf,
877    pub config: crate::config::Config,
878    pub memory: Arc<dyn Memory>,
879}
880
881impl Profile {
882    pub fn load(name: &str) -> anyhow::Result<Self> {
883        let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
884        let base_state = dirs::state_dir().unwrap_or_default().join("sparrow");
885
886        let config_dir = base_config.join("profiles").join(name);
887        let state_dir = base_state.join("profiles").join(name);
888
889        std::fs::create_dir_all(&config_dir)?;
890        std::fs::create_dir_all(&state_dir)?;
891
892        let config = if config_dir.join("config.toml").exists() {
893            let content = std::fs::read_to_string(config_dir.join("config.toml"))?;
894            toml::from_str(&content)?
895        } else {
896            // Inherit from default config if available
897            let default = base_config.join("config.toml");
898            if default.exists() {
899                let content = std::fs::read_to_string(&default)?;
900                toml::from_str(&content)?
901            } else {
902                crate::config::Config {
903                    defaults: Default::default(),
904                    routing: Default::default(),
905                    budget: Default::default(),
906                    providers: Default::default(),
907                    surfaces: Default::default(),
908                    experience: Default::default(),
909                    skills: Default::default(),
910                    intel: Default::default(),
911                    permissions: Default::default(),
912                    hooks: Default::default(),
913                    theme: "captain".into(),
914                    config_dir: config_dir.clone(),
915                    state_dir: state_dir.clone(),
916                    forced_model: None,
917                }
918            }
919        };
920
921        let memory: Arc<dyn Memory> = Arc::new(crate::memory::SqliteMemory::open(
922            &state_dir.join("profile.db"),
923        )?);
924
925        Ok(Self {
926            name: name.to_string(),
927            config_dir,
928            state_dir,
929            config,
930            memory,
931        })
932    }
933
934    pub fn create(name: &str) -> anyhow::Result<()> {
935        let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
936        let config_dir = base_config.join("profiles").join(name);
937        std::fs::create_dir_all(&config_dir)?;
938
939        // Copy default config
940        let default = base_config.join("config.toml");
941        if default.exists() {
942            std::fs::copy(&default, config_dir.join("config.toml"))?;
943        }
944
945        let base_state = dirs::state_dir().unwrap_or_default().join("sparrow");
946        std::fs::create_dir_all(base_state.join("profiles").join(name))?;
947
948        Ok(())
949    }
950
951    pub fn list() -> Vec<String> {
952        let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
953        let profiles_dir = base_config.join("profiles");
954        let mut names = Vec::new();
955        if let Ok(entries) = std::fs::read_dir(&profiles_dir) {
956            for entry in entries.flatten() {
957                if entry.path().is_dir() {
958                    if let Some(name) = entry.file_name().to_str() {
959                        names.push(name.to_string());
960                    }
961                }
962            }
963        }
964        names.sort();
965        names
966    }
967}