Skip to main content

sparrow/router/
learned.rs

1//! Per-repo routing memory — local-first learning of which task tiers
2//! actually succeed in THIS repository.
3//!
4//! Every verified run records its outcome per tier in
5//! `.sparrow/routing_memory.json` under the workspace root. When a tier keeps
6//! failing or escalating here, the engine starts the next run one tier higher
7//! — the router learns the repo without any telemetry leaving the machine.
8//!
9//! Only verification-backed outcomes are recorded: a run that "completed"
10//! without a verify command proves nothing and would pollute the data.
11
12use std::collections::HashMap;
13use std::path::{Path, PathBuf};
14
15use serde::{Deserialize, Serialize};
16
17use crate::router::TaskTier;
18
19/// Outcome of a run, as far as routing quality is concerned.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum RunRoutingOutcome {
22    /// The run finished and the verify command passed without escalation.
23    VerifiedSuccess,
24    /// The run finished but only after escalating to a stronger model.
25    Escalated,
26    /// The run failed (verification never passed, or the chain errored out).
27    Failed,
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct TierStats {
32    #[serde(default)]
33    pub verified_success: u32,
34    #[serde(default)]
35    pub escalated: u32,
36    #[serde(default)]
37    pub failed: u32,
38}
39
40impl TierStats {
41    fn samples(&self) -> u32 {
42        self.verified_success + self.escalated + self.failed
43    }
44
45    /// Halve all counters so recent runs dominate — the repo (and the
46    /// available models) change over time.
47    fn decay(&mut self) {
48        self.verified_success /= 2;
49        self.escalated /= 2;
50        self.failed /= 2;
51    }
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
55pub struct RepoRoutingMemory {
56    #[serde(default)]
57    pub tiers: HashMap<String, TierStats>,
58    #[serde(skip)]
59    path: Option<PathBuf>,
60}
61
62/// Minimum verified samples for a tier before its stats may influence routing.
63const MIN_SAMPLES: u32 = 4;
64/// Share of (escalated + failed) runs at which the starting tier gets bumped.
65const BUMP_THRESHOLD: f64 = 0.5;
66/// Counters decay once a tier accumulates this many samples.
67const DECAY_AT: u32 = 50;
68
69impl RepoRoutingMemory {
70    fn file_path(workspace_root: &Path) -> PathBuf {
71        workspace_root.join(".sparrow").join("routing_memory.json")
72    }
73
74    /// Load the repo's routing memory; a missing or corrupt file is an empty
75    /// memory, never an error.
76    pub fn load(workspace_root: &Path) -> Self {
77        let path = Self::file_path(workspace_root);
78        let mut mem: RepoRoutingMemory = std::fs::read_to_string(&path)
79            .ok()
80            .and_then(|s| serde_json::from_str(&s).ok())
81            .unwrap_or_default();
82        mem.path = Some(path);
83        mem
84    }
85
86    /// Record a run outcome for `tier` and persist (best-effort).
87    pub fn record(&mut self, tier: &TaskTier, outcome: RunRoutingOutcome) {
88        let stats = self.tiers.entry(tier.as_str().to_string()).or_default();
89        match outcome {
90            RunRoutingOutcome::VerifiedSuccess => stats.verified_success += 1,
91            RunRoutingOutcome::Escalated => stats.escalated += 1,
92            RunRoutingOutcome::Failed => stats.failed += 1,
93        }
94        if stats.samples() >= DECAY_AT {
95            stats.decay();
96        }
97        self.save();
98    }
99
100    fn save(&self) {
101        let Some(path) = &self.path else { return };
102        if let Some(dir) = path.parent() {
103            let _ = std::fs::create_dir_all(dir);
104        }
105        if let Ok(json) = serde_json::to_string_pretty(self) {
106            let _ = std::fs::write(path, json);
107        }
108    }
109
110    /// If this repo's history says `tier` mostly fails or escalates, return
111    /// the tier the run should START at instead.
112    pub fn suggest_bump(&self, tier: &TaskTier) -> Option<TaskTier> {
113        let stats = self.tiers.get(tier.as_str())?;
114        if stats.samples() < MIN_SAMPLES {
115            return None;
116        }
117        let bad = (stats.escalated + stats.failed) as f64;
118        if bad / stats.samples() as f64 >= BUMP_THRESHOLD {
119            next_tier_up(tier)
120        } else {
121            None
122        }
123    }
124}
125
126fn next_tier_up(tier: &TaskTier) -> Option<TaskTier> {
127    match tier {
128        TaskTier::Trivial => Some(TaskTier::Small),
129        TaskTier::Small => Some(TaskTier::Medium),
130        TaskTier::Medium => Some(TaskTier::Hard),
131        // Nothing above Hard; Vision is orthogonal, not a strength tier.
132        TaskTier::Hard | TaskTier::Vision => None,
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    fn mem() -> RepoRoutingMemory {
141        RepoRoutingMemory::default()
142    }
143
144    #[test]
145    fn no_bump_without_enough_samples() {
146        let mut m = mem();
147        for _ in 0..3 {
148            m.tiers.entry("small".into()).or_default().failed += 1;
149        }
150        assert_eq!(m.suggest_bump(&TaskTier::Small), None);
151    }
152
153    #[test]
154    fn bump_when_majority_fails() {
155        let mut m = mem();
156        let s = m.tiers.entry("small".into()).or_default();
157        s.failed = 2;
158        s.escalated = 1;
159        s.verified_success = 1;
160        assert_eq!(m.suggest_bump(&TaskTier::Small), Some(TaskTier::Medium));
161    }
162
163    #[test]
164    fn no_bump_when_mostly_verified() {
165        let mut m = mem();
166        let s = m.tiers.entry("medium".into()).or_default();
167        s.verified_success = 5;
168        s.failed = 1;
169        assert_eq!(m.suggest_bump(&TaskTier::Medium), None);
170    }
171
172    #[test]
173    fn hard_has_no_higher_tier() {
174        let mut m = mem();
175        let s = m.tiers.entry("hard".into()).or_default();
176        s.failed = 10;
177        assert_eq!(m.suggest_bump(&TaskTier::Hard), None);
178    }
179
180    #[test]
181    fn decay_halves_counters() {
182        let mut s = TierStats {
183            verified_success: 30,
184            escalated: 10,
185            failed: 10,
186        };
187        s.decay();
188        assert_eq!(s.verified_success, 15);
189        assert_eq!(s.samples(), 25);
190    }
191
192    #[test]
193    fn load_missing_file_is_empty() {
194        let dir = std::env::temp_dir().join("sparrow-test-no-such-dir-xyz");
195        let m = RepoRoutingMemory::load(&dir);
196        assert!(m.tiers.is_empty());
197    }
198
199    #[test]
200    fn record_and_reload_roundtrip() {
201        let dir = std::env::temp_dir().join(format!("sparrow-rm-{}", std::process::id()));
202        let _ = std::fs::create_dir_all(&dir);
203        let mut m = RepoRoutingMemory::load(&dir);
204        m.record(&TaskTier::Medium, RunRoutingOutcome::Escalated);
205        m.record(&TaskTier::Medium, RunRoutingOutcome::VerifiedSuccess);
206        let reloaded = RepoRoutingMemory::load(&dir);
207        let stats = reloaded.tiers.get("medium").expect("medium stats");
208        assert_eq!(stats.escalated, 1);
209        assert_eq!(stats.verified_success, 1);
210        let _ = std::fs::remove_dir_all(&dir);
211    }
212}