Skip to main content

zeph_orchestration/
adaptorch.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `AdaptOrch` — bandit-driven topology advisor for the LLM planner.
5//!
6//! [`TopologyAdvisor`] runs before [`crate::planner::LlmPlanner`] and injects a
7//! soft topology hint into the planner system prompt. A 16-arm Thompson Beta-bandit
8//! (4 task classes × 4 topology hints) learns which hint works best for each class.
9//!
10//! State is persisted on shutdown alongside the Thompson router state; `record_outcome`
11//! is synchronous and never spawns a task.
12
13use std::collections::HashMap;
14use std::io;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::time::Duration;
19
20use parking_lot::Mutex;
21use rand::SeedableRng as _;
22use rand_distr::{Beta, Distribution};
23use serde::{Deserialize, Serialize};
24use zeph_llm::any::AnyProvider;
25use zeph_llm::provider::{LlmProvider, Message, Role};
26
27/// Task decomposition shape inferred from the user goal text.
28///
29/// `Unknown` absorbs all unclassified cases and defaults the hint to [`TopologyHint::Hybrid`].
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum TaskClass {
33    /// Fan-out work with no cross-dependencies (research, comparisons, multi-source queries).
34    IndependentBatch,
35    /// Strict ordering: build → test → deploy, ETL pipelines.
36    SequentialPipeline,
37    /// Tree decomposition: subgoal expansion, recursive analysis.
38    HierarchicalDecomp,
39    /// Unknown / fallback; defaults hint to `Hybrid`.
40    Unknown,
41}
42
43/// Soft topology hint injected into the planner system prompt.
44///
45/// Advisory only — `TopologyClassifier::analyze` still runs on the produced graph.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum TopologyHint {
49    /// Maximize independent tasks; avoid unnecessary `depends_on` chains.
50    Parallel,
51    /// Prefer a strict linear chain unless impossible.
52    Sequential,
53    /// Decompose into subgoals; expect 2–3 levels of depth.
54    Hierarchical,
55    /// No constraint (free planning). Default for `Unknown` class.
56    Hybrid,
57}
58
59impl TopologyHint {
60    /// One-sentence injection appended to the planner system prompt.
61    /// Returns `None` for `Hybrid` (no injection).
62    #[must_use]
63    pub fn prompt_sentence(self) -> Option<&'static str> {
64        match self {
65            Self::Parallel => {
66                Some("Prefer maximizing parallel tasks; avoid unnecessary `depends_on` chains.")
67            }
68            Self::Sequential => Some(
69                "This goal is naturally a pipeline; produce a strict linear chain unless \
70                 impossible.",
71            ),
72            Self::Hierarchical => {
73                Some("Decompose this goal into subgoals; expect 2–3 levels of depth.")
74            }
75            Self::Hybrid => None,
76        }
77    }
78}
79
80/// Result of a `TopologyAdvisor::recommend` call.
81#[derive(Debug, Clone)]
82pub struct AdvisorVerdict {
83    /// Inferred task class for the goal.
84    pub class: TaskClass,
85    /// Sampled topology hint.
86    pub hint: TopologyHint,
87    /// `true` if Thompson exploited the best-known arm (vs. explored).
88    pub exploit: bool,
89    /// `true` if classification failed and `Hybrid` was used as the default.
90    pub fallback: bool,
91}
92
93/// Per-(class, hint) arm for the Beta-Thompson bandit.
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct BetaDist {
96    pub alpha: f64,
97    pub beta: f64,
98}
99
100impl Default for BetaDist {
101    fn default() -> Self {
102        Self {
103            alpha: 1.0,
104            beta: 1.0,
105        }
106    }
107}
108
109impl BetaDist {
110    fn sample<R: rand::Rng>(&self, rng: &mut R) -> f64 {
111        let a = self.alpha.max(1e-6);
112        let b = self.beta.max(1e-6);
113        // Safety: a and b are clamped to ≥1e-6, so Beta::new never fails.
114        Beta::new(a, b)
115            .expect("clamped values ≥1e-6 are always valid Beta params")
116            .sample(rng)
117    }
118}
119
120/// Versioned on-disk format for `AdaptOrch` state.
121#[derive(Debug, Serialize, Deserialize)]
122struct PersistState {
123    version: u32,
124    arms: HashMap<String, BetaDist>,
125}
126
127/// Session-level metrics for `AdaptOrch` (atomic, not persisted).
128#[derive(Debug, Default)]
129pub struct AdaptOrchMetrics {
130    /// Total classify calls.
131    pub classify_calls: AtomicU64,
132    /// Calls that timed out or failed — fell back to `Unknown`.
133    pub classify_timeouts: AtomicU64,
134    /// Hint distribution.
135    pub hint_parallel: AtomicU64,
136    pub hint_sequential: AtomicU64,
137    pub hint_hierarchical: AtomicU64,
138    pub hint_hybrid: AtomicU64,
139    /// Times `record_outcome` was called.
140    pub outcomes_recorded: AtomicU64,
141}
142
143fn arm_key(class: TaskClass, hint: TopologyHint) -> String {
144    let c = match class {
145        TaskClass::IndependentBatch => "independent_batch",
146        TaskClass::SequentialPipeline => "sequential_pipeline",
147        TaskClass::HierarchicalDecomp => "hierarchical_decomp",
148        TaskClass::Unknown => "unknown",
149    };
150    let h = match hint {
151        TopologyHint::Parallel => "parallel",
152        TopologyHint::Sequential => "sequential",
153        TopologyHint::Hierarchical => "hierarchical",
154        TopologyHint::Hybrid => "hybrid",
155    };
156    format!("{c}:{h}")
157}
158
159const ALL_HINTS: [TopologyHint; 4] = [
160    TopologyHint::Parallel,
161    TopologyHint::Sequential,
162    TopologyHint::Hierarchical,
163    TopologyHint::Hybrid,
164];
165
166/// Bandit-driven topology advisor.
167///
168/// Classifies the user goal into a [`TaskClass`] via a cheap LLM call, samples
169/// the best [`TopologyHint`] for that class via Thompson sampling, and injects
170/// one sentence into the planner system prompt. Outcomes are recorded synchronously
171/// and persisted once on shutdown.
172pub struct TopologyAdvisor {
173    classifier: Arc<AnyProvider>,
174    arms: Arc<Mutex<HashMap<(TaskClass, TopologyHint), BetaDist>>>,
175    state_path: PathBuf,
176    classify_timeout: Duration,
177    pub metrics: Arc<AdaptOrchMetrics>,
178    rng: Arc<Mutex<rand::rngs::SmallRng>>,
179}
180
181impl std::fmt::Debug for TopologyAdvisor {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.debug_struct("TopologyAdvisor")
184            .field("state_path", &self.state_path)
185            .field("classify_timeout", &self.classify_timeout)
186            .finish_non_exhaustive()
187    }
188}
189
190impl TopologyAdvisor {
191    /// Construct a new advisor. Loads persisted state from `state_path` if present.
192    ///
193    /// When `state_path` is an empty string, the default path
194    /// `~/.zeph/adaptorch_state.json` is used.
195    #[must_use]
196    pub fn new(
197        classifier: Arc<AnyProvider>,
198        state_path: impl Into<PathBuf>,
199        classify_timeout: Duration,
200    ) -> Self {
201        let path: PathBuf = {
202            let p = state_path.into();
203            if p.as_os_str().is_empty() {
204                Self::default_path()
205            } else {
206                p
207            }
208        };
209        let arms = load_arms(&path);
210        Self {
211            classifier,
212            arms: Arc::new(Mutex::new(arms)),
213            state_path: path,
214            classify_timeout,
215            metrics: Arc::new(AdaptOrchMetrics::default()),
216            rng: Arc::new(Mutex::new(rand::rngs::SmallRng::from_rng(&mut rand::rng()))),
217        }
218    }
219
220    /// Default persistence path: `~/.zeph/adaptorch_state.json`.
221    #[must_use]
222    pub fn default_path() -> PathBuf {
223        dirs::home_dir()
224            .unwrap_or_else(|| PathBuf::from("."))
225            .join(".zeph")
226            .join("adaptorch_state.json")
227    }
228
229    /// Classify the goal and sample the best topology hint for this turn.
230    ///
231    /// Classification failures fall back to `TaskClass::Unknown` + `TopologyHint::Hybrid`.
232    pub async fn recommend(&self, goal: &str) -> AdvisorVerdict {
233        self.metrics.classify_calls.fetch_add(1, Ordering::Relaxed);
234
235        let class = tokio::time::timeout(self.classify_timeout, self.classify(goal))
236            .await
237            .unwrap_or_else(|_| {
238                self.metrics
239                    .classify_timeouts
240                    .fetch_add(1, Ordering::Relaxed);
241                TaskClass::Unknown
242            });
243
244        let fallback = class == TaskClass::Unknown;
245        let (hint, exploit) = self.sample_arm(class);
246
247        match hint {
248            TopologyHint::Parallel => {
249                self.metrics.hint_parallel.fetch_add(1, Ordering::Relaxed);
250            }
251            TopologyHint::Sequential => {
252                self.metrics.hint_sequential.fetch_add(1, Ordering::Relaxed);
253            }
254            TopologyHint::Hierarchical => {
255                self.metrics
256                    .hint_hierarchical
257                    .fetch_add(1, Ordering::Relaxed);
258            }
259            TopologyHint::Hybrid => {
260                self.metrics.hint_hybrid.fetch_add(1, Ordering::Relaxed);
261            }
262        }
263
264        AdvisorVerdict {
265            class,
266            hint,
267            exploit,
268            fallback,
269        }
270    }
271
272    /// Record the binary outcome of a plan guided by `(class, hint)`.
273    ///
274    /// **Synchronous** — acquires the in-memory `Mutex`, updates two `f64` counters, drops
275    /// the guard. Never spawns. Never persists. Persistence happens in [`save`](Self::save).
276    pub fn record_outcome(&self, class: TaskClass, hint: TopologyHint, reward: f64) {
277        self.metrics
278            .outcomes_recorded
279            .fetch_add(1, Ordering::Relaxed);
280        let key = (class, hint);
281        let mut arms = self.arms.lock();
282        let arm = arms.entry(key).or_default();
283        if reward >= 1.0 {
284            arm.alpha += 1.0;
285        } else {
286            arm.beta += 1.0;
287        }
288    }
289
290    /// Persist the Beta-arm table to `state_path` atomically.
291    ///
292    /// Called from the agent shutdown hook (once per process), mirroring
293    /// `AnyProvider::save_router_state`. Failures are logged and swallowed.
294    ///
295    /// # Errors
296    ///
297    /// Returns `io::Error` when the write fails.
298    pub fn save(&self) -> io::Result<()> {
299        let arms_map: HashMap<String, BetaDist> = self
300            .arms
301            .lock()
302            .iter()
303            .map(|((class, hint), dist)| (arm_key(*class, *hint), dist.clone()))
304            .collect();
305
306        let state = PersistState {
307            version: 1,
308            arms: arms_map,
309        };
310
311        let json = serde_json::to_string_pretty(&state).map_err(io::Error::other)?;
312
313        if let Some(parent) = self.state_path.parent() {
314            std::fs::create_dir_all(parent)?;
315        }
316
317        atomic_write(&self.state_path, json.as_bytes())?;
318        Ok(())
319    }
320
321    // ─── private helpers ─────────────────────────────────────────────────────
322
323    async fn classify(&self, goal: &str) -> TaskClass {
324        let truncated: String = goal.chars().take(400).collect();
325        let system = "\
326You classify task decomposition patterns. Read the goal and answer with one of:\n\
327- independent_batch  — fan-out work with no cross-deps (research, comparisons, multi-source queries)\n\
328- sequential_pipeline — strict ordering (build → test → deploy, ETL)\n\
329- hierarchical_decomp — tree of subgoals, divide-and-conquer\n\
330- unknown            — does not clearly fit any of the above\n\n\
331Respond with a single JSON object:\n\
332{\"class\":\"...\",\"reason\":\"<one sentence>\"}";
333
334        let messages = vec![
335            Message::from_legacy(Role::System, system),
336            Message::from_legacy(Role::User, format!("Goal:\n{truncated}")),
337        ];
338
339        let raw = match self.classifier.chat(&messages).await {
340            Ok(r) => r,
341            Err(e) => {
342                tracing::warn!(error = %e, "adaptorch: classify call failed");
343                return TaskClass::Unknown;
344            }
345        };
346
347        parse_class(&raw)
348    }
349
350    fn sample_arm(&self, class: TaskClass) -> (TopologyHint, bool) {
351        if class == TaskClass::Unknown {
352            return (TopologyHint::Hybrid, false);
353        }
354        // Clone arm entries under arms lock, then release before acquiring rng lock.
355        let arm_entries: Vec<(TopologyHint, BetaDist)> = {
356            let arms = self.arms.lock();
357            ALL_HINTS
358                .iter()
359                .map(|hint| {
360                    (
361                        *hint,
362                        arms.get(&(class, *hint)).cloned().unwrap_or_default(),
363                    )
364                })
365                .collect()
366        };
367        let mut rng = self.rng.lock();
368        let scores: Vec<(TopologyHint, f64)> = arm_entries
369            .iter()
370            .map(|(hint, dist)| (*hint, dist.sample(&mut *rng)))
371            .collect();
372
373        let (hint, score) = scores
374            .iter()
375            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
376            .map_or((TopologyHint::Hybrid, 0.0), |(h, s)| (*h, *s));
377
378        // "exploit" = the arm's mean (alpha / (alpha+beta)) aligns with the sampled score
379        let arm = arm_entries
380            .iter()
381            .find(|(h, _)| *h == hint)
382            .map(|(_, d)| d.clone())
383            .unwrap_or_default();
384        let mean = arm.alpha / (arm.alpha + arm.beta);
385        let exploit = (score - mean).abs() < 0.15;
386
387        (hint, exploit)
388    }
389}
390
391/// Parse the classifier's JSON response into a [`TaskClass`].
392fn parse_class(raw: &str) -> TaskClass {
393    // Try direct JSON parse first.
394    if let Ok(val) = serde_json::from_str::<serde_json::Value>(raw)
395        && let Some(class) = val.get("class").and_then(|c| c.as_str())
396    {
397        return str_to_class(class);
398    }
399    // Extract first {...} substring.
400    if let Some(start) = raw.find('{')
401        && let Some(end) = raw[start..].find('}')
402    {
403        let chunk = &raw[start..=start + end];
404        if let Ok(val) = serde_json::from_str::<serde_json::Value>(chunk)
405            && let Some(class) = val.get("class").and_then(|c| c.as_str())
406        {
407            return str_to_class(class);
408        }
409    }
410    // Substring scan.
411    for variant in &[
412        "independent_batch",
413        "sequential_pipeline",
414        "hierarchical_decomp",
415        "unknown",
416    ] {
417        if raw.contains(variant) {
418            return str_to_class(variant);
419        }
420    }
421    TaskClass::Unknown
422}
423
424fn str_to_class(s: &str) -> TaskClass {
425    match s {
426        "independent_batch" => TaskClass::IndependentBatch,
427        "sequential_pipeline" => TaskClass::SequentialPipeline,
428        "hierarchical_decomp" => TaskClass::HierarchicalDecomp,
429        _ => TaskClass::Unknown,
430    }
431}
432
433fn load_arms(path: &std::path::Path) -> HashMap<(TaskClass, TopologyHint), BetaDist> {
434    let mut arms = default_arms();
435    let Ok(data) = std::fs::read_to_string(path) else {
436        return arms;
437    };
438    let Ok(state) = serde_json::from_str::<PersistState>(&data) else {
439        tracing::warn!(path = %path.display(), "adaptorch: failed to parse state file, using defaults");
440        return arms;
441    };
442    if state.version != 1 {
443        tracing::warn!(
444            version = state.version,
445            "adaptorch: unknown state version, using defaults"
446        );
447        return arms;
448    }
449    for (key_str, dist) in state.arms {
450        let mut parts = key_str.splitn(2, ':');
451        let (Some(c), Some(h)) = (parts.next(), parts.next()) else {
452            continue;
453        };
454        let class = str_to_class(c);
455        let hint = match h {
456            "parallel" => TopologyHint::Parallel,
457            "sequential" => TopologyHint::Sequential,
458            "hierarchical" => TopologyHint::Hierarchical,
459            "hybrid" => TopologyHint::Hybrid,
460            _ => continue,
461        };
462        arms.insert((class, hint), dist);
463    }
464    arms
465}
466
467fn default_arms() -> HashMap<(TaskClass, TopologyHint), BetaDist> {
468    let classes = [
469        TaskClass::IndependentBatch,
470        TaskClass::SequentialPipeline,
471        TaskClass::HierarchicalDecomp,
472        TaskClass::Unknown,
473    ];
474    let mut map = HashMap::new();
475    for class in classes {
476        for hint in ALL_HINTS {
477            map.insert((class, hint), BetaDist::default());
478        }
479    }
480    map
481}
482
483fn atomic_write(path: &std::path::Path, data: &[u8]) -> io::Result<()> {
484    zeph_common::fs_secure::atomic_write_private(path, data)
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn parse_class_direct_json() {
493        let json = r#"{"class":"independent_batch","reason":"fan-out"}"#;
494        assert_eq!(parse_class(json), TaskClass::IndependentBatch);
495    }
496
497    #[test]
498    fn parse_class_fallback_substring() {
499        assert_eq!(
500            parse_class("  sequential_pipeline "),
501            TaskClass::SequentialPipeline
502        );
503    }
504
505    #[test]
506    fn parse_class_unknown_for_garbage() {
507        assert_eq!(parse_class("no idea"), TaskClass::Unknown);
508    }
509
510    #[test]
511    fn topology_hint_sentence_hybrid_is_none() {
512        assert!(TopologyHint::Hybrid.prompt_sentence().is_none());
513    }
514
515    #[test]
516    fn record_outcome_updates_alpha_beta() {
517        use std::sync::Arc;
518        use zeph_llm::any::AnyProvider;
519        let mock = zeph_llm::mock::MockProvider::default();
520        let advisor = TopologyAdvisor::new(
521            Arc::new(AnyProvider::Mock(mock)),
522            PathBuf::new(),
523            Duration::from_secs(4),
524        );
525        advisor.record_outcome(TaskClass::IndependentBatch, TopologyHint::Parallel, 1.0);
526        advisor.record_outcome(TaskClass::IndependentBatch, TopologyHint::Parallel, 0.0);
527        let arms = advisor.arms.lock();
528        let arm = arms
529            .get(&(TaskClass::IndependentBatch, TopologyHint::Parallel))
530            .unwrap();
531        assert!((arm.alpha - 2.0).abs() < f64::EPSILON);
532        assert!((arm.beta - 2.0).abs() < f64::EPSILON);
533    }
534
535    #[tokio::test]
536    async fn recommend_with_valid_json_returns_correct_class() {
537        use zeph_llm::any::AnyProvider;
538        use zeph_llm::mock::MockProvider;
539        let mock = MockProvider::with_responses(vec![
540            r#"{"class":"sequential_pipeline","reason":"strict ordering"}"#.into(),
541        ]);
542        let advisor = TopologyAdvisor::new(
543            Arc::new(AnyProvider::Mock(mock)),
544            PathBuf::new(),
545            Duration::from_secs(4),
546        );
547        let verdict = advisor
548            .recommend("Build, test, then deploy the service")
549            .await;
550        assert_eq!(verdict.class, TaskClass::SequentialPipeline);
551        assert!(advisor.metrics.classify_timeouts.load(Ordering::Relaxed) == 0);
552    }
553
554    #[tokio::test]
555    async fn recommend_timeout_returns_unknown_and_increments_metric() {
556        use zeph_llm::any::AnyProvider;
557        use zeph_llm::mock::MockProvider;
558        // Delay longer than classify_timeout so the call times out.
559        let mut mock = MockProvider::default();
560        mock.delay_ms = 200;
561        mock.default_response = r#"{"class":"sequential_pipeline","reason":"x"}"#.into();
562        let advisor = TopologyAdvisor::new(
563            Arc::new(AnyProvider::Mock(mock)),
564            PathBuf::new(),
565            Duration::from_millis(50), // short timeout
566        );
567        let verdict = advisor.recommend("any goal").await;
568        assert_eq!(verdict.class, TaskClass::Unknown);
569        assert_eq!(advisor.metrics.classify_timeouts.load(Ordering::Relaxed), 1);
570    }
571
572    #[test]
573    fn sample_arm_favours_reinforced_hint() {
574        use zeph_llm::any::AnyProvider;
575        let mock = zeph_llm::mock::MockProvider::default();
576        let advisor = TopologyAdvisor::new(
577            Arc::new(AnyProvider::Mock(mock)),
578            PathBuf::new(),
579            Duration::from_secs(4),
580        );
581        // Reinforce Sequential 20 times for SequentialPipeline class.
582        for _ in 0..20 {
583            advisor.record_outcome(TaskClass::SequentialPipeline, TopologyHint::Sequential, 1.0);
584        }
585        // Sample 50 times and verify Sequential wins most often.
586        let mut counts = std::collections::HashMap::new();
587        for _ in 0..50 {
588            let (hint, _) = advisor.sample_arm(TaskClass::SequentialPipeline);
589            *counts.entry(hint).or_insert(0u32) += 1;
590        }
591        let sequential_count = counts.get(&TopologyHint::Sequential).copied().unwrap_or(0);
592        assert!(
593            sequential_count > 30,
594            "expected Sequential to win >30/50 times after reinforcement, got {sequential_count}"
595        );
596    }
597
598    #[test]
599    fn persistence_round_trip() {
600        use zeph_llm::any::AnyProvider;
601        let dir = tempfile::tempdir().unwrap();
602        let path = dir.path().join("state.json");
603        {
604            let mock = zeph_llm::mock::MockProvider::default();
605            let advisor = TopologyAdvisor::new(
606                Arc::new(AnyProvider::Mock(mock)),
607                path.clone(),
608                Duration::from_secs(4),
609            );
610            advisor.record_outcome(TaskClass::SequentialPipeline, TopologyHint::Sequential, 1.0);
611            advisor.save().unwrap();
612        }
613        {
614            let mock = zeph_llm::mock::MockProvider::default();
615            let advisor = TopologyAdvisor::new(
616                Arc::new(AnyProvider::Mock(mock)),
617                path.clone(),
618                Duration::from_secs(4),
619            );
620            let arms = advisor.arms.lock();
621            let arm = arms
622                .get(&(TaskClass::SequentialPipeline, TopologyHint::Sequential))
623                .unwrap();
624            // alpha was 1.0 (default) + 1 success = 2.0
625            assert!((arm.alpha - 2.0).abs() < f64::EPSILON);
626        }
627    }
628}