Skip to main content

zeph_orchestration/
adaptorch.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `AdaptOrch` — bandit-driven topology advisor for the LLM planner.
5//!
6//! [`TopologyAdvisor`] runs before [`crate::planner::LlmPlanner`] and injects a
7//! soft topology hint into the planner system prompt. A 16-arm Thompson Beta-bandit
8//! (4 task classes × 4 topology hints) learns which hint works best for each class.
9//!
10//! State is persisted on shutdown alongside the Thompson router state; `record_outcome`
11//! is synchronous and never spawns a task.
12
13use std::collections::HashMap;
14use std::io;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::time::Duration;
19
20use tracing::Instrument as _;
21
22use parking_lot::Mutex;
23use rand::SeedableRng as _;
24use rand_distr::{Beta, Distribution};
25use serde::{Deserialize, Serialize};
26use zeph_llm::any::AnyProvider;
27use zeph_llm::provider::{LlmProvider, Message, Role};
28
29/// Task decomposition shape inferred from the user goal text.
30///
31/// `Unknown` absorbs all unclassified cases and defaults the hint to [`TopologyHint::Hybrid`].
32#[non_exhaustive]
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case")]
35pub enum TaskClass {
36    /// Fan-out work with no cross-dependencies (research, comparisons, multi-source queries).
37    IndependentBatch,
38    /// Strict ordering: build → test → deploy, ETL pipelines.
39    SequentialPipeline,
40    /// Tree decomposition: subgoal expansion, recursive analysis.
41    HierarchicalDecomp,
42    /// Unknown / fallback; defaults hint to `Hybrid`.
43    Unknown,
44}
45
46/// Soft topology hint injected into the planner system prompt.
47///
48/// Advisory only — `TopologyClassifier::analyze` still runs on the produced graph.
49#[non_exhaustive]
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
51#[serde(rename_all = "snake_case")]
52pub enum TopologyHint {
53    /// Maximize independent tasks; avoid unnecessary `depends_on` chains.
54    Parallel,
55    /// Prefer a strict linear chain unless impossible.
56    Sequential,
57    /// Decompose into subgoals; expect 2–3 levels of depth.
58    Hierarchical,
59    /// No constraint (free planning). Default for `Unknown` class.
60    Hybrid,
61}
62
63impl TopologyHint {
64    /// One-sentence injection appended to the planner system prompt.
65    /// Returns `None` for `Hybrid` (no injection).
66    #[must_use]
67    pub fn prompt_sentence(self) -> Option<&'static str> {
68        match self {
69            Self::Parallel => {
70                Some("Prefer maximizing parallel tasks; avoid unnecessary `depends_on` chains.")
71            }
72            Self::Sequential => Some(
73                "This goal is naturally a pipeline; produce a strict linear chain unless \
74                 impossible.",
75            ),
76            Self::Hierarchical => {
77                Some("Decompose this goal into subgoals; expect 2–3 levels of depth.")
78            }
79            Self::Hybrid => None,
80        }
81    }
82}
83
84/// Result of a `TopologyAdvisor::recommend` call.
85#[derive(Debug, Clone)]
86pub struct AdvisorVerdict {
87    /// Inferred task class for the goal.
88    pub class: TaskClass,
89    /// Sampled topology hint.
90    pub hint: TopologyHint,
91    /// `true` if Thompson exploited the best-known arm (vs. explored).
92    pub exploit: bool,
93    /// `true` if classification failed and `Hybrid` was used as the default.
94    pub fallback: bool,
95}
96
97/// Per-(class, hint) arm for the Beta-Thompson bandit.
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99pub struct BetaDist {
100    pub alpha: f64,
101    pub beta: f64,
102}
103
104impl Default for BetaDist {
105    fn default() -> Self {
106        Self {
107            alpha: 1.0,
108            beta: 1.0,
109        }
110    }
111}
112
113impl BetaDist {
114    fn sample<R: rand::Rng>(&self, rng: &mut R) -> f64 {
115        let a = self.alpha.max(1e-6);
116        let b = self.beta.max(1e-6);
117        // Safety: a and b are clamped to ≥1e-6, so Beta::new never fails.
118        Beta::new(a, b)
119            .expect("clamped values ≥1e-6 are always valid Beta params")
120            .sample(rng)
121    }
122}
123
124/// Versioned on-disk format for `AdaptOrch` state.
125#[derive(Debug, Serialize, Deserialize)]
126struct PersistState {
127    version: u32,
128    arms: HashMap<String, BetaDist>,
129}
130
131/// Session-level metrics for `AdaptOrch` (atomic, not persisted).
132#[derive(Debug, Default)]
133pub struct AdaptOrchMetrics {
134    /// Total classify calls.
135    pub classify_calls: AtomicU64,
136    /// Calls that timed out or failed — fell back to `Unknown`.
137    pub classify_timeouts: AtomicU64,
138    /// Hint distribution.
139    pub hint_parallel: AtomicU64,
140    pub hint_sequential: AtomicU64,
141    pub hint_hierarchical: AtomicU64,
142    pub hint_hybrid: AtomicU64,
143    /// Times `record_outcome` was called.
144    pub outcomes_recorded: AtomicU64,
145}
146
147fn arm_key(class: TaskClass, hint: TopologyHint) -> String {
148    let c = match class {
149        TaskClass::IndependentBatch => "independent_batch",
150        TaskClass::SequentialPipeline => "sequential_pipeline",
151        TaskClass::HierarchicalDecomp => "hierarchical_decomp",
152        TaskClass::Unknown => "unknown",
153    };
154    let h = match hint {
155        TopologyHint::Parallel => "parallel",
156        TopologyHint::Sequential => "sequential",
157        TopologyHint::Hierarchical => "hierarchical",
158        TopologyHint::Hybrid => "hybrid",
159    };
160    format!("{c}:{h}")
161}
162
163const ALL_HINTS: [TopologyHint; 4] = [
164    TopologyHint::Parallel,
165    TopologyHint::Sequential,
166    TopologyHint::Hierarchical,
167    TopologyHint::Hybrid,
168];
169
170/// Bandit-driven topology advisor.
171///
172/// Classifies the user goal into a [`TaskClass`] via a cheap LLM call, samples
173/// the best [`TopologyHint`] for that class via Thompson sampling, and injects
174/// one sentence into the planner system prompt. Outcomes are recorded synchronously
175/// and persisted once on shutdown.
176pub struct TopologyAdvisor {
177    classifier: Arc<AnyProvider>,
178    arms: Arc<Mutex<HashMap<(TaskClass, TopologyHint), BetaDist>>>,
179    state_path: PathBuf,
180    classify_timeout: Duration,
181    pub metrics: Arc<AdaptOrchMetrics>,
182    rng: Arc<Mutex<rand::rngs::SmallRng>>,
183}
184
185impl std::fmt::Debug for TopologyAdvisor {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        f.debug_struct("TopologyAdvisor")
188            .field("state_path", &self.state_path)
189            .field("classify_timeout", &self.classify_timeout)
190            .finish_non_exhaustive()
191    }
192}
193
194impl TopologyAdvisor {
195    /// Construct a new advisor. Loads persisted state from `state_path` if present.
196    ///
197    /// When `state_path` is an empty string, the default path
198    /// `~/.zeph/adaptorch_state.json` is used.
199    #[must_use]
200    pub fn new(
201        classifier: Arc<AnyProvider>,
202        state_path: impl Into<PathBuf>,
203        classify_timeout: Duration,
204    ) -> Self {
205        let path: PathBuf = {
206            let p = state_path.into();
207            if p.as_os_str().is_empty() {
208                Self::default_path()
209            } else {
210                p
211            }
212        };
213        let arms = load_arms(&path);
214        Self {
215            classifier,
216            arms: Arc::new(Mutex::new(arms)),
217            state_path: path,
218            classify_timeout,
219            metrics: Arc::new(AdaptOrchMetrics::default()),
220            rng: Arc::new(Mutex::new(rand::rngs::SmallRng::from_rng(&mut rand::rng()))),
221        }
222    }
223
224    /// Default persistence path: `~/.zeph/adaptorch_state.json`.
225    #[must_use]
226    pub fn default_path() -> PathBuf {
227        dirs::home_dir()
228            .unwrap_or_else(|| PathBuf::from("."))
229            .join(".zeph")
230            .join("adaptorch_state.json")
231    }
232
233    /// Classify the goal and sample the best topology hint for this turn.
234    ///
235    /// Classification failures fall back to `TaskClass::Unknown` + `TopologyHint::Hybrid`.
236    pub async fn recommend(&self, goal: &str) -> AdvisorVerdict {
237        async move {
238            self.metrics.classify_calls.fetch_add(1, Ordering::Relaxed);
239
240            let class = tokio::time::timeout(self.classify_timeout, self.classify(goal))
241                .await
242                .unwrap_or_else(|_| {
243                    self.metrics
244                        .classify_timeouts
245                        .fetch_add(1, Ordering::Relaxed);
246                    TaskClass::Unknown
247                });
248
249            let fallback = class == TaskClass::Unknown;
250            let (hint, exploit) = self.sample_arm(class);
251
252            match hint {
253                TopologyHint::Parallel => {
254                    self.metrics.hint_parallel.fetch_add(1, Ordering::Relaxed);
255                }
256                TopologyHint::Sequential => {
257                    self.metrics.hint_sequential.fetch_add(1, Ordering::Relaxed);
258                }
259                TopologyHint::Hierarchical => {
260                    self.metrics
261                        .hint_hierarchical
262                        .fetch_add(1, Ordering::Relaxed);
263                }
264                TopologyHint::Hybrid => {
265                    self.metrics.hint_hybrid.fetch_add(1, Ordering::Relaxed);
266                }
267            }
268
269            AdvisorVerdict {
270                class,
271                hint,
272                exploit,
273                fallback,
274            }
275        }
276        .instrument(tracing::info_span!(
277            "orchestration.adaptorch.recommend",
278            goal_len = goal.len(),
279        ))
280        .await
281    }
282
283    /// Record the binary outcome of a plan guided by `(class, hint)`.
284    ///
285    /// **Synchronous** — acquires the in-memory `Mutex`, updates two `f64` counters, drops
286    /// the guard. Never spawns. Never persists. Persistence happens in [`save`](Self::save).
287    pub fn record_outcome(&self, class: TaskClass, hint: TopologyHint, reward: f64) {
288        self.metrics
289            .outcomes_recorded
290            .fetch_add(1, Ordering::Relaxed);
291        let key = (class, hint);
292        let mut arms = self.arms.lock();
293        let arm = arms.entry(key).or_default();
294        if reward >= 1.0 {
295            arm.alpha += 1.0;
296        } else {
297            arm.beta += 1.0;
298        }
299    }
300
301    /// Persist the Beta-arm table to `state_path` atomically.
302    ///
303    /// Called from the agent shutdown hook (once per process), mirroring
304    /// `AnyProvider::save_router_state`. Failures are logged and swallowed.
305    ///
306    /// # Errors
307    ///
308    /// Returns `io::Error` when the write fails.
309    pub fn save(&self) -> io::Result<()> {
310        let arms_map: HashMap<String, BetaDist> = self
311            .arms
312            .lock()
313            .iter()
314            .map(|((class, hint), dist)| (arm_key(*class, *hint), dist.clone()))
315            .collect();
316
317        let state = PersistState {
318            version: 1,
319            arms: arms_map,
320        };
321
322        let json = serde_json::to_string_pretty(&state).map_err(io::Error::other)?;
323
324        if let Some(parent) = self.state_path.parent() {
325            std::fs::create_dir_all(parent)?;
326        }
327
328        atomic_write(&self.state_path, json.as_bytes())?;
329        Ok(())
330    }
331
332    // ─── private helpers ─────────────────────────────────────────────────────
333
334    async fn classify(&self, goal: &str) -> TaskClass {
335        async move {
336            let truncated: String = goal.chars().take(400).collect();
337            let system = "\
338You classify task decomposition patterns. Read the goal and answer with one of:\n\
339- independent_batch  — fan-out work with no cross-deps (research, comparisons, multi-source queries)\n\
340- sequential_pipeline — strict ordering (build → test → deploy, ETL)\n\
341- hierarchical_decomp — tree of subgoals, divide-and-conquer\n\
342- unknown            — does not clearly fit any of the above\n\n\
343Respond with a single JSON object:\n\
344{\"class\":\"...\",\"reason\":\"<one sentence>\"}";
345
346            let messages = vec![
347                Message::from_legacy(Role::System, system),
348                Message::from_legacy(Role::User, format!("Goal:\n{truncated}")),
349            ];
350
351            let raw = match self.classifier.chat(&messages).await {
352                Ok(r) => r,
353                Err(e) => {
354                    tracing::warn!(error = %e, "adaptorch: classify call failed");
355                    return TaskClass::Unknown;
356                }
357            };
358
359            parse_class(&raw)
360        }
361        .instrument(tracing::info_span!(
362            "orchestration.adaptorch.classify",
363            goal_len = goal.len(),
364        ))
365        .await
366    }
367
368    fn sample_arm(&self, class: TaskClass) -> (TopologyHint, bool) {
369        if class == TaskClass::Unknown {
370            return (TopologyHint::Hybrid, false);
371        }
372        // Clone arm entries under arms lock, then release before acquiring rng lock.
373        let arm_entries: Vec<(TopologyHint, BetaDist)> = {
374            let arms = self.arms.lock();
375            ALL_HINTS
376                .iter()
377                .map(|hint| {
378                    (
379                        *hint,
380                        arms.get(&(class, *hint)).cloned().unwrap_or_default(),
381                    )
382                })
383                .collect()
384        };
385        let mut rng = self.rng.lock();
386        let scores: Vec<(TopologyHint, f64)> = arm_entries
387            .iter()
388            .map(|(hint, dist)| (*hint, dist.sample(&mut *rng)))
389            .collect();
390
391        let (hint, score) = scores
392            .iter()
393            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
394            .map_or((TopologyHint::Hybrid, 0.0), |(h, s)| (*h, *s));
395
396        // "exploit" = the arm's mean (alpha / (alpha+beta)) aligns with the sampled score
397        let arm = arm_entries
398            .iter()
399            .find(|(h, _)| *h == hint)
400            .map(|(_, d)| d.clone())
401            .unwrap_or_default();
402        let mean = arm.alpha / (arm.alpha + arm.beta);
403        let exploit = (score - mean).abs() < 0.15;
404
405        (hint, exploit)
406    }
407}
408
409/// Parse the classifier's JSON response into a [`TaskClass`].
410fn parse_class(raw: &str) -> TaskClass {
411    // Try direct JSON parse first.
412    if let Ok(val) = serde_json::from_str::<serde_json::Value>(raw)
413        && let Some(class) = val.get("class").and_then(|c| c.as_str())
414    {
415        return str_to_class(class);
416    }
417    // Extract first {...} substring.
418    if let Some(start) = raw.find('{')
419        && let Some(end) = raw[start..].find('}')
420    {
421        let chunk = &raw[start..=start + end];
422        if let Ok(val) = serde_json::from_str::<serde_json::Value>(chunk)
423            && let Some(class) = val.get("class").and_then(|c| c.as_str())
424        {
425            return str_to_class(class);
426        }
427    }
428    // Substring scan.
429    for variant in &[
430        "independent_batch",
431        "sequential_pipeline",
432        "hierarchical_decomp",
433        "unknown",
434    ] {
435        if raw.contains(variant) {
436            return str_to_class(variant);
437        }
438    }
439    TaskClass::Unknown
440}
441
442fn str_to_class(s: &str) -> TaskClass {
443    match s {
444        "independent_batch" => TaskClass::IndependentBatch,
445        "sequential_pipeline" => TaskClass::SequentialPipeline,
446        "hierarchical_decomp" => TaskClass::HierarchicalDecomp,
447        _ => TaskClass::Unknown,
448    }
449}
450
451fn load_arms(path: &std::path::Path) -> HashMap<(TaskClass, TopologyHint), BetaDist> {
452    let mut arms = default_arms();
453    let Ok(data) = std::fs::read_to_string(path) else {
454        return arms;
455    };
456    let Ok(state) = serde_json::from_str::<PersistState>(&data) else {
457        tracing::warn!(path = %path.display(), "adaptorch: failed to parse state file, using defaults");
458        return arms;
459    };
460    if state.version != 1 {
461        tracing::warn!(
462            version = state.version,
463            "adaptorch: unknown state version, using defaults"
464        );
465        return arms;
466    }
467    for (key_str, dist) in state.arms {
468        let mut parts = key_str.splitn(2, ':');
469        let (Some(c), Some(h)) = (parts.next(), parts.next()) else {
470            continue;
471        };
472        let class = str_to_class(c);
473        let hint = match h {
474            "parallel" => TopologyHint::Parallel,
475            "sequential" => TopologyHint::Sequential,
476            "hierarchical" => TopologyHint::Hierarchical,
477            "hybrid" => TopologyHint::Hybrid,
478            _ => continue,
479        };
480        arms.insert((class, hint), dist);
481    }
482    arms
483}
484
485fn default_arms() -> HashMap<(TaskClass, TopologyHint), BetaDist> {
486    let classes = [
487        TaskClass::IndependentBatch,
488        TaskClass::SequentialPipeline,
489        TaskClass::HierarchicalDecomp,
490        TaskClass::Unknown,
491    ];
492    let mut map = HashMap::new();
493    for class in classes {
494        for hint in ALL_HINTS {
495            map.insert((class, hint), BetaDist::default());
496        }
497    }
498    map
499}
500
501fn atomic_write(path: &std::path::Path, data: &[u8]) -> io::Result<()> {
502    zeph_common::fs_secure::atomic_write_private(path, data)
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn parse_class_direct_json() {
511        let json = r#"{"class":"independent_batch","reason":"fan-out"}"#;
512        assert_eq!(parse_class(json), TaskClass::IndependentBatch);
513    }
514
515    #[test]
516    fn parse_class_fallback_substring() {
517        assert_eq!(
518            parse_class("  sequential_pipeline "),
519            TaskClass::SequentialPipeline
520        );
521    }
522
523    #[test]
524    fn parse_class_unknown_for_garbage() {
525        assert_eq!(parse_class("no idea"), TaskClass::Unknown);
526    }
527
528    #[test]
529    fn topology_hint_sentence_hybrid_is_none() {
530        assert!(TopologyHint::Hybrid.prompt_sentence().is_none());
531    }
532
533    #[test]
534    fn record_outcome_updates_alpha_beta() {
535        use std::sync::Arc;
536        use zeph_llm::any::AnyProvider;
537        let mock = zeph_llm::mock::MockProvider::default();
538        let advisor = TopologyAdvisor::new(
539            Arc::new(AnyProvider::Mock(mock)),
540            PathBuf::new(),
541            Duration::from_secs(4),
542        );
543        advisor.record_outcome(TaskClass::IndependentBatch, TopologyHint::Parallel, 1.0);
544        advisor.record_outcome(TaskClass::IndependentBatch, TopologyHint::Parallel, 0.0);
545        let arms = advisor.arms.lock();
546        let arm = arms
547            .get(&(TaskClass::IndependentBatch, TopologyHint::Parallel))
548            .unwrap();
549        assert!((arm.alpha - 2.0).abs() < f64::EPSILON);
550        assert!((arm.beta - 2.0).abs() < f64::EPSILON);
551    }
552
553    #[tokio::test]
554    async fn recommend_with_valid_json_returns_correct_class() {
555        use zeph_llm::any::AnyProvider;
556        use zeph_llm::mock::MockProvider;
557        let mock = MockProvider::with_responses(vec![
558            r#"{"class":"sequential_pipeline","reason":"strict ordering"}"#.into(),
559        ]);
560        let advisor = TopologyAdvisor::new(
561            Arc::new(AnyProvider::Mock(mock)),
562            PathBuf::new(),
563            Duration::from_secs(4),
564        );
565        let verdict = advisor
566            .recommend("Build, test, then deploy the service")
567            .await;
568        assert_eq!(verdict.class, TaskClass::SequentialPipeline);
569        assert!(advisor.metrics.classify_timeouts.load(Ordering::Relaxed) == 0);
570    }
571
572    #[tokio::test]
573    async fn recommend_timeout_returns_unknown_and_increments_metric() {
574        use zeph_llm::any::AnyProvider;
575        use zeph_llm::mock::MockProvider;
576        // Delay longer than classify_timeout so the call times out.
577        let mut mock = MockProvider::default();
578        mock.delay_ms = 200;
579        mock.default_response = r#"{"class":"sequential_pipeline","reason":"x"}"#.into();
580        let advisor = TopologyAdvisor::new(
581            Arc::new(AnyProvider::Mock(mock)),
582            PathBuf::new(),
583            Duration::from_millis(50), // short timeout
584        );
585        let verdict = advisor.recommend("any goal").await;
586        assert_eq!(verdict.class, TaskClass::Unknown);
587        assert_eq!(advisor.metrics.classify_timeouts.load(Ordering::Relaxed), 1);
588    }
589
590    #[test]
591    fn sample_arm_favours_reinforced_hint() {
592        use zeph_llm::any::AnyProvider;
593        let mock = zeph_llm::mock::MockProvider::default();
594        let advisor = TopologyAdvisor::new(
595            Arc::new(AnyProvider::Mock(mock)),
596            PathBuf::new(),
597            Duration::from_secs(4),
598        );
599        // Reinforce Sequential 20 times for SequentialPipeline class.
600        for _ in 0..20 {
601            advisor.record_outcome(TaskClass::SequentialPipeline, TopologyHint::Sequential, 1.0);
602        }
603        // Sample 50 times and verify Sequential wins most often.
604        let mut counts = std::collections::HashMap::new();
605        for _ in 0..50 {
606            let (hint, _) = advisor.sample_arm(TaskClass::SequentialPipeline);
607            *counts.entry(hint).or_insert(0u32) += 1;
608        }
609        let sequential_count = counts.get(&TopologyHint::Sequential).copied().unwrap_or(0);
610        assert!(
611            sequential_count > 30,
612            "expected Sequential to win >30/50 times after reinforcement, got {sequential_count}"
613        );
614    }
615
616    #[test]
617    fn persistence_round_trip() {
618        use zeph_llm::any::AnyProvider;
619        let dir = tempfile::tempdir().unwrap();
620        let path = dir.path().join("state.json");
621        {
622            let mock = zeph_llm::mock::MockProvider::default();
623            let advisor = TopologyAdvisor::new(
624                Arc::new(AnyProvider::Mock(mock)),
625                path.clone(),
626                Duration::from_secs(4),
627            );
628            advisor.record_outcome(TaskClass::SequentialPipeline, TopologyHint::Sequential, 1.0);
629            advisor.save().unwrap();
630        }
631        {
632            let mock = zeph_llm::mock::MockProvider::default();
633            let advisor = TopologyAdvisor::new(
634                Arc::new(AnyProvider::Mock(mock)),
635                path.clone(),
636                Duration::from_secs(4),
637            );
638            let arms = advisor.arms.lock();
639            let arm = arms
640                .get(&(TaskClass::SequentialPipeline, TopologyHint::Sequential))
641                .unwrap();
642            // alpha was 1.0 (default) + 1 success = 2.0
643            assert!((arm.alpha - 2.0).abs() < f64::EPSILON);
644        }
645    }
646}