1use crate::engine::{Engine, Task};
2use crate::event::Event;
3use crate::memory::{Fact, Memory};
4use std::sync::Arc;
5use tokio::sync::mpsc;
6
7pub struct Distiller;
13
14impl Distiller {
15 pub async fn distill(memory: &Arc<dyn Memory>, events: &[Event], task_description: &str) {
18 let mut facts = Vec::new();
19
20 let mut lang_hints: Vec<String> = Vec::new();
22 let mut framework_hints: Vec<String> = Vec::new();
23 let mut tool_usage: std::collections::HashMap<String, u32> =
24 std::collections::HashMap::new();
25 let mut style_hints: Vec<String> = Vec::new();
26 let mut pref_hints: Vec<String> = Vec::new();
27 let mut convention_hints: Vec<String> = Vec::new();
28
29 for event in events {
30 match event {
31 Event::ToolUseProposed { name, args, .. } => {
32 *tool_usage.entry(name.clone()).or_insert(0) += 1;
34
35 if let Some(path) = args.get("path").and_then(|v| v.as_str()) {
37 detect_languages(path, &mut lang_hints);
38 detect_conventions(path, &mut convention_hints);
39 }
40 if let Some(content) = args.get("content").and_then(|v| v.as_str()) {
42 detect_frameworks(content, &mut framework_hints);
43 }
44 }
45 Event::ThinkingDelta { text, .. } => {
46 if text.contains("refactor") {
48 style_hints.push("refactoring-oriented".to_string());
49 }
50 if text.contains("test") || text.contains("TDD") {
51 style_hints.push("test-driven".to_string());
52 }
53 if text.contains("async") || text.contains("await") {
54 style_hints.push("async-first".to_string());
55 }
56 detect_preferences(text, &mut pref_hints);
58 }
59 Event::Message { text, role, .. } if role == "user" => {
60 detect_preferences(text, &mut pref_hints);
62 }
63 _ => {}
64 }
65 }
66
67 dedup(&mut lang_hints);
69 dedup(&mut framework_hints);
70 dedup(&mut style_hints);
71 dedup(&mut pref_hints);
72 dedup(&mut convention_hints);
73
74 for lang in &lang_hints {
76 facts.push(fact("user:language", lang));
77 }
78 for fw in &framework_hints {
79 facts.push(fact("user:framework", fw));
80 }
81 for style in &style_hints {
82 facts.push(fact("user:style", style));
83 }
84 for pref in &pref_hints {
85 facts.push(fact("user:preference", pref));
86 }
87 for conv in &convention_hints {
88 facts.push(fact("project:convention", conv));
89 }
90 for (tool, count) in &tool_usage {
92 if *count >= 3 {
93 facts.push(fact(
94 "user:frequent_tool",
95 &format!("uses {} frequently ({}x this session)", tool, count),
96 ));
97 }
98 }
99
100 let existing = memory.all_facts();
102 let existing_keys: Vec<&str> = existing.iter().map(|f| f.key.as_str()).collect();
103 let mut saved = 0;
104
105 for fact in &facts {
106 if !existing_keys.contains(&fact.key.as_str()) {
107 let _ = memory.remember(fact.clone());
108 saved += 1;
109 }
110 }
111
112 if saved > 0 {
113 tracing::info!(
114 "Distiller: extracted {} facts ({} new) from task: {}",
115 facts.len(),
116 saved,
117 &task_description[..task_description.len().min(60)]
118 );
119 }
120 }
121}
122
123fn detect_languages(path: &str, hints: &mut Vec<String>) {
126 let ext_map: &[(&str, &str)] = &[
127 (".rs", "Rust"),
128 (".ts", "TypeScript"),
129 (".tsx", "TypeScript/React"),
130 (".py", "Python"),
131 (".go", "Go"),
132 (".js", "JavaScript"),
133 (".jsx", "JavaScript/React"),
134 (".java", "Java"),
135 (".rb", "Ruby"),
136 (".css", "CSS"),
137 (".html", "HTML"),
138 (".sql", "SQL"),
139 (".tf", "Terraform"),
140 (".yml", "YAML"),
141 (".yaml", "YAML"),
142 (".toml", "TOML"),
143 (".json", "JSON"),
144 (".md", "Markdown"),
145 (".sh", "Shell"),
146 ];
147 let lower = path.to_lowercase();
148 for (ext, lang) in ext_map {
149 if lower.ends_with(ext) {
150 hints.push(lang.to_string());
151 return;
152 }
153 }
154}
155
156fn detect_frameworks(content: &str, hints: &mut Vec<String>) {
157 let fw_map: &[(&str, &str)] = &[
158 ("Cargo.toml", "Rust/Cargo"),
159 ("package.json", "Node.js"),
160 ("go.mod", "Go modules"),
161 ("requirements.txt", "Python/pip"),
162 ("pyproject.toml", "Python/poetry"),
163 ("Dockerfile", "Docker"),
164 ("docker-compose", "Docker Compose"),
165 ("Makefile", "Make"),
166 ("CMakeLists.txt", "CMake"),
167 ("pom.xml", "Java/Maven"),
168 ("build.gradle", "Java/Gradle"),
169 ];
170 for (pattern, fw) in fw_map {
171 if content.contains(pattern) {
172 hints.push(fw.to_string());
173 }
174 }
175}
176
177fn detect_preferences(text: &str, hints: &mut Vec<String>) {
178 let pref_patterns: &[(&str, &str)] = &[
179 ("prefer async", "prefers async/await"),
180 ("prefer sync", "prefers synchronous code"),
181 ("use tabs", "uses tabs for indentation"),
182 ("use spaces", "uses spaces for indentation"),
183 (
184 "prefer unwrap",
185 "prefers .unwrap() over proper error handling",
186 ),
187 ("prefer anyhow", "prefers anyhow for error handling"),
188 ("instead of", "has strong opinions about alternatives"),
189 ("don't use", "has explicit dislikes"),
190 ("always use", "has explicit preferences"),
191 ("I like", "expressed a personal preference"),
192 ("I want", "expressed a desire"),
193 ];
194 let lower = text.to_lowercase();
195 for (pattern, hint) in pref_patterns {
196 if lower.contains(pattern) {
197 hints.push(hint.to_string());
198 }
199 }
200}
201
202fn detect_conventions(path: &str, hints: &mut Vec<String>) {
203 let conv_patterns: &[(&str, &str)] = &[
204 ("src/main.rs", "Rust binary project structure"),
205 ("src/lib.rs", "Rust library project structure"),
206 ("src/index.ts", "TypeScript entry point convention"),
207 ("src/app.py", "Python app entry point"),
208 ("tests/", "has a test directory"),
209 ("spec/", "has a spec directory"),
210 ("docs/", "maintains documentation"),
211 (".github/workflows/", "uses GitHub Actions CI"),
212 (".gitignore", "has gitignore"),
213 ];
214 let lower = path.to_lowercase();
215 for (pattern, hint) in conv_patterns {
216 if lower.contains(&pattern.to_lowercase()) {
217 hints.push(hint.to_string());
218 }
219 }
220}
221
222fn dedup(v: &mut Vec<String>) {
223 v.sort();
224 v.dedup();
225}
226
227fn fact(key: &str, value: &str) -> Fact {
228 Fact {
229 id: uuid::Uuid::new_v4().to_string(),
230 key: key.to_string(),
231 value: value.to_string(),
232 created_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
233 updated_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
234 }
235}
236
237#[derive(Debug, Clone)]
242pub struct Embeddings {
243 pub vectors: Vec<(String, Vec<f64>)>,
245 dimensions: usize,
246}
247
248impl Embeddings {
249 pub const DEFAULT_DIMENSIONS: usize = 512;
250
251 pub fn new() -> Self {
252 Self {
253 vectors: Vec::new(),
254 dimensions: Self::DEFAULT_DIMENSIONS,
255 }
256 }
257
258 pub fn with_dimensions(dimensions: usize) -> Self {
259 Self {
260 vectors: Vec::new(),
261 dimensions: dimensions.max(16),
262 }
263 }
264
265 pub fn embed(&self, text: &str) -> Vec<f64> {
273 embed_with_dimensions(text, self.dimensions)
274 }
275
276 pub fn add(&mut self, text: &str) {
277 let clean = text.trim();
278 if clean.is_empty() {
279 return;
280 }
281 self.vectors.push((clean.to_string(), self.embed(clean)));
282 }
283
284 pub fn add_many<I, S>(&mut self, texts: I)
285 where
286 I: IntoIterator<Item = S>,
287 S: AsRef<str>,
288 {
289 for text in texts {
290 self.add(text.as_ref());
291 }
292 }
293
294 pub fn search(&self, query: &str, k: usize) -> Vec<String> {
296 self.search_scored(query, k)
297 .into_iter()
298 .map(|(_, text)| text)
299 .collect()
300 }
301
302 pub fn search_scored(&self, query: &str, k: usize) -> Vec<(f64, String)> {
303 if k == 0 {
304 return Vec::new();
305 }
306 let q_embed = self.embed(query);
307 let mut scored: Vec<(f64, usize, &str)> = self
308 .vectors
309 .iter()
310 .enumerate()
311 .map(|(idx, (text, emb))| (cosine_sim(&q_embed, emb), idx, text.as_str()))
312 .collect();
313 scored.sort_by(|a, b| {
314 b.0.partial_cmp(&a.0)
315 .unwrap_or(std::cmp::Ordering::Equal)
316 .then(a.1.cmp(&b.1))
317 });
318 scored
319 .into_iter()
320 .take(k)
321 .filter(|(score, _, _)| *score > 0.0)
322 .map(|(score, _, text)| (score, text.to_string()))
323 .collect()
324 }
325
326 pub fn save_to_path(&self, path: impl AsRef<std::path::Path>) -> anyhow::Result<()> {
327 let snapshot = EmbeddingsSnapshot {
328 dimensions: self.dimensions,
329 texts: self.vectors.iter().map(|(text, _)| text.clone()).collect(),
330 };
331 let json = serde_json::to_string_pretty(&snapshot)?;
332 if let Some(parent) = path.as_ref().parent() {
333 std::fs::create_dir_all(parent)?;
334 }
335 std::fs::write(path, json)?;
336 Ok(())
337 }
338
339 pub fn load_from_path(path: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
340 let json = std::fs::read_to_string(path)?;
341 let snapshot: EmbeddingsSnapshot = serde_json::from_str(&json)?;
342 let mut index = Self::with_dimensions(snapshot.dimensions);
343 index.add_many(snapshot.texts);
344 Ok(index)
345 }
346}
347
348impl Default for Embeddings {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354#[derive(serde::Serialize, serde::Deserialize)]
355struct EmbeddingsSnapshot {
356 dimensions: usize,
357 texts: Vec<String>,
358}
359
360fn embed_with_dimensions(text: &str, dimensions: usize) -> Vec<f64> {
361 let mut vector = vec![0.0; dimensions.max(16)];
362 let tokens = tokenize(text);
363 for token in &tokens {
364 add_feature(&mut vector, token, 1.0);
365 }
366 for pair in tokens.windows(2) {
367 add_feature(&mut vector, &format!("{}__{}", pair[0], pair[1]), 1.35);
368 }
369 for value in &mut vector {
370 if *value != 0.0 {
371 *value = value.signum() * value.abs().ln_1p();
372 }
373 }
374 normalize(&mut vector);
375 vector
376}
377
378fn tokenize(text: &str) -> Vec<String> {
379 let mut tokens = Vec::new();
380 let mut current = String::new();
381 for ch in text.chars() {
382 if ch.is_alphanumeric() {
383 current.extend(ch.to_lowercase());
384 } else if !current.is_empty() {
385 tokens.push(std::mem::take(&mut current));
386 }
387 }
388 if !current.is_empty() {
389 tokens.push(current);
390 }
391 tokens
392}
393
394fn add_feature(vector: &mut [f64], feature: &str, weight: f64) {
395 let hash = fnv1a64(feature.as_bytes());
396 let idx = (hash as usize) % vector.len();
397 let sign = if hash & (1 << 63) == 0 { 1.0 } else { -1.0 };
398 vector[idx] += sign * weight;
399}
400
401fn fnv1a64(bytes: &[u8]) -> u64 {
402 let mut hash = 0xcbf29ce484222325u64;
403 for byte in bytes {
404 hash ^= *byte as u64;
405 hash = hash.wrapping_mul(0x100000001b3);
406 }
407 hash
408}
409
410fn normalize(vector: &mut [f64]) {
411 let norm = vector.iter().map(|v| v * v).sum::<f64>().sqrt();
412 if norm > 0.0 {
413 for value in vector {
414 *value /= norm;
415 }
416 }
417}
418
419fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
420 let len = a.len().min(b.len());
421 if len == 0 {
422 return 0.0;
423 }
424 let dot: f64 = a.iter().zip(b.iter()).take(len).map(|(x, y)| x * y).sum();
425 let norm_a: f64 = a.iter().take(len).map(|x| x * x).sum::<f64>().sqrt();
426 let norm_b: f64 = b.iter().take(len).map(|x| x * x).sum::<f64>().sqrt();
427 if norm_a == 0.0 || norm_b == 0.0 {
428 0.0
429 } else {
430 dot / (norm_a * norm_b)
431 }
432}
433
434pub struct ReExecuter {
439 engine: Arc<Engine>,
440}
441
442impl ReExecuter {
443 pub fn new(engine: Arc<Engine>) -> Self {
444 Self { engine }
445 }
446
447 pub async fn re_execute(
450 &self,
451 transcript: &crate::runtime::recorder::Transcript,
452 ) -> anyhow::Result<crate::event::OutcomeSummary> {
453 let (tx, _rx) = mpsc::unbounded_channel::<Event>();
454 let task = Task {
455 description: transcript.inputs.task.clone(),
456 context: vec![],
457 };
458 self.engine.drive(task, tx).await
459 }
460}
461
462pub struct OAuthFlow;
465
466impl OAuthFlow {
467 pub async fn start_device_flow(
470 device_endpoint: &str,
471 token_endpoint_hint: &str, client_id: &str,
473 scope: &str,
474 ) -> anyhow::Result<(String, String, String)> {
475 let _ = token_endpoint_hint;
476 let client = reqwest::Client::new();
477 let resp: serde_json::Value = client
478 .post(device_endpoint)
479 .form(&[("client_id", client_id), ("scope", scope)])
480 .send()
481 .await?
482 .json()
483 .await?;
484
485 let verification_uri = resp["verification_uri"]
486 .as_str()
487 .or_else(|| resp["verification_url"].as_str())
488 .unwrap_or("")
489 .to_string();
490 let user_code = resp["user_code"].as_str().unwrap_or("").to_string();
491 let device_code = resp["device_code"].as_str().unwrap_or("").to_string();
492
493 if device_code.is_empty() {
494 anyhow::bail!("Device flow start failed — provider response: {}", resp);
495 }
496
497 Ok((verification_uri, user_code, device_code))
498 }
499
500 pub async fn poll_token(
502 token_endpoint: &str,
503 client_id: &str,
504 device_code: &str,
505 timeout_secs: u64,
506 ) -> anyhow::Result<String> {
507 let client = reqwest::Client::new();
508 let start = std::time::Instant::now();
509
510 loop {
511 if start.elapsed().as_secs() > timeout_secs {
512 anyhow::bail!("OAuth timed out after {}s", timeout_secs);
513 }
514
515 let resp: serde_json::Value = client
516 .post(token_endpoint)
517 .form(&[
518 ("client_id", client_id),
519 ("device_code", device_code),
520 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
521 ])
522 .send()
523 .await?
524 .json()
525 .await?;
526
527 if let Some(token) = resp["access_token"].as_str() {
528 return Ok(token.to_string());
529 }
530
531 match resp["error"].as_str() {
532 Some("authorization_pending") | Some("slow_down") => {
533 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
534 continue;
535 }
536 Some(e) => anyhow::bail!("OAuth error: {}", e),
537 None => {
538 tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
539 }
540 }
541 }
542 }
543}
544
545pub const IBM_PLEX_MONO_URL: &str =
551 "https://github.com/IBM/plex/releases/latest/download/IBM-Plex-Mono.zip";
552
553pub fn ibm_plex_install_instructions() -> String {
554 r#"IBM Plex Mono — recommended font for Sparrow TUI.
555
556Install:
557 Linux: sudo apt install fonts-ibm-plex
558 macOS: brew install font-ibm-plex
559 Windows: Download from https://github.com/IBM/plex/releases
560
561Then update your terminal to use "IBM Plex Mono" as the font.
562"#
563 .to_string()
564}
565
566pub struct ChatSession {
571 engine: Arc<Engine>,
572 history: Vec<crate::provider::Msg>,
573 running: bool,
574}
575
576impl ChatSession {
577 pub fn new(engine: Arc<Engine>) -> Self {
578 Self {
579 engine,
580 history: Vec::new(),
581 running: true,
582 }
583 }
584
585 pub async fn run_interactive(&mut self) -> anyhow::Result<()> {
586 use std::io::{self, Write};
587
588 println!("═══ Sparrow Chat ═══");
589 println!("Type your message and press Enter. Type /exit to quit.");
590 println!();
591
592 while self.running {
593 print!("◆ you › ");
594 io::stdout().flush()?;
595
596 let mut input = String::new();
597 io::stdin().read_line(&mut input)?;
598 let input = input.trim().to_string();
599
600 if input.is_empty() {
601 continue;
602 }
603 if input == "/exit" || input == "/quit" {
604 break;
605 }
606
607 self.history.push(crate::provider::Msg {
608 role: "user".into(),
609 content: vec![crate::provider::ContentBlock::Text {
610 text: input.clone(),
611 }],
612 });
613
614 let (tx, mut rx) = mpsc::unbounded_channel::<Event>();
615 let task = Task {
616 description: input.clone(),
617 context: self.history.clone(),
618 };
619
620 let engine = self.engine.clone();
621 let handle = tokio::spawn(async move { engine.drive(task, tx).await });
622
623 while let Some(event) = rx.recv().await {
624 match &event {
625 Event::ThinkingDelta { text, .. } => {
626 print!("{}", text);
627 io::stdout().flush()?;
628 }
629 Event::RunFinished { outcome, .. } => {
630 println!(
631 "\n── {} | ${:.4} {}──",
632 outcome.status,
633 outcome.cost_usd,
634 crate::cost::format_comparison_oneliner(
635 outcome.cost_usd,
636 &outcome.tokens
637 )
638 );
639 }
640 Event::Error { message, .. } => {
641 eprintln!("\nError: {}", message);
642 }
643 _ => {}
644 }
645 }
646
647 match handle.await? {
648 Ok(outcome) => {
649 self.history.push(crate::provider::Msg {
650 role: "assistant".into(),
651 content: vec![crate::provider::ContentBlock::Text {
652 text: format!("[{}]", outcome.status),
653 }],
654 });
655 }
656 Err(e) => {
657 eprintln!("Chat error: {}", e);
658 }
659 }
660 println!();
661 }
662
663 Ok(())
664 }
665}
666
667#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
672pub struct PipelineConfig {
673 pub name: String,
674 pub steps: Vec<PipelineStep>,
675 pub max_reworks: u32,
676}
677
678#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
679pub struct PipelineStep {
680 pub role: String,
681 pub model_preference: Option<String>,
682 pub prompt_override: Option<String>,
683 pub depends_on: Vec<String>,
684}
685
686impl PipelineConfig {
687 pub fn default_pipeline() -> Self {
688 Self {
689 name: "planner-coder-verifier".into(),
690 steps: vec![
691 PipelineStep {
692 role: "planner".into(),
693 model_preference: None,
694 prompt_override: None,
695 depends_on: vec![],
696 },
697 PipelineStep {
698 role: "coder".into(),
699 model_preference: None,
700 prompt_override: None,
701 depends_on: vec!["planner".into()],
702 },
703 PipelineStep {
704 role: "verifier".into(),
705 model_preference: None,
706 prompt_override: None,
707 depends_on: vec!["coder".into()],
708 },
709 ],
710 max_reworks: 3,
711 }
712 }
713
714 pub fn validate(&self) -> anyhow::Result<()> {
715 if self.steps.is_empty() {
716 anyhow::bail!("Pipeline must have at least one step");
717 }
718 for step in &self.steps {
719 for dep in &step.depends_on {
720 if !self.steps.iter().any(|s| s.role == *dep) {
721 anyhow::bail!("Step '{}' depends on unknown role '{}'", step.role, dep);
722 }
723 }
724 }
725 Ok(())
726 }
727
728 pub fn from_toml(content: &str) -> anyhow::Result<Self> {
729 Ok(toml::from_str(content)?)
730 }
731
732 pub fn to_toml(&self) -> anyhow::Result<String> {
733 Ok(toml::to_string_pretty(self)?)
734 }
735}
736
737pub struct Profile {
742 pub name: String,
743 pub config_dir: std::path::PathBuf,
744 pub state_dir: std::path::PathBuf,
745 pub config: crate::config::Config,
746 pub memory: Arc<dyn Memory>,
747}
748
749impl Profile {
750 pub fn load(name: &str) -> anyhow::Result<Self> {
751 let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
752 let base_state = dirs::state_dir().unwrap_or_default().join("sparrow");
753
754 let config_dir = base_config.join("profiles").join(name);
755 let state_dir = base_state.join("profiles").join(name);
756
757 std::fs::create_dir_all(&config_dir)?;
758 std::fs::create_dir_all(&state_dir)?;
759
760 let config = if config_dir.join("config.toml").exists() {
761 let content = std::fs::read_to_string(config_dir.join("config.toml"))?;
762 toml::from_str(&content)?
763 } else {
764 let default = base_config.join("config.toml");
766 if default.exists() {
767 let content = std::fs::read_to_string(&default)?;
768 toml::from_str(&content)?
769 } else {
770 crate::config::Config {
771 defaults: Default::default(),
772 routing: Default::default(),
773 budget: Default::default(),
774 providers: Default::default(),
775 surfaces: Default::default(),
776 skills: Default::default(),
777 permissions: Default::default(),
778 hooks: Default::default(),
779 theme: "captain".into(),
780 config_dir: config_dir.clone(),
781 state_dir: state_dir.clone(),
782 forced_model: None,
783 }
784 }
785 };
786
787 let memory: Arc<dyn Memory> = Arc::new(crate::memory::SqliteMemory::open(
788 &state_dir.join("profile.db"),
789 )?);
790
791 Ok(Self {
792 name: name.to_string(),
793 config_dir,
794 state_dir,
795 config,
796 memory,
797 })
798 }
799
800 pub fn create(name: &str) -> anyhow::Result<()> {
801 let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
802 let config_dir = base_config.join("profiles").join(name);
803 std::fs::create_dir_all(&config_dir)?;
804
805 let default = base_config.join("config.toml");
807 if default.exists() {
808 std::fs::copy(&default, config_dir.join("config.toml"))?;
809 }
810
811 let base_state = dirs::state_dir().unwrap_or_default().join("sparrow");
812 std::fs::create_dir_all(base_state.join("profiles").join(name))?;
813
814 Ok(())
815 }
816
817 pub fn list() -> Vec<String> {
818 let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
819 let profiles_dir = base_config.join("profiles");
820 let mut names = Vec::new();
821 if let Ok(entries) = std::fs::read_dir(&profiles_dir) {
822 for entry in entries.flatten() {
823 if entry.path().is_dir() {
824 if let Some(name) = entry.file_name().to_str() {
825 names.push(name.to_string());
826 }
827 }
828 }
829 }
830 names.sort();
831 names
832 }
833}