1use crate::engine::{Engine, Task};
2use crate::event::Event;
3use crate::memory::{Fact, Memory};
4use std::sync::Arc;
5use tokio::sync::mpsc;
6
7pub struct Distiller;
13
14impl Distiller {
15 pub async fn distill(memory: &Arc<dyn Memory>, events: &[Event], task_description: &str) {
18 let mut facts = Vec::new();
19
20 let mut lang_hints: Vec<String> = Vec::new();
22 let mut framework_hints: Vec<String> = Vec::new();
23 let mut tool_usage: std::collections::HashMap<String, u32> =
24 std::collections::HashMap::new();
25 let mut style_hints: Vec<String> = Vec::new();
26 let mut pref_hints: Vec<String> = Vec::new();
27 let mut convention_hints: Vec<String> = Vec::new();
28 let mut directive_hints: Vec<String> = Vec::new();
29
30 for event in events {
31 match event {
32 Event::ToolUseProposed { name, args, .. } => {
33 *tool_usage.entry(name.clone()).or_insert(0) += 1;
35
36 if let Some(path) = args.get("path").and_then(|v| v.as_str()) {
38 detect_languages(path, &mut lang_hints);
39 detect_conventions(path, &mut convention_hints);
40 }
41 if let Some(content) = args.get("content").and_then(|v| v.as_str()) {
43 detect_frameworks(content, &mut framework_hints);
44 }
45 }
46 Event::ThinkingDelta { text, .. } => {
47 if text.contains("refactor") {
49 style_hints.push("refactoring-oriented".to_string());
50 }
51 if text.contains("test") || text.contains("TDD") {
52 style_hints.push("test-driven".to_string());
53 }
54 if text.contains("async") || text.contains("await") {
55 style_hints.push("async-first".to_string());
56 }
57 detect_preferences(text, &mut pref_hints);
59 }
60 Event::Message { text, role, .. } if role == "user" => {
61 detect_preferences(text, &mut pref_hints);
63 detect_directives(text, &mut directive_hints);
67 }
68 _ => {}
69 }
70 }
71
72 dedup(&mut lang_hints);
74 dedup(&mut framework_hints);
75 dedup(&mut style_hints);
76 dedup(&mut pref_hints);
77 dedup(&mut convention_hints);
78 dedup(&mut directive_hints);
79
80 for lang in &lang_hints {
82 facts.push(fact("user:language", lang));
83 }
84 for fw in &framework_hints {
85 facts.push(fact("user:framework", fw));
86 }
87 for style in &style_hints {
88 facts.push(fact("user:style", style));
89 }
90 for pref in &pref_hints {
91 facts.push(fact("user:preference", pref));
92 }
93 for conv in &convention_hints {
94 facts.push(fact("project:convention", conv));
95 }
96 for d in &directive_hints {
99 facts.push(fact(&format!("user:directive:{}", short_hash(d)), d));
100 }
101 for (tool, count) in &tool_usage {
103 if *count >= 3 {
104 facts.push(fact(
105 "user:frequent_tool",
106 &format!("uses {} frequently ({}x this session)", tool, count),
107 ));
108 }
109 }
110
111 let existing = memory.all_facts();
120 let existing_pairs: std::collections::HashSet<(String, String)> = existing
121 .iter()
122 .map(|f| (f.key.clone(), f.value.clone()))
123 .collect();
124 let mut saved = 0;
125
126 for fact in &facts {
127 if !existing_pairs.contains(&(fact.key.clone(), fact.value.clone())) {
128 let _ = memory.remember(fact.clone());
129 saved += 1;
130 }
131 }
132
133 if saved > 0 {
134 tracing::info!(
135 "Distiller: extracted {} facts ({} new) from task: {}",
136 facts.len(),
137 saved,
138 &task_description[..task_description.len().min(60)]
139 );
140 }
141 }
142}
143
144fn detect_languages(path: &str, hints: &mut Vec<String>) {
147 let ext_map: &[(&str, &str)] = &[
148 (".rs", "Rust"),
149 (".ts", "TypeScript"),
150 (".tsx", "TypeScript/React"),
151 (".py", "Python"),
152 (".go", "Go"),
153 (".js", "JavaScript"),
154 (".jsx", "JavaScript/React"),
155 (".java", "Java"),
156 (".rb", "Ruby"),
157 (".css", "CSS"),
158 (".html", "HTML"),
159 (".sql", "SQL"),
160 (".tf", "Terraform"),
161 (".yml", "YAML"),
162 (".yaml", "YAML"),
163 (".toml", "TOML"),
164 (".json", "JSON"),
165 (".md", "Markdown"),
166 (".sh", "Shell"),
167 ];
168 let lower = path.to_lowercase();
169 for (ext, lang) in ext_map {
170 if lower.ends_with(ext) {
171 hints.push(lang.to_string());
172 return;
173 }
174 }
175}
176
177fn detect_frameworks(content: &str, hints: &mut Vec<String>) {
178 let fw_map: &[(&str, &str)] = &[
179 ("Cargo.toml", "Rust/Cargo"),
180 ("package.json", "Node.js"),
181 ("go.mod", "Go modules"),
182 ("requirements.txt", "Python/pip"),
183 ("pyproject.toml", "Python/poetry"),
184 ("Dockerfile", "Docker"),
185 ("docker-compose", "Docker Compose"),
186 ("Makefile", "Make"),
187 ("CMakeLists.txt", "CMake"),
188 ("pom.xml", "Java/Maven"),
189 ("build.gradle", "Java/Gradle"),
190 ];
191 for (pattern, fw) in fw_map {
192 if content.contains(pattern) {
193 hints.push(fw.to_string());
194 }
195 }
196}
197
198fn detect_preferences(text: &str, hints: &mut Vec<String>) {
199 let pref_patterns: &[(&str, &str)] = &[
200 ("prefer async", "prefers async/await"),
201 ("prefer sync", "prefers synchronous code"),
202 ("use tabs", "uses tabs for indentation"),
203 ("use spaces", "uses spaces for indentation"),
204 (
205 "prefer unwrap",
206 "prefers .unwrap() over proper error handling",
207 ),
208 ("prefer anyhow", "prefers anyhow for error handling"),
209 ("instead of", "has strong opinions about alternatives"),
210 ("don't use", "has explicit dislikes"),
211 ("always use", "has explicit preferences"),
212 ("I like", "expressed a personal preference"),
213 ("I want", "expressed a desire"),
214 ];
215 let lower = text.to_lowercase();
216 for (pattern, hint) in pref_patterns {
217 if lower.contains(pattern) {
218 hints.push(hint.to_string());
219 }
220 }
221}
222
223fn detect_directives(text: &str, hints: &mut Vec<String>) {
228 const MARKERS: &[&str] = &[
229 "remember that",
231 "remember to",
232 "don't forget",
233 "keep in mind",
234 "note that",
235 "from now on",
236 "always ",
237 "never ",
238 "my name is",
239 "i prefer",
240 "i want you to",
241 "make sure to",
242 "going forward",
243 "retiens",
245 "souviens-toi",
246 "souviens toi",
247 "n'oublie pas",
248 "note que",
249 "désormais",
250 "dorénavant",
251 "toujours ",
252 "jamais ",
253 "je m'appelle",
254 "je préfère",
255 "je veux que",
256 "à partir de maintenant",
257 "rappelle-toi",
258 ];
259 let lower = text.to_lowercase();
260 for line in text.lines() {
261 let line_l = line.to_lowercase();
262 if MARKERS.iter().any(|m| line_l.contains(m)) {
263 let cleaned = line.trim();
264 if cleaned.len() >= 4 {
265 let capped: String = cleaned.chars().take(220).collect();
267 hints.push(capped);
268 }
269 }
270 }
271 if hints.is_empty() && MARKERS.iter().any(|m| lower.contains(m)) {
273 let capped: String = text.trim().chars().take(220).collect();
274 if capped.len() >= 4 {
275 hints.push(capped);
276 }
277 }
278}
279
280fn short_hash(s: &str) -> String {
282 use std::hash::{Hash, Hasher};
283 let mut h = std::collections::hash_map::DefaultHasher::new();
284 s.trim().to_lowercase().hash(&mut h);
285 format!("{:08x}", (h.finish() & 0xffff_ffff) as u32)
286}
287
288fn detect_conventions(path: &str, hints: &mut Vec<String>) {
289 let conv_patterns: &[(&str, &str)] = &[
290 ("src/main.rs", "Rust binary project structure"),
291 ("src/lib.rs", "Rust library project structure"),
292 ("src/index.ts", "TypeScript entry point convention"),
293 ("src/app.py", "Python app entry point"),
294 ("tests/", "has a test directory"),
295 ("spec/", "has a spec directory"),
296 ("docs/", "maintains documentation"),
297 (".github/workflows/", "uses GitHub Actions CI"),
298 (".gitignore", "has gitignore"),
299 ];
300 let lower = path.to_lowercase();
301 for (pattern, hint) in conv_patterns {
302 if lower.contains(&pattern.to_lowercase()) {
303 hints.push(hint.to_string());
304 }
305 }
306}
307
308fn dedup(v: &mut Vec<String>) {
309 v.sort();
310 v.dedup();
311}
312
313fn fact(key: &str, value: &str) -> Fact {
314 Fact {
315 id: uuid::Uuid::new_v4().to_string(),
316 key: key.to_string(),
317 value: value.to_string(),
318 created_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
319 updated_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
320 }
321}
322
323#[cfg(test)]
324mod distiller_tests {
325 use super::*;
326
327 #[test]
328 fn detect_directives_captures_english_durable_instructions() {
329 let mut h = Vec::new();
330 detect_directives("Remember that I want reports in ./artifacts", &mut h);
331 assert!(h.iter().any(|s| s.contains("reports in ./artifacts")));
332
333 let mut h2 = Vec::new();
334 detect_directives("From now on, always run the tests first", &mut h2);
335 assert_eq!(h2.len(), 1);
336 }
337
338 #[test]
339 fn detect_directives_captures_french_durable_instructions() {
340 let mut h = Vec::new();
341 detect_directives("Retiens que je m'appelle Abdou", &mut h);
342 assert!(h.iter().any(|s| s.contains("Abdou")));
343
344 let mut h2 = Vec::new();
345 detect_directives("Désormais, range les livrables dans ./out", &mut h2);
346 assert_eq!(h2.len(), 1);
347 }
348
349 #[test]
350 fn detect_directives_ignores_ordinary_text() {
351 let mut h = Vec::new();
352 detect_directives("Can you fix the bug in main.rs please?", &mut h);
353 assert!(h.is_empty());
354 }
355
356 #[test]
357 fn short_hash_is_stable_and_distinct() {
358 assert_eq!(
361 short_hash("use ./artifacts"),
362 short_hash(" USE ./artifacts ")
363 );
364 assert_ne!(short_hash("prefer async"), short_hash("prefer sync"));
365 assert_eq!(short_hash("x").len(), 8);
366 }
367}
368
369#[derive(Debug, Clone)]
374pub struct Embeddings {
375 pub vectors: Vec<(String, Vec<f64>)>,
377 dimensions: usize,
378}
379
380impl Embeddings {
381 pub const DEFAULT_DIMENSIONS: usize = 512;
382
383 pub fn new() -> Self {
384 Self {
385 vectors: Vec::new(),
386 dimensions: Self::DEFAULT_DIMENSIONS,
387 }
388 }
389
390 pub fn with_dimensions(dimensions: usize) -> Self {
391 Self {
392 vectors: Vec::new(),
393 dimensions: dimensions.max(16),
394 }
395 }
396
397 pub fn embed(&self, text: &str) -> Vec<f64> {
405 embed_with_dimensions(text, self.dimensions)
406 }
407
408 pub fn add(&mut self, text: &str) {
409 let clean = text.trim();
410 if clean.is_empty() {
411 return;
412 }
413 self.vectors.push((clean.to_string(), self.embed(clean)));
414 }
415
416 pub fn add_many<I, S>(&mut self, texts: I)
417 where
418 I: IntoIterator<Item = S>,
419 S: AsRef<str>,
420 {
421 for text in texts {
422 self.add(text.as_ref());
423 }
424 }
425
426 pub fn search(&self, query: &str, k: usize) -> Vec<String> {
428 self.search_scored(query, k)
429 .into_iter()
430 .map(|(_, text)| text)
431 .collect()
432 }
433
434 pub fn search_scored(&self, query: &str, k: usize) -> Vec<(f64, String)> {
435 if k == 0 {
436 return Vec::new();
437 }
438 let q_embed = self.embed(query);
439 let mut scored: Vec<(f64, usize, &str)> = self
440 .vectors
441 .iter()
442 .enumerate()
443 .map(|(idx, (text, emb))| (cosine_sim(&q_embed, emb), idx, text.as_str()))
444 .collect();
445 scored.sort_by(|a, b| {
446 b.0.partial_cmp(&a.0)
447 .unwrap_or(std::cmp::Ordering::Equal)
448 .then(a.1.cmp(&b.1))
449 });
450 scored
451 .into_iter()
452 .take(k)
453 .filter(|(score, _, _)| *score > 0.0)
454 .map(|(score, _, text)| (score, text.to_string()))
455 .collect()
456 }
457
458 pub fn save_to_path(&self, path: impl AsRef<std::path::Path>) -> anyhow::Result<()> {
459 let snapshot = EmbeddingsSnapshot {
460 dimensions: self.dimensions,
461 texts: self.vectors.iter().map(|(text, _)| text.clone()).collect(),
462 };
463 let json = serde_json::to_string_pretty(&snapshot)?;
464 if let Some(parent) = path.as_ref().parent() {
465 std::fs::create_dir_all(parent)?;
466 }
467 std::fs::write(path, json)?;
468 Ok(())
469 }
470
471 pub fn load_from_path(path: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
472 let json = std::fs::read_to_string(path)?;
473 let snapshot: EmbeddingsSnapshot = serde_json::from_str(&json)?;
474 let mut index = Self::with_dimensions(snapshot.dimensions);
475 index.add_many(snapshot.texts);
476 Ok(index)
477 }
478}
479
480impl Default for Embeddings {
481 fn default() -> Self {
482 Self::new()
483 }
484}
485
486#[derive(serde::Serialize, serde::Deserialize)]
487struct EmbeddingsSnapshot {
488 dimensions: usize,
489 texts: Vec<String>,
490}
491
492fn embed_with_dimensions(text: &str, dimensions: usize) -> Vec<f64> {
493 let mut vector = vec![0.0; dimensions.max(16)];
494 let tokens = tokenize(text);
495 for token in &tokens {
496 add_feature(&mut vector, token, 1.0);
497 }
498 for pair in tokens.windows(2) {
499 add_feature(&mut vector, &format!("{}__{}", pair[0], pair[1]), 1.35);
500 }
501 for value in &mut vector {
502 if *value != 0.0 {
503 *value = value.signum() * value.abs().ln_1p();
504 }
505 }
506 normalize(&mut vector);
507 vector
508}
509
510fn tokenize(text: &str) -> Vec<String> {
511 let mut tokens = Vec::new();
512 let mut current = String::new();
513 for ch in text.chars() {
514 if ch.is_alphanumeric() {
515 current.extend(ch.to_lowercase());
516 } else if !current.is_empty() {
517 tokens.push(std::mem::take(&mut current));
518 }
519 }
520 if !current.is_empty() {
521 tokens.push(current);
522 }
523 tokens
524}
525
526fn add_feature(vector: &mut [f64], feature: &str, weight: f64) {
527 let hash = fnv1a64(feature.as_bytes());
528 let idx = (hash as usize) % vector.len();
529 let sign = if hash & (1 << 63) == 0 { 1.0 } else { -1.0 };
530 vector[idx] += sign * weight;
531}
532
533fn fnv1a64(bytes: &[u8]) -> u64 {
534 let mut hash = 0xcbf29ce484222325u64;
535 for byte in bytes {
536 hash ^= *byte as u64;
537 hash = hash.wrapping_mul(0x100000001b3);
538 }
539 hash
540}
541
542fn normalize(vector: &mut [f64]) {
543 let norm = vector.iter().map(|v| v * v).sum::<f64>().sqrt();
544 if norm > 0.0 {
545 for value in vector {
546 *value /= norm;
547 }
548 }
549}
550
551fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
552 let len = a.len().min(b.len());
553 if len == 0 {
554 return 0.0;
555 }
556 let dot: f64 = a.iter().zip(b.iter()).take(len).map(|(x, y)| x * y).sum();
557 let norm_a: f64 = a.iter().take(len).map(|x| x * x).sum::<f64>().sqrt();
558 let norm_b: f64 = b.iter().take(len).map(|x| x * x).sum::<f64>().sqrt();
559 if norm_a == 0.0 || norm_b == 0.0 {
560 0.0
561 } else {
562 dot / (norm_a * norm_b)
563 }
564}
565
566pub struct ReExecuter {
571 engine: Arc<Engine>,
572}
573
574impl ReExecuter {
575 pub fn new(engine: Arc<Engine>) -> Self {
576 Self { engine }
577 }
578
579 pub async fn re_execute(
582 &self,
583 transcript: &crate::runtime::recorder::Transcript,
584 ) -> anyhow::Result<crate::event::OutcomeSummary> {
585 let (tx, _rx) = mpsc::unbounded_channel::<Event>();
586 let task = Task {
587 description: transcript.inputs.task.clone(),
588 context: vec![],
589 };
590 self.engine.drive(task, tx).await
591 }
592}
593
594pub struct OAuthFlow;
597
598impl OAuthFlow {
599 pub async fn start_device_flow(
602 device_endpoint: &str,
603 token_endpoint_hint: &str, client_id: &str,
605 scope: &str,
606 ) -> anyhow::Result<(String, String, String)> {
607 let _ = token_endpoint_hint;
608 let client = reqwest::Client::new();
609 let resp: serde_json::Value = client
610 .post(device_endpoint)
611 .form(&[("client_id", client_id), ("scope", scope)])
612 .send()
613 .await?
614 .json()
615 .await?;
616
617 let verification_uri = resp["verification_uri"]
618 .as_str()
619 .or_else(|| resp["verification_url"].as_str())
620 .unwrap_or("")
621 .to_string();
622 let user_code = resp["user_code"].as_str().unwrap_or("").to_string();
623 let device_code = resp["device_code"].as_str().unwrap_or("").to_string();
624
625 if device_code.is_empty() {
626 anyhow::bail!("Device flow start failed — provider response: {}", resp);
627 }
628
629 Ok((verification_uri, user_code, device_code))
630 }
631
632 pub async fn poll_token(
634 token_endpoint: &str,
635 client_id: &str,
636 device_code: &str,
637 timeout_secs: u64,
638 ) -> anyhow::Result<String> {
639 let client = reqwest::Client::new();
640 let start = std::time::Instant::now();
641
642 loop {
643 if start.elapsed().as_secs() > timeout_secs {
644 anyhow::bail!("OAuth timed out after {}s", timeout_secs);
645 }
646
647 let resp: serde_json::Value = client
648 .post(token_endpoint)
649 .form(&[
650 ("client_id", client_id),
651 ("device_code", device_code),
652 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
653 ])
654 .send()
655 .await?
656 .json()
657 .await?;
658
659 if let Some(token) = resp["access_token"].as_str() {
660 return Ok(token.to_string());
661 }
662
663 match resp["error"].as_str() {
664 Some("authorization_pending") | Some("slow_down") => {
665 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
666 continue;
667 }
668 Some(e) => anyhow::bail!("OAuth error: {}", e),
669 None => {
670 tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
671 }
672 }
673 }
674 }
675}
676
677pub const IBM_PLEX_MONO_URL: &str =
683 "https://github.com/IBM/plex/releases/latest/download/IBM-Plex-Mono.zip";
684
685pub fn ibm_plex_install_instructions() -> String {
686 r#"IBM Plex Mono — recommended font for Sparrow TUI.
687
688Install:
689 Linux: sudo apt install fonts-ibm-plex
690 macOS: brew install font-ibm-plex
691 Windows: Download from https://github.com/IBM/plex/releases
692
693Then update your terminal to use "IBM Plex Mono" as the font.
694"#
695 .to_string()
696}
697
698pub struct ChatSession {
703 engine: Arc<Engine>,
704 history: Vec<crate::provider::Msg>,
705 running: bool,
706}
707
708impl ChatSession {
709 pub fn new(engine: Arc<Engine>) -> Self {
710 Self {
711 engine,
712 history: Vec::new(),
713 running: true,
714 }
715 }
716
717 pub async fn run_interactive(&mut self) -> anyhow::Result<()> {
718 use std::io::{self, Write};
719
720 println!("═══ Sparrow Chat ═══");
721 println!("Type your message and press Enter. Type /exit to quit.");
722 println!();
723
724 while self.running {
725 print!("◆ you › ");
726 io::stdout().flush()?;
727
728 let mut input = String::new();
729 io::stdin().read_line(&mut input)?;
730 let input = input.trim().to_string();
731
732 if input.is_empty() {
733 continue;
734 }
735 if input == "/exit" || input == "/quit" {
736 break;
737 }
738
739 self.history.push(crate::provider::Msg {
740 role: "user".into(),
741 content: vec![crate::provider::ContentBlock::Text {
742 text: input.clone(),
743 }],
744 });
745
746 let (tx, mut rx) = mpsc::unbounded_channel::<Event>();
747 let task = Task {
748 description: input.clone(),
749 context: self.history.clone(),
750 };
751
752 let engine = self.engine.clone();
753 let handle = tokio::spawn(async move { engine.drive(task, tx).await });
754
755 while let Some(event) = rx.recv().await {
756 match &event {
757 Event::ThinkingDelta { text, .. } => {
758 print!("{}", text);
759 io::stdout().flush()?;
760 }
761 Event::RunFinished { outcome, .. } => {
762 println!(
763 "\n── {} | ${:.4} {}──",
764 outcome.status,
765 outcome.cost_usd,
766 crate::cost::format_comparison_oneliner(
767 outcome.cost_usd,
768 &outcome.tokens
769 )
770 );
771 }
772 Event::Error { message, .. } => {
773 eprintln!("\nError: {}", message);
774 }
775 _ => {}
776 }
777 }
778
779 match handle.await? {
780 Ok(outcome) => {
781 self.history.push(crate::provider::Msg {
782 role: "assistant".into(),
783 content: vec![crate::provider::ContentBlock::Text {
784 text: format!("[{}]", outcome.status),
785 }],
786 });
787 }
788 Err(e) => {
789 eprintln!("Chat error: {}", e);
790 }
791 }
792 println!();
793 }
794
795 Ok(())
796 }
797}
798
799#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
804pub struct PipelineConfig {
805 pub name: String,
806 pub steps: Vec<PipelineStep>,
807 pub max_reworks: u32,
808}
809
810#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
811pub struct PipelineStep {
812 pub role: String,
813 pub model_preference: Option<String>,
814 pub prompt_override: Option<String>,
815 pub depends_on: Vec<String>,
816}
817
818impl PipelineConfig {
819 pub fn default_pipeline() -> Self {
820 Self {
821 name: "planner-coder-verifier".into(),
822 steps: vec![
823 PipelineStep {
824 role: "planner".into(),
825 model_preference: None,
826 prompt_override: None,
827 depends_on: vec![],
828 },
829 PipelineStep {
830 role: "coder".into(),
831 model_preference: None,
832 prompt_override: None,
833 depends_on: vec!["planner".into()],
834 },
835 PipelineStep {
836 role: "verifier".into(),
837 model_preference: None,
838 prompt_override: None,
839 depends_on: vec!["coder".into()],
840 },
841 ],
842 max_reworks: 3,
843 }
844 }
845
846 pub fn validate(&self) -> anyhow::Result<()> {
847 if self.steps.is_empty() {
848 anyhow::bail!("Pipeline must have at least one step");
849 }
850 for step in &self.steps {
851 for dep in &step.depends_on {
852 if !self.steps.iter().any(|s| s.role == *dep) {
853 anyhow::bail!("Step '{}' depends on unknown role '{}'", step.role, dep);
854 }
855 }
856 }
857 Ok(())
858 }
859
860 pub fn from_toml(content: &str) -> anyhow::Result<Self> {
861 Ok(toml::from_str(content)?)
862 }
863
864 pub fn to_toml(&self) -> anyhow::Result<String> {
865 Ok(toml::to_string_pretty(self)?)
866 }
867}
868
869pub struct Profile {
874 pub name: String,
875 pub config_dir: std::path::PathBuf,
876 pub state_dir: std::path::PathBuf,
877 pub config: crate::config::Config,
878 pub memory: Arc<dyn Memory>,
879}
880
881impl Profile {
882 pub fn load(name: &str) -> anyhow::Result<Self> {
883 let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
884 let base_state = dirs::state_dir().unwrap_or_default().join("sparrow");
885
886 let config_dir = base_config.join("profiles").join(name);
887 let state_dir = base_state.join("profiles").join(name);
888
889 std::fs::create_dir_all(&config_dir)?;
890 std::fs::create_dir_all(&state_dir)?;
891
892 let config = if config_dir.join("config.toml").exists() {
893 let content = std::fs::read_to_string(config_dir.join("config.toml"))?;
894 toml::from_str(&content)?
895 } else {
896 let default = base_config.join("config.toml");
898 if default.exists() {
899 let content = std::fs::read_to_string(&default)?;
900 toml::from_str(&content)?
901 } else {
902 crate::config::Config {
903 defaults: Default::default(),
904 routing: Default::default(),
905 budget: Default::default(),
906 providers: Default::default(),
907 surfaces: Default::default(),
908 experience: Default::default(),
909 skills: Default::default(),
910 intel: Default::default(),
911 permissions: Default::default(),
912 hooks: Default::default(),
913 theme: "captain".into(),
914 config_dir: config_dir.clone(),
915 state_dir: state_dir.clone(),
916 forced_model: None,
917 }
918 }
919 };
920
921 let memory: Arc<dyn Memory> = Arc::new(crate::memory::SqliteMemory::open(
922 &state_dir.join("profile.db"),
923 )?);
924
925 Ok(Self {
926 name: name.to_string(),
927 config_dir,
928 state_dir,
929 config,
930 memory,
931 })
932 }
933
934 pub fn create(name: &str) -> anyhow::Result<()> {
935 let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
936 let config_dir = base_config.join("profiles").join(name);
937 std::fs::create_dir_all(&config_dir)?;
938
939 let default = base_config.join("config.toml");
941 if default.exists() {
942 std::fs::copy(&default, config_dir.join("config.toml"))?;
943 }
944
945 let base_state = dirs::state_dir().unwrap_or_default().join("sparrow");
946 std::fs::create_dir_all(base_state.join("profiles").join(name))?;
947
948 Ok(())
949 }
950
951 pub fn list() -> Vec<String> {
952 let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
953 let profiles_dir = base_config.join("profiles");
954 let mut names = Vec::new();
955 if let Ok(entries) = std::fs::read_dir(&profiles_dir) {
956 for entry in entries.flatten() {
957 if entry.path().is_dir() {
958 if let Some(name) = entry.file_name().to_str() {
959 names.push(name.to_string());
960 }
961 }
962 }
963 }
964 names.sort();
965 names
966 }
967}