1use std::collections::HashMap;
14use std::io;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::time::Duration;
19
20use tracing::Instrument as _;
21
22use parking_lot::Mutex;
23use rand::SeedableRng as _;
24use rand_distr::{Beta, Distribution};
25use serde::{Deserialize, Serialize};
26use zeph_llm::any::AnyProvider;
27use zeph_llm::provider::{LlmProvider, Message, Role};
28
29#[non_exhaustive]
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case")]
35pub enum TaskClass {
36 IndependentBatch,
38 SequentialPipeline,
40 HierarchicalDecomp,
42 Unknown,
44}
45
46#[non_exhaustive]
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
51#[serde(rename_all = "snake_case")]
52pub enum TopologyHint {
53 Parallel,
55 Sequential,
57 Hierarchical,
59 Hybrid,
61}
62
63impl TopologyHint {
64 #[must_use]
67 pub fn prompt_sentence(self) -> Option<&'static str> {
68 match self {
69 Self::Parallel => {
70 Some("Prefer maximizing parallel tasks; avoid unnecessary `depends_on` chains.")
71 }
72 Self::Sequential => Some(
73 "This goal is naturally a pipeline; produce a strict linear chain unless \
74 impossible.",
75 ),
76 Self::Hierarchical => {
77 Some("Decompose this goal into subgoals; expect 2–3 levels of depth.")
78 }
79 Self::Hybrid => None,
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct AdvisorVerdict {
87 pub class: TaskClass,
89 pub hint: TopologyHint,
91 pub exploit: bool,
93 pub fallback: bool,
95}
96
97#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99pub struct BetaDist {
100 pub alpha: f64,
101 pub beta: f64,
102}
103
104impl Default for BetaDist {
105 fn default() -> Self {
106 Self {
107 alpha: 1.0,
108 beta: 1.0,
109 }
110 }
111}
112
113impl BetaDist {
114 fn sample<R: rand::Rng>(&self, rng: &mut R) -> f64 {
115 let a = self.alpha.max(1e-6);
116 let b = self.beta.max(1e-6);
117 Beta::new(a, b)
119 .expect("clamped values ≥1e-6 are always valid Beta params")
120 .sample(rng)
121 }
122}
123
124#[derive(Debug, Serialize, Deserialize)]
126struct PersistState {
127 version: u32,
128 arms: HashMap<String, BetaDist>,
129}
130
131#[derive(Debug, Default)]
133pub struct AdaptOrchMetrics {
134 pub classify_calls: AtomicU64,
136 pub classify_timeouts: AtomicU64,
138 pub hint_parallel: AtomicU64,
140 pub hint_sequential: AtomicU64,
141 pub hint_hierarchical: AtomicU64,
142 pub hint_hybrid: AtomicU64,
143 pub outcomes_recorded: AtomicU64,
145}
146
147fn arm_key(class: TaskClass, hint: TopologyHint) -> String {
148 let c = match class {
149 TaskClass::IndependentBatch => "independent_batch",
150 TaskClass::SequentialPipeline => "sequential_pipeline",
151 TaskClass::HierarchicalDecomp => "hierarchical_decomp",
152 TaskClass::Unknown => "unknown",
153 };
154 let h = match hint {
155 TopologyHint::Parallel => "parallel",
156 TopologyHint::Sequential => "sequential",
157 TopologyHint::Hierarchical => "hierarchical",
158 TopologyHint::Hybrid => "hybrid",
159 };
160 format!("{c}:{h}")
161}
162
163const ALL_HINTS: [TopologyHint; 4] = [
164 TopologyHint::Parallel,
165 TopologyHint::Sequential,
166 TopologyHint::Hierarchical,
167 TopologyHint::Hybrid,
168];
169
170pub struct TopologyAdvisor {
177 classifier: Arc<AnyProvider>,
178 arms: Arc<Mutex<HashMap<(TaskClass, TopologyHint), BetaDist>>>,
179 state_path: PathBuf,
180 classify_timeout: Duration,
181 pub metrics: Arc<AdaptOrchMetrics>,
182 rng: Arc<Mutex<rand::rngs::SmallRng>>,
183}
184
185impl std::fmt::Debug for TopologyAdvisor {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("TopologyAdvisor")
188 .field("state_path", &self.state_path)
189 .field("classify_timeout", &self.classify_timeout)
190 .finish_non_exhaustive()
191 }
192}
193
194impl TopologyAdvisor {
195 #[must_use]
200 pub fn new(
201 classifier: Arc<AnyProvider>,
202 state_path: impl Into<PathBuf>,
203 classify_timeout: Duration,
204 ) -> Self {
205 let path: PathBuf = {
206 let p = state_path.into();
207 if p.as_os_str().is_empty() {
208 Self::default_path()
209 } else {
210 p
211 }
212 };
213 let arms = load_arms(&path);
214 Self {
215 classifier,
216 arms: Arc::new(Mutex::new(arms)),
217 state_path: path,
218 classify_timeout,
219 metrics: Arc::new(AdaptOrchMetrics::default()),
220 rng: Arc::new(Mutex::new(rand::rngs::SmallRng::from_rng(&mut rand::rng()))),
221 }
222 }
223
224 #[must_use]
226 pub fn default_path() -> PathBuf {
227 dirs::home_dir()
228 .unwrap_or_else(|| PathBuf::from("."))
229 .join(".zeph")
230 .join("adaptorch_state.json")
231 }
232
233 pub async fn recommend(&self, goal: &str) -> AdvisorVerdict {
237 async move {
238 self.metrics.classify_calls.fetch_add(1, Ordering::Relaxed);
239
240 let class = tokio::time::timeout(self.classify_timeout, self.classify(goal))
241 .await
242 .unwrap_or_else(|_| {
243 self.metrics
244 .classify_timeouts
245 .fetch_add(1, Ordering::Relaxed);
246 TaskClass::Unknown
247 });
248
249 let fallback = class == TaskClass::Unknown;
250 let (hint, exploit) = self.sample_arm(class);
251
252 match hint {
253 TopologyHint::Parallel => {
254 self.metrics.hint_parallel.fetch_add(1, Ordering::Relaxed);
255 }
256 TopologyHint::Sequential => {
257 self.metrics.hint_sequential.fetch_add(1, Ordering::Relaxed);
258 }
259 TopologyHint::Hierarchical => {
260 self.metrics
261 .hint_hierarchical
262 .fetch_add(1, Ordering::Relaxed);
263 }
264 TopologyHint::Hybrid => {
265 self.metrics.hint_hybrid.fetch_add(1, Ordering::Relaxed);
266 }
267 }
268
269 AdvisorVerdict {
270 class,
271 hint,
272 exploit,
273 fallback,
274 }
275 }
276 .instrument(tracing::info_span!(
277 "orchestration.adaptorch.recommend",
278 goal_len = goal.len(),
279 ))
280 .await
281 }
282
283 pub fn record_outcome(&self, class: TaskClass, hint: TopologyHint, reward: f64) {
288 self.metrics
289 .outcomes_recorded
290 .fetch_add(1, Ordering::Relaxed);
291 let key = (class, hint);
292 let mut arms = self.arms.lock();
293 let arm = arms.entry(key).or_default();
294 if reward >= 1.0 {
295 arm.alpha += 1.0;
296 } else {
297 arm.beta += 1.0;
298 }
299 }
300
301 pub fn save(&self) -> io::Result<()> {
310 let arms_map: HashMap<String, BetaDist> = self
311 .arms
312 .lock()
313 .iter()
314 .map(|((class, hint), dist)| (arm_key(*class, *hint), dist.clone()))
315 .collect();
316
317 let state = PersistState {
318 version: 1,
319 arms: arms_map,
320 };
321
322 let json = serde_json::to_string_pretty(&state).map_err(io::Error::other)?;
323
324 if let Some(parent) = self.state_path.parent() {
325 std::fs::create_dir_all(parent)?;
326 }
327
328 atomic_write(&self.state_path, json.as_bytes())?;
329 Ok(())
330 }
331
332 async fn classify(&self, goal: &str) -> TaskClass {
335 async move {
336 let truncated: String = goal.chars().take(400).collect();
337 let system = "\
338You classify task decomposition patterns. Read the goal and answer with one of:\n\
339- independent_batch — fan-out work with no cross-deps (research, comparisons, multi-source queries)\n\
340- sequential_pipeline — strict ordering (build → test → deploy, ETL)\n\
341- hierarchical_decomp — tree of subgoals, divide-and-conquer\n\
342- unknown — does not clearly fit any of the above\n\n\
343Respond with a single JSON object:\n\
344{\"class\":\"...\",\"reason\":\"<one sentence>\"}";
345
346 let messages = vec![
347 Message::from_legacy(Role::System, system),
348 Message::from_legacy(Role::User, format!("Goal:\n{truncated}")),
349 ];
350
351 let raw = match self.classifier.chat(&messages).await {
352 Ok(r) => r,
353 Err(e) => {
354 tracing::warn!(error = %e, "adaptorch: classify call failed");
355 return TaskClass::Unknown;
356 }
357 };
358
359 parse_class(&raw)
360 }
361 .instrument(tracing::info_span!(
362 "orchestration.adaptorch.classify",
363 goal_len = goal.len(),
364 ))
365 .await
366 }
367
368 fn sample_arm(&self, class: TaskClass) -> (TopologyHint, bool) {
369 if class == TaskClass::Unknown {
370 return (TopologyHint::Hybrid, false);
371 }
372 let arm_entries: Vec<(TopologyHint, BetaDist)> = {
374 let arms = self.arms.lock();
375 ALL_HINTS
376 .iter()
377 .map(|hint| {
378 (
379 *hint,
380 arms.get(&(class, *hint)).cloned().unwrap_or_default(),
381 )
382 })
383 .collect()
384 };
385 let mut rng = self.rng.lock();
386 let scores: Vec<(TopologyHint, f64)> = arm_entries
387 .iter()
388 .map(|(hint, dist)| (*hint, dist.sample(&mut *rng)))
389 .collect();
390
391 let (hint, score) = scores
392 .iter()
393 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
394 .map_or((TopologyHint::Hybrid, 0.0), |(h, s)| (*h, *s));
395
396 let arm = arm_entries
398 .iter()
399 .find(|(h, _)| *h == hint)
400 .map(|(_, d)| d.clone())
401 .unwrap_or_default();
402 let mean = arm.alpha / (arm.alpha + arm.beta);
403 let exploit = (score - mean).abs() < 0.15;
404
405 (hint, exploit)
406 }
407}
408
409fn parse_class(raw: &str) -> TaskClass {
411 if let Ok(val) = serde_json::from_str::<serde_json::Value>(raw)
413 && let Some(class) = val.get("class").and_then(|c| c.as_str())
414 {
415 return str_to_class(class);
416 }
417 if let Some(start) = raw.find('{')
419 && let Some(end) = raw[start..].find('}')
420 {
421 let chunk = &raw[start..=start + end];
422 if let Ok(val) = serde_json::from_str::<serde_json::Value>(chunk)
423 && let Some(class) = val.get("class").and_then(|c| c.as_str())
424 {
425 return str_to_class(class);
426 }
427 }
428 for variant in &[
430 "independent_batch",
431 "sequential_pipeline",
432 "hierarchical_decomp",
433 "unknown",
434 ] {
435 if raw.contains(variant) {
436 return str_to_class(variant);
437 }
438 }
439 TaskClass::Unknown
440}
441
442fn str_to_class(s: &str) -> TaskClass {
443 match s {
444 "independent_batch" => TaskClass::IndependentBatch,
445 "sequential_pipeline" => TaskClass::SequentialPipeline,
446 "hierarchical_decomp" => TaskClass::HierarchicalDecomp,
447 _ => TaskClass::Unknown,
448 }
449}
450
451fn load_arms(path: &std::path::Path) -> HashMap<(TaskClass, TopologyHint), BetaDist> {
452 let mut arms = default_arms();
453 let Ok(data) = std::fs::read_to_string(path) else {
454 return arms;
455 };
456 let Ok(state) = serde_json::from_str::<PersistState>(&data) else {
457 tracing::warn!(path = %path.display(), "adaptorch: failed to parse state file, using defaults");
458 return arms;
459 };
460 if state.version != 1 {
461 tracing::warn!(
462 version = state.version,
463 "adaptorch: unknown state version, using defaults"
464 );
465 return arms;
466 }
467 for (key_str, dist) in state.arms {
468 let mut parts = key_str.splitn(2, ':');
469 let (Some(c), Some(h)) = (parts.next(), parts.next()) else {
470 continue;
471 };
472 let class = str_to_class(c);
473 let hint = match h {
474 "parallel" => TopologyHint::Parallel,
475 "sequential" => TopologyHint::Sequential,
476 "hierarchical" => TopologyHint::Hierarchical,
477 "hybrid" => TopologyHint::Hybrid,
478 _ => continue,
479 };
480 arms.insert((class, hint), dist);
481 }
482 arms
483}
484
485fn default_arms() -> HashMap<(TaskClass, TopologyHint), BetaDist> {
486 let classes = [
487 TaskClass::IndependentBatch,
488 TaskClass::SequentialPipeline,
489 TaskClass::HierarchicalDecomp,
490 TaskClass::Unknown,
491 ];
492 let mut map = HashMap::new();
493 for class in classes {
494 for hint in ALL_HINTS {
495 map.insert((class, hint), BetaDist::default());
496 }
497 }
498 map
499}
500
501fn atomic_write(path: &std::path::Path, data: &[u8]) -> io::Result<()> {
502 zeph_common::fs_secure::atomic_write_private(path, data)
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn parse_class_direct_json() {
511 let json = r#"{"class":"independent_batch","reason":"fan-out"}"#;
512 assert_eq!(parse_class(json), TaskClass::IndependentBatch);
513 }
514
515 #[test]
516 fn parse_class_fallback_substring() {
517 assert_eq!(
518 parse_class(" sequential_pipeline "),
519 TaskClass::SequentialPipeline
520 );
521 }
522
523 #[test]
524 fn parse_class_unknown_for_garbage() {
525 assert_eq!(parse_class("no idea"), TaskClass::Unknown);
526 }
527
528 #[test]
529 fn topology_hint_sentence_hybrid_is_none() {
530 assert!(TopologyHint::Hybrid.prompt_sentence().is_none());
531 }
532
533 #[test]
534 fn record_outcome_updates_alpha_beta() {
535 use std::sync::Arc;
536 use zeph_llm::any::AnyProvider;
537 let mock = zeph_llm::mock::MockProvider::default();
538 let advisor = TopologyAdvisor::new(
539 Arc::new(AnyProvider::Mock(mock)),
540 PathBuf::new(),
541 Duration::from_secs(4),
542 );
543 advisor.record_outcome(TaskClass::IndependentBatch, TopologyHint::Parallel, 1.0);
544 advisor.record_outcome(TaskClass::IndependentBatch, TopologyHint::Parallel, 0.0);
545 let arms = advisor.arms.lock();
546 let arm = arms
547 .get(&(TaskClass::IndependentBatch, TopologyHint::Parallel))
548 .unwrap();
549 assert!((arm.alpha - 2.0).abs() < f64::EPSILON);
550 assert!((arm.beta - 2.0).abs() < f64::EPSILON);
551 }
552
553 #[tokio::test]
554 async fn recommend_with_valid_json_returns_correct_class() {
555 use zeph_llm::any::AnyProvider;
556 use zeph_llm::mock::MockProvider;
557 let mock = MockProvider::with_responses(vec![
558 r#"{"class":"sequential_pipeline","reason":"strict ordering"}"#.into(),
559 ]);
560 let advisor = TopologyAdvisor::new(
561 Arc::new(AnyProvider::Mock(mock)),
562 PathBuf::new(),
563 Duration::from_secs(4),
564 );
565 let verdict = advisor
566 .recommend("Build, test, then deploy the service")
567 .await;
568 assert_eq!(verdict.class, TaskClass::SequentialPipeline);
569 assert!(advisor.metrics.classify_timeouts.load(Ordering::Relaxed) == 0);
570 }
571
572 #[tokio::test]
573 async fn recommend_timeout_returns_unknown_and_increments_metric() {
574 use zeph_llm::any::AnyProvider;
575 use zeph_llm::mock::MockProvider;
576 let mut mock = MockProvider::default();
578 mock.delay_ms = 200;
579 mock.default_response = r#"{"class":"sequential_pipeline","reason":"x"}"#.into();
580 let advisor = TopologyAdvisor::new(
581 Arc::new(AnyProvider::Mock(mock)),
582 PathBuf::new(),
583 Duration::from_millis(50), );
585 let verdict = advisor.recommend("any goal").await;
586 assert_eq!(verdict.class, TaskClass::Unknown);
587 assert_eq!(advisor.metrics.classify_timeouts.load(Ordering::Relaxed), 1);
588 }
589
590 #[test]
591 fn sample_arm_favours_reinforced_hint() {
592 use zeph_llm::any::AnyProvider;
593 let mock = zeph_llm::mock::MockProvider::default();
594 let advisor = TopologyAdvisor::new(
595 Arc::new(AnyProvider::Mock(mock)),
596 PathBuf::new(),
597 Duration::from_secs(4),
598 );
599 for _ in 0..20 {
601 advisor.record_outcome(TaskClass::SequentialPipeline, TopologyHint::Sequential, 1.0);
602 }
603 let mut counts = std::collections::HashMap::new();
605 for _ in 0..50 {
606 let (hint, _) = advisor.sample_arm(TaskClass::SequentialPipeline);
607 *counts.entry(hint).or_insert(0u32) += 1;
608 }
609 let sequential_count = counts.get(&TopologyHint::Sequential).copied().unwrap_or(0);
610 assert!(
611 sequential_count > 30,
612 "expected Sequential to win >30/50 times after reinforcement, got {sequential_count}"
613 );
614 }
615
616 #[test]
617 fn persistence_round_trip() {
618 use zeph_llm::any::AnyProvider;
619 let dir = tempfile::tempdir().unwrap();
620 let path = dir.path().join("state.json");
621 {
622 let mock = zeph_llm::mock::MockProvider::default();
623 let advisor = TopologyAdvisor::new(
624 Arc::new(AnyProvider::Mock(mock)),
625 path.clone(),
626 Duration::from_secs(4),
627 );
628 advisor.record_outcome(TaskClass::SequentialPipeline, TopologyHint::Sequential, 1.0);
629 advisor.save().unwrap();
630 }
631 {
632 let mock = zeph_llm::mock::MockProvider::default();
633 let advisor = TopologyAdvisor::new(
634 Arc::new(AnyProvider::Mock(mock)),
635 path.clone(),
636 Duration::from_secs(4),
637 );
638 let arms = advisor.arms.lock();
639 let arm = arms
640 .get(&(TaskClass::SequentialPipeline, TopologyHint::Sequential))
641 .unwrap();
642 assert!((arm.alpha - 2.0).abs() < f64::EPSILON);
644 }
645 }
646}