1use std::collections::HashMap;
14use std::io;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::time::Duration;
19
20use parking_lot::Mutex;
21use rand::SeedableRng as _;
22use rand_distr::{Beta, Distribution};
23use serde::{Deserialize, Serialize};
24use zeph_llm::any::AnyProvider;
25use zeph_llm::provider::{LlmProvider, Message, Role};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum TaskClass {
33 IndependentBatch,
35 SequentialPipeline,
37 HierarchicalDecomp,
39 Unknown,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum TopologyHint {
49 Parallel,
51 Sequential,
53 Hierarchical,
55 Hybrid,
57}
58
59impl TopologyHint {
60 #[must_use]
63 pub fn prompt_sentence(self) -> Option<&'static str> {
64 match self {
65 Self::Parallel => {
66 Some("Prefer maximizing parallel tasks; avoid unnecessary `depends_on` chains.")
67 }
68 Self::Sequential => Some(
69 "This goal is naturally a pipeline; produce a strict linear chain unless \
70 impossible.",
71 ),
72 Self::Hierarchical => {
73 Some("Decompose this goal into subgoals; expect 2–3 levels of depth.")
74 }
75 Self::Hybrid => None,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct AdvisorVerdict {
83 pub class: TaskClass,
85 pub hint: TopologyHint,
87 pub exploit: bool,
89 pub fallback: bool,
91}
92
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct BetaDist {
96 pub alpha: f64,
97 pub beta: f64,
98}
99
100impl Default for BetaDist {
101 fn default() -> Self {
102 Self {
103 alpha: 1.0,
104 beta: 1.0,
105 }
106 }
107}
108
109impl BetaDist {
110 fn sample<R: rand::Rng>(&self, rng: &mut R) -> f64 {
111 let a = self.alpha.max(1e-6);
112 let b = self.beta.max(1e-6);
113 Beta::new(a, b)
115 .expect("clamped values ≥1e-6 are always valid Beta params")
116 .sample(rng)
117 }
118}
119
120#[derive(Debug, Serialize, Deserialize)]
122struct PersistState {
123 version: u32,
124 arms: HashMap<String, BetaDist>,
125}
126
127#[derive(Debug, Default)]
129pub struct AdaptOrchMetrics {
130 pub classify_calls: AtomicU64,
132 pub classify_timeouts: AtomicU64,
134 pub hint_parallel: AtomicU64,
136 pub hint_sequential: AtomicU64,
137 pub hint_hierarchical: AtomicU64,
138 pub hint_hybrid: AtomicU64,
139 pub outcomes_recorded: AtomicU64,
141}
142
143fn arm_key(class: TaskClass, hint: TopologyHint) -> String {
144 let c = match class {
145 TaskClass::IndependentBatch => "independent_batch",
146 TaskClass::SequentialPipeline => "sequential_pipeline",
147 TaskClass::HierarchicalDecomp => "hierarchical_decomp",
148 TaskClass::Unknown => "unknown",
149 };
150 let h = match hint {
151 TopologyHint::Parallel => "parallel",
152 TopologyHint::Sequential => "sequential",
153 TopologyHint::Hierarchical => "hierarchical",
154 TopologyHint::Hybrid => "hybrid",
155 };
156 format!("{c}:{h}")
157}
158
159const ALL_HINTS: [TopologyHint; 4] = [
160 TopologyHint::Parallel,
161 TopologyHint::Sequential,
162 TopologyHint::Hierarchical,
163 TopologyHint::Hybrid,
164];
165
166pub struct TopologyAdvisor {
173 classifier: Arc<AnyProvider>,
174 arms: Arc<Mutex<HashMap<(TaskClass, TopologyHint), BetaDist>>>,
175 state_path: PathBuf,
176 classify_timeout: Duration,
177 pub metrics: Arc<AdaptOrchMetrics>,
178 rng: Arc<Mutex<rand::rngs::SmallRng>>,
179}
180
181impl std::fmt::Debug for TopologyAdvisor {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("TopologyAdvisor")
184 .field("state_path", &self.state_path)
185 .field("classify_timeout", &self.classify_timeout)
186 .finish_non_exhaustive()
187 }
188}
189
190impl TopologyAdvisor {
191 #[must_use]
196 pub fn new(
197 classifier: Arc<AnyProvider>,
198 state_path: impl Into<PathBuf>,
199 classify_timeout: Duration,
200 ) -> Self {
201 let path: PathBuf = {
202 let p = state_path.into();
203 if p.as_os_str().is_empty() {
204 Self::default_path()
205 } else {
206 p
207 }
208 };
209 let arms = load_arms(&path);
210 Self {
211 classifier,
212 arms: Arc::new(Mutex::new(arms)),
213 state_path: path,
214 classify_timeout,
215 metrics: Arc::new(AdaptOrchMetrics::default()),
216 rng: Arc::new(Mutex::new(rand::rngs::SmallRng::from_rng(&mut rand::rng()))),
217 }
218 }
219
220 #[must_use]
222 pub fn default_path() -> PathBuf {
223 dirs::home_dir()
224 .unwrap_or_else(|| PathBuf::from("."))
225 .join(".zeph")
226 .join("adaptorch_state.json")
227 }
228
229 pub async fn recommend(&self, goal: &str) -> AdvisorVerdict {
233 self.metrics.classify_calls.fetch_add(1, Ordering::Relaxed);
234
235 let class = tokio::time::timeout(self.classify_timeout, self.classify(goal))
236 .await
237 .unwrap_or_else(|_| {
238 self.metrics
239 .classify_timeouts
240 .fetch_add(1, Ordering::Relaxed);
241 TaskClass::Unknown
242 });
243
244 let fallback = class == TaskClass::Unknown;
245 let (hint, exploit) = self.sample_arm(class);
246
247 match hint {
248 TopologyHint::Parallel => {
249 self.metrics.hint_parallel.fetch_add(1, Ordering::Relaxed);
250 }
251 TopologyHint::Sequential => {
252 self.metrics.hint_sequential.fetch_add(1, Ordering::Relaxed);
253 }
254 TopologyHint::Hierarchical => {
255 self.metrics
256 .hint_hierarchical
257 .fetch_add(1, Ordering::Relaxed);
258 }
259 TopologyHint::Hybrid => {
260 self.metrics.hint_hybrid.fetch_add(1, Ordering::Relaxed);
261 }
262 }
263
264 AdvisorVerdict {
265 class,
266 hint,
267 exploit,
268 fallback,
269 }
270 }
271
272 pub fn record_outcome(&self, class: TaskClass, hint: TopologyHint, reward: f64) {
277 self.metrics
278 .outcomes_recorded
279 .fetch_add(1, Ordering::Relaxed);
280 let key = (class, hint);
281 let mut arms = self.arms.lock();
282 let arm = arms.entry(key).or_default();
283 if reward >= 1.0 {
284 arm.alpha += 1.0;
285 } else {
286 arm.beta += 1.0;
287 }
288 }
289
290 pub fn save(&self) -> io::Result<()> {
299 let arms_map: HashMap<String, BetaDist> = self
300 .arms
301 .lock()
302 .iter()
303 .map(|((class, hint), dist)| (arm_key(*class, *hint), dist.clone()))
304 .collect();
305
306 let state = PersistState {
307 version: 1,
308 arms: arms_map,
309 };
310
311 let json = serde_json::to_string_pretty(&state).map_err(io::Error::other)?;
312
313 if let Some(parent) = self.state_path.parent() {
314 std::fs::create_dir_all(parent)?;
315 }
316
317 atomic_write(&self.state_path, json.as_bytes())?;
318 Ok(())
319 }
320
321 async fn classify(&self, goal: &str) -> TaskClass {
324 let truncated: String = goal.chars().take(400).collect();
325 let system = "\
326You classify task decomposition patterns. Read the goal and answer with one of:\n\
327- independent_batch — fan-out work with no cross-deps (research, comparisons, multi-source queries)\n\
328- sequential_pipeline — strict ordering (build → test → deploy, ETL)\n\
329- hierarchical_decomp — tree of subgoals, divide-and-conquer\n\
330- unknown — does not clearly fit any of the above\n\n\
331Respond with a single JSON object:\n\
332{\"class\":\"...\",\"reason\":\"<one sentence>\"}";
333
334 let messages = vec![
335 Message::from_legacy(Role::System, system),
336 Message::from_legacy(Role::User, format!("Goal:\n{truncated}")),
337 ];
338
339 let raw = match self.classifier.chat(&messages).await {
340 Ok(r) => r,
341 Err(e) => {
342 tracing::warn!(error = %e, "adaptorch: classify call failed");
343 return TaskClass::Unknown;
344 }
345 };
346
347 parse_class(&raw)
348 }
349
350 fn sample_arm(&self, class: TaskClass) -> (TopologyHint, bool) {
351 if class == TaskClass::Unknown {
352 return (TopologyHint::Hybrid, false);
353 }
354 let arm_entries: Vec<(TopologyHint, BetaDist)> = {
356 let arms = self.arms.lock();
357 ALL_HINTS
358 .iter()
359 .map(|hint| {
360 (
361 *hint,
362 arms.get(&(class, *hint)).cloned().unwrap_or_default(),
363 )
364 })
365 .collect()
366 };
367 let mut rng = self.rng.lock();
368 let scores: Vec<(TopologyHint, f64)> = arm_entries
369 .iter()
370 .map(|(hint, dist)| (*hint, dist.sample(&mut *rng)))
371 .collect();
372
373 let (hint, score) = scores
374 .iter()
375 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
376 .map_or((TopologyHint::Hybrid, 0.0), |(h, s)| (*h, *s));
377
378 let arm = arm_entries
380 .iter()
381 .find(|(h, _)| *h == hint)
382 .map(|(_, d)| d.clone())
383 .unwrap_or_default();
384 let mean = arm.alpha / (arm.alpha + arm.beta);
385 let exploit = (score - mean).abs() < 0.15;
386
387 (hint, exploit)
388 }
389}
390
391fn parse_class(raw: &str) -> TaskClass {
393 if let Ok(val) = serde_json::from_str::<serde_json::Value>(raw)
395 && let Some(class) = val.get("class").and_then(|c| c.as_str())
396 {
397 return str_to_class(class);
398 }
399 if let Some(start) = raw.find('{')
401 && let Some(end) = raw[start..].find('}')
402 {
403 let chunk = &raw[start..=start + end];
404 if let Ok(val) = serde_json::from_str::<serde_json::Value>(chunk)
405 && let Some(class) = val.get("class").and_then(|c| c.as_str())
406 {
407 return str_to_class(class);
408 }
409 }
410 for variant in &[
412 "independent_batch",
413 "sequential_pipeline",
414 "hierarchical_decomp",
415 "unknown",
416 ] {
417 if raw.contains(variant) {
418 return str_to_class(variant);
419 }
420 }
421 TaskClass::Unknown
422}
423
424fn str_to_class(s: &str) -> TaskClass {
425 match s {
426 "independent_batch" => TaskClass::IndependentBatch,
427 "sequential_pipeline" => TaskClass::SequentialPipeline,
428 "hierarchical_decomp" => TaskClass::HierarchicalDecomp,
429 _ => TaskClass::Unknown,
430 }
431}
432
433fn load_arms(path: &std::path::Path) -> HashMap<(TaskClass, TopologyHint), BetaDist> {
434 let mut arms = default_arms();
435 let Ok(data) = std::fs::read_to_string(path) else {
436 return arms;
437 };
438 let Ok(state) = serde_json::from_str::<PersistState>(&data) else {
439 tracing::warn!(path = %path.display(), "adaptorch: failed to parse state file, using defaults");
440 return arms;
441 };
442 if state.version != 1 {
443 tracing::warn!(
444 version = state.version,
445 "adaptorch: unknown state version, using defaults"
446 );
447 return arms;
448 }
449 for (key_str, dist) in state.arms {
450 let mut parts = key_str.splitn(2, ':');
451 let (Some(c), Some(h)) = (parts.next(), parts.next()) else {
452 continue;
453 };
454 let class = str_to_class(c);
455 let hint = match h {
456 "parallel" => TopologyHint::Parallel,
457 "sequential" => TopologyHint::Sequential,
458 "hierarchical" => TopologyHint::Hierarchical,
459 "hybrid" => TopologyHint::Hybrid,
460 _ => continue,
461 };
462 arms.insert((class, hint), dist);
463 }
464 arms
465}
466
467fn default_arms() -> HashMap<(TaskClass, TopologyHint), BetaDist> {
468 let classes = [
469 TaskClass::IndependentBatch,
470 TaskClass::SequentialPipeline,
471 TaskClass::HierarchicalDecomp,
472 TaskClass::Unknown,
473 ];
474 let mut map = HashMap::new();
475 for class in classes {
476 for hint in ALL_HINTS {
477 map.insert((class, hint), BetaDist::default());
478 }
479 }
480 map
481}
482
483fn atomic_write(path: &std::path::Path, data: &[u8]) -> io::Result<()> {
484 zeph_common::fs_secure::atomic_write_private(path, data)
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn parse_class_direct_json() {
493 let json = r#"{"class":"independent_batch","reason":"fan-out"}"#;
494 assert_eq!(parse_class(json), TaskClass::IndependentBatch);
495 }
496
497 #[test]
498 fn parse_class_fallback_substring() {
499 assert_eq!(
500 parse_class(" sequential_pipeline "),
501 TaskClass::SequentialPipeline
502 );
503 }
504
505 #[test]
506 fn parse_class_unknown_for_garbage() {
507 assert_eq!(parse_class("no idea"), TaskClass::Unknown);
508 }
509
510 #[test]
511 fn topology_hint_sentence_hybrid_is_none() {
512 assert!(TopologyHint::Hybrid.prompt_sentence().is_none());
513 }
514
515 #[test]
516 fn record_outcome_updates_alpha_beta() {
517 use std::sync::Arc;
518 use zeph_llm::any::AnyProvider;
519 let mock = zeph_llm::mock::MockProvider::default();
520 let advisor = TopologyAdvisor::new(
521 Arc::new(AnyProvider::Mock(mock)),
522 PathBuf::new(),
523 Duration::from_secs(4),
524 );
525 advisor.record_outcome(TaskClass::IndependentBatch, TopologyHint::Parallel, 1.0);
526 advisor.record_outcome(TaskClass::IndependentBatch, TopologyHint::Parallel, 0.0);
527 let arms = advisor.arms.lock();
528 let arm = arms
529 .get(&(TaskClass::IndependentBatch, TopologyHint::Parallel))
530 .unwrap();
531 assert!((arm.alpha - 2.0).abs() < f64::EPSILON);
532 assert!((arm.beta - 2.0).abs() < f64::EPSILON);
533 }
534
535 #[tokio::test]
536 async fn recommend_with_valid_json_returns_correct_class() {
537 use zeph_llm::any::AnyProvider;
538 use zeph_llm::mock::MockProvider;
539 let mock = MockProvider::with_responses(vec![
540 r#"{"class":"sequential_pipeline","reason":"strict ordering"}"#.into(),
541 ]);
542 let advisor = TopologyAdvisor::new(
543 Arc::new(AnyProvider::Mock(mock)),
544 PathBuf::new(),
545 Duration::from_secs(4),
546 );
547 let verdict = advisor
548 .recommend("Build, test, then deploy the service")
549 .await;
550 assert_eq!(verdict.class, TaskClass::SequentialPipeline);
551 assert!(advisor.metrics.classify_timeouts.load(Ordering::Relaxed) == 0);
552 }
553
554 #[tokio::test]
555 async fn recommend_timeout_returns_unknown_and_increments_metric() {
556 use zeph_llm::any::AnyProvider;
557 use zeph_llm::mock::MockProvider;
558 let mut mock = MockProvider::default();
560 mock.delay_ms = 200;
561 mock.default_response = r#"{"class":"sequential_pipeline","reason":"x"}"#.into();
562 let advisor = TopologyAdvisor::new(
563 Arc::new(AnyProvider::Mock(mock)),
564 PathBuf::new(),
565 Duration::from_millis(50), );
567 let verdict = advisor.recommend("any goal").await;
568 assert_eq!(verdict.class, TaskClass::Unknown);
569 assert_eq!(advisor.metrics.classify_timeouts.load(Ordering::Relaxed), 1);
570 }
571
572 #[test]
573 fn sample_arm_favours_reinforced_hint() {
574 use zeph_llm::any::AnyProvider;
575 let mock = zeph_llm::mock::MockProvider::default();
576 let advisor = TopologyAdvisor::new(
577 Arc::new(AnyProvider::Mock(mock)),
578 PathBuf::new(),
579 Duration::from_secs(4),
580 );
581 for _ in 0..20 {
583 advisor.record_outcome(TaskClass::SequentialPipeline, TopologyHint::Sequential, 1.0);
584 }
585 let mut counts = std::collections::HashMap::new();
587 for _ in 0..50 {
588 let (hint, _) = advisor.sample_arm(TaskClass::SequentialPipeline);
589 *counts.entry(hint).or_insert(0u32) += 1;
590 }
591 let sequential_count = counts.get(&TopologyHint::Sequential).copied().unwrap_or(0);
592 assert!(
593 sequential_count > 30,
594 "expected Sequential to win >30/50 times after reinforcement, got {sequential_count}"
595 );
596 }
597
598 #[test]
599 fn persistence_round_trip() {
600 use zeph_llm::any::AnyProvider;
601 let dir = tempfile::tempdir().unwrap();
602 let path = dir.path().join("state.json");
603 {
604 let mock = zeph_llm::mock::MockProvider::default();
605 let advisor = TopologyAdvisor::new(
606 Arc::new(AnyProvider::Mock(mock)),
607 path.clone(),
608 Duration::from_secs(4),
609 );
610 advisor.record_outcome(TaskClass::SequentialPipeline, TopologyHint::Sequential, 1.0);
611 advisor.save().unwrap();
612 }
613 {
614 let mock = zeph_llm::mock::MockProvider::default();
615 let advisor = TopologyAdvisor::new(
616 Arc::new(AnyProvider::Mock(mock)),
617 path.clone(),
618 Duration::from_secs(4),
619 );
620 let arms = advisor.arms.lock();
621 let arm = arms
622 .get(&(TaskClass::SequentialPipeline, TopologyHint::Sequential))
623 .unwrap();
624 assert!((arm.alpha - 2.0).abs() < f64::EPSILON);
626 }
627 }
628}