1use std::collections::VecDeque;
37
38#[derive(Debug, Clone, PartialEq)]
40pub struct CapacityControllerConfig {
41 pub enabled: bool,
42 pub low_risk_max: f64,
43 pub medium_risk_max: f64,
44 pub severe_min_slack: f64,
45 pub severe_violation_ratio: f64,
46 pub refresh_cooldown_turns: u64,
47 pub replan_cooldown_turns: u64,
48 pub max_replay_per_turn: usize,
49 pub min_turns_before_guardrail: u64,
50 pub profile_window: usize,
51 pub deepseek_v3_2_chat_prior: f64,
52 pub deepseek_v3_2_reasoner_prior: f64,
53 pub deepseek_v4_pro_prior: f64,
54 pub deepseek_v4_flash_prior: f64,
55 pub fallback_default_prior: f64,
56}
57
58impl Default for CapacityControllerConfig {
59 fn default() -> Self {
60 Self {
61 enabled: false,
62 low_risk_max: 0.50,
63 medium_risk_max: 0.62,
64 severe_min_slack: -0.25,
65 severe_violation_ratio: 0.40,
66 refresh_cooldown_turns: 6,
67 replan_cooldown_turns: 5,
68 max_replay_per_turn: 1,
69 min_turns_before_guardrail: 4,
70 profile_window: 8,
71 deepseek_v3_2_chat_prior: 3.9,
72 deepseek_v3_2_reasoner_prior: 4.1,
73 deepseek_v4_pro_prior: 3.5,
74 deepseek_v4_flash_prior: 4.2,
75 fallback_default_prior: 3.8,
76 }
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum GuardrailAction {
83 NoIntervention,
84 TargetedContextRefresh,
85 VerifyWithToolReplay,
86 VerifyAndReplan,
87}
88
89impl GuardrailAction {
90 #[must_use]
91 pub fn as_str(self) -> &'static str {
92 match self {
93 GuardrailAction::NoIntervention => "no_intervention",
94 GuardrailAction::TargetedContextRefresh => "targeted_context_refresh",
95 GuardrailAction::VerifyWithToolReplay => "verify_with_tool_replay",
96 GuardrailAction::VerifyAndReplan => "verify_and_replan",
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum RiskBand {
104 Low,
105 Medium,
106 High,
107}
108
109impl RiskBand {
110 #[must_use]
111 pub fn as_str(self) -> &'static str {
112 match self {
113 RiskBand::Low => "low",
114 RiskBand::Medium => "medium",
115 RiskBand::High => "high",
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct CapacityObservationInput {
123 pub turn_index: u64,
124 pub model: String,
125 pub action_count_this_turn: usize,
126 pub tool_calls_recent_window: usize,
127 pub unique_reference_ids_recent_window: usize,
128 pub context_used_ratio: f64,
129}
130
131#[derive(Debug, Clone, Copy, Default)]
133pub struct DynamicSlackProfile {
134 pub final_slack: f64,
135 pub min_slack: f64,
136 pub violation_ratio: f64,
137 pub slack_volatility: f64,
138 pub slack_drop: f64,
139}
140
141#[derive(Debug, Clone)]
143pub struct CapacitySnapshot {
144 pub turn_index: u64,
145 pub h_hat: f64,
146 pub c_hat: f64,
147 pub slack: f64,
148 pub profile: DynamicSlackProfile,
149 pub p_fail: f64,
150 pub risk_band: RiskBand,
151 pub severe: bool,
152}
153
154#[derive(Debug, Clone)]
156pub struct CapacityDecision {
157 pub action: GuardrailAction,
158 pub reason: String,
159 pub cooldown_blocked: bool,
160}
161
162#[derive(Debug, Clone, Default)]
163struct GuardrailRuntimeState {
164 last_refresh_turn: Option<u64>,
165 last_replan_turn: Option<u64>,
166 replay_count_this_turn: usize,
167 replay_disabled_turn: Option<u64>,
168 intervention_applied_turn: Option<u64>,
169}
170
171#[derive(Debug, Clone)]
173pub struct CapacityController {
174 config: CapacityControllerConfig,
175 slack_window: VecDeque<f64>,
176 recent_tool_counts: VecDeque<usize>,
177 recent_ref_counts: VecDeque<usize>,
178 state: GuardrailRuntimeState,
179 last_snapshot: Option<CapacitySnapshot>,
180}
181
182impl CapacityController {
183 #[must_use]
184 pub fn new(config: CapacityControllerConfig) -> Self {
185 Self {
186 config,
187 slack_window: VecDeque::new(),
188 recent_tool_counts: VecDeque::new(),
189 recent_ref_counts: VecDeque::new(),
190 state: GuardrailRuntimeState::default(),
191 last_snapshot: None,
192 }
193 }
194
195 pub fn observe_pre_turn(
196 &mut self,
197 input: CapacityObservationInput,
198 ) -> Option<CapacitySnapshot> {
199 self.observe(input)
200 }
201
202 pub fn observe_post_tool(
203 &mut self,
204 input: CapacityObservationInput,
205 ) -> Option<CapacitySnapshot> {
206 self.observe(input)
207 }
208
209 #[must_use]
211 pub fn decide(
212 &mut self,
213 turn_index: u64,
214 snapshot: Option<&CapacitySnapshot>,
215 ) -> CapacityDecision {
216 if !self.config.enabled {
217 return CapacityDecision {
218 action: GuardrailAction::NoIntervention,
219 reason: "capacity_controller_disabled".to_string(),
220 cooldown_blocked: false,
221 };
222 }
223
224 let Some(snapshot) = snapshot else {
225 return CapacityDecision {
226 action: GuardrailAction::NoIntervention,
227 reason: "missing_capacity_data_fail_open".to_string(),
228 cooldown_blocked: false,
229 };
230 };
231
232 if turn_index < self.config.min_turns_before_guardrail {
233 return CapacityDecision {
234 action: GuardrailAction::NoIntervention,
235 reason: "min_turns_before_guardrail_not_reached".to_string(),
236 cooldown_blocked: false,
237 };
238 }
239
240 let proposed = decide_policy(&self.config, snapshot);
241 if proposed == GuardrailAction::NoIntervention {
242 return CapacityDecision {
243 action: proposed,
244 reason: "low_risk_no_intervention".to_string(),
245 cooldown_blocked: false,
246 };
247 }
248
249 if self
250 .state
251 .intervention_applied_turn
252 .is_some_and(|t| t == turn_index)
253 {
254 return CapacityDecision {
255 action: GuardrailAction::NoIntervention,
256 reason: "intervention_already_applied_this_turn".to_string(),
257 cooldown_blocked: true,
258 };
259 }
260
261 match proposed {
262 GuardrailAction::TargetedContextRefresh => {
263 if self
264 .state
265 .last_refresh_turn
266 .is_some_and(|last| turn_index <= last + self.config.refresh_cooldown_turns)
267 {
268 return CapacityDecision {
269 action: GuardrailAction::NoIntervention,
270 reason: "refresh_cooldown_active".to_string(),
271 cooldown_blocked: true,
272 };
273 }
274 }
275 GuardrailAction::VerifyWithToolReplay => {
276 if self
277 .state
278 .replay_disabled_turn
279 .is_some_and(|t| t == turn_index)
280 {
281 return CapacityDecision {
282 action: GuardrailAction::NoIntervention,
283 reason: "replay_disabled_for_turn".to_string(),
284 cooldown_blocked: true,
285 };
286 }
287 if self.state.replay_count_this_turn >= self.config.max_replay_per_turn {
288 return CapacityDecision {
289 action: GuardrailAction::NoIntervention,
290 reason: "max_replay_per_turn_reached".to_string(),
291 cooldown_blocked: true,
292 };
293 }
294 }
295 GuardrailAction::VerifyAndReplan => {
296 if self
297 .state
298 .last_replan_turn
299 .is_some_and(|last| turn_index <= last + self.config.replan_cooldown_turns)
300 {
301 return CapacityDecision {
302 action: GuardrailAction::NoIntervention,
303 reason: "replan_cooldown_active".to_string(),
304 cooldown_blocked: true,
305 };
306 }
307 }
308 GuardrailAction::NoIntervention => {}
309 }
310
311 CapacityDecision {
312 action: proposed,
313 reason: "policy_selected_action".to_string(),
314 cooldown_blocked: false,
315 }
316 }
317
318 pub fn mark_turn_start(&mut self, turn_index: u64) {
319 let new_turn = match self.last_snapshot.as_ref() {
320 None => true,
321 Some(snapshot) => snapshot.turn_index != turn_index,
322 };
323 if new_turn {
324 self.state.replay_count_this_turn = 0;
325 self.state.replay_disabled_turn = None;
326 self.state.intervention_applied_turn = None;
327 }
328 }
329
330 pub fn mark_intervention_applied(&mut self, turn_index: u64, action: GuardrailAction) {
331 self.state.intervention_applied_turn = Some(turn_index);
332 match action {
333 GuardrailAction::TargetedContextRefresh => {
334 self.state.last_refresh_turn = Some(turn_index);
335 }
336 GuardrailAction::VerifyWithToolReplay => {
337 self.state.replay_count_this_turn =
338 self.state.replay_count_this_turn.saturating_add(1);
339 }
340 GuardrailAction::VerifyAndReplan => {
341 self.state.last_replan_turn = Some(turn_index);
342 }
343 GuardrailAction::NoIntervention => {}
344 }
345 }
346
347 pub fn mark_replay_failed(&mut self, turn_index: u64) {
348 self.state.replay_disabled_turn = Some(turn_index);
349 }
350
351 #[must_use]
352 pub fn last_snapshot(&self) -> Option<&CapacitySnapshot> {
353 self.last_snapshot.as_ref()
354 }
355
356 fn observe(&mut self, input: CapacityObservationInput) -> Option<CapacitySnapshot> {
357 if !self.config.enabled {
358 return None;
359 }
360
361 let context_used_ratio = input.context_used_ratio.clamp(0.0, 2.0);
362 let action_complexity_bits = log2_1p(input.action_count_this_turn);
363 let tool_complexity_bits = log2_1p(input.tool_calls_recent_window);
364 let ref_complexity_bits = log2_1p(input.unique_reference_ids_recent_window);
365 let context_pressure_bits = 6.0 * context_used_ratio;
366
367 let h_hat = (0.35 * action_complexity_bits)
368 + (0.30 * tool_complexity_bits)
369 + (0.20 * ref_complexity_bits)
370 + (0.15 * context_pressure_bits);
371 let c_hat = self.model_prior(&input.model);
372 let slack = c_hat - h_hat;
373
374 push_window(&mut self.slack_window, slack, self.config.profile_window);
375 push_window(
376 &mut self.recent_tool_counts,
377 input.tool_calls_recent_window,
378 self.config.profile_window,
379 );
380 push_window(
381 &mut self.recent_ref_counts,
382 input.unique_reference_ids_recent_window,
383 self.config.profile_window,
384 );
385
386 let profile = compute_profile(&self.slack_window);
387 let z = (-1.65 * profile.final_slack)
388 + (-0.85 * profile.min_slack)
389 + (1.35 * profile.violation_ratio)
390 + (0.70 * profile.slack_volatility)
391 + (0.28 * profile.slack_drop)
392 - 0.12;
393 let p_fail = sigmoid(z).clamp(0.0, 1.0);
394 let risk_band = if p_fail <= self.config.low_risk_max {
395 RiskBand::Low
396 } else if p_fail <= self.config.medium_risk_max {
397 RiskBand::Medium
398 } else {
399 RiskBand::High
400 };
401 let severe = profile.min_slack <= self.config.severe_min_slack
402 || profile.violation_ratio >= self.config.severe_violation_ratio;
403
404 let snapshot = CapacitySnapshot {
405 turn_index: input.turn_index,
406 h_hat,
407 c_hat,
408 slack,
409 profile,
410 p_fail,
411 risk_band,
412 severe,
413 };
414 self.last_snapshot = Some(snapshot.clone());
415 Some(snapshot)
416 }
417
418 fn model_prior(&self, model: &str) -> f64 {
419 let normalized = normalize_model_prior_key(model);
420 match normalized {
421 "deepseek_v4_pro" => self.config.deepseek_v4_pro_prior,
422 "deepseek_v4_flash" => self.config.deepseek_v4_flash_prior,
423 "deepseek_v3_2_reasoner" => self.config.deepseek_v3_2_reasoner_prior,
424 "deepseek_v3_2_chat" => self.config.deepseek_v3_2_chat_prior,
425 _ => self.config.fallback_default_prior,
426 }
427 }
428}
429
430#[must_use]
432pub fn decide_policy(
433 _config: &CapacityControllerConfig,
434 snapshot: &CapacitySnapshot,
435) -> GuardrailAction {
436 match snapshot.risk_band {
437 RiskBand::Low => GuardrailAction::NoIntervention,
438 RiskBand::Medium => GuardrailAction::TargetedContextRefresh,
439 RiskBand::High if snapshot.severe => GuardrailAction::VerifyAndReplan,
440 RiskBand::High => GuardrailAction::VerifyWithToolReplay,
441 }
442}
443
444fn normalize_model_prior_key(model: &str) -> &str {
445 let model = model.strip_prefix("deepseek-ai/").unwrap_or(model);
447 let lower = model.to_ascii_lowercase();
448 if lower.contains("v4-pro") || lower.contains("v4_pro") {
452 "deepseek_v4_pro"
453 } else if lower.contains("v4-flash") || lower.contains("v4_flash") {
454 "deepseek_v4_flash"
455 } else if lower.contains("reasoner") || lower.contains("r1") {
456 "deepseek_v3_2_reasoner"
457 } else if lower.contains("chat") || lower.contains("v3") {
458 "deepseek_v3_2_chat"
459 } else {
460 "fallback_default"
461 }
462}
463
464fn log2_1p(v: usize) -> f64 {
465 (1.0 + (v as f64)).log2()
466}
467
468fn push_window<T>(window: &mut VecDeque<T>, value: T, max_len: usize) {
469 window.push_back(value);
470 while window.len() > max_len {
471 window.pop_front();
472 }
473}
474
475fn compute_profile(window: &VecDeque<f64>) -> DynamicSlackProfile {
476 if window.is_empty() {
477 return DynamicSlackProfile::default();
478 }
479
480 let values: Vec<f64> = window.iter().copied().collect();
481 let final_slack = *values.last().unwrap_or(&0.0);
482 let min_slack = values.iter().copied().fold(f64::INFINITY, f64::min);
483 let violations = values.iter().filter(|v| **v <= 0.0).count() as f64;
484 let violation_ratio = violations / (values.len() as f64);
485
486 let deltas: Vec<f64> = values.windows(2).map(|w| w[1] - w[0]).collect();
487 let slack_drop = if values.len() >= 2 {
488 (values[values.len() - 2] - values[values.len() - 1]).max(0.0)
489 } else {
490 0.0
491 };
492
493 let slack_volatility = if deltas.is_empty() {
494 0.0
495 } else {
496 let mean = deltas.iter().sum::<f64>() / (deltas.len() as f64);
497 let var = deltas
498 .iter()
499 .map(|delta| {
500 let centered = *delta - mean;
501 centered * centered
502 })
503 .sum::<f64>()
504 / (deltas.len() as f64);
505 var.sqrt()
506 };
507
508 DynamicSlackProfile {
509 final_slack,
510 min_slack,
511 violation_ratio,
512 slack_volatility,
513 slack_drop,
514 }
515}
516
517fn sigmoid(z: f64) -> f64 {
518 if z >= 0.0 {
519 let ez = (-z).exp();
520 1.0 / (1.0 + ez)
521 } else {
522 let ez = z.exp();
523 ez / (1.0 + ez)
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530
531 fn make_snapshot(p_fail: f64, severe: bool, risk_band: RiskBand) -> CapacitySnapshot {
532 CapacitySnapshot {
533 turn_index: 3,
534 h_hat: 1.0,
535 c_hat: 3.8,
536 slack: 2.8,
537 profile: DynamicSlackProfile {
538 final_slack: 2.8,
539 min_slack: if severe { -0.5 } else { 0.2 },
540 violation_ratio: if severe { 0.6 } else { 0.1 },
541 slack_volatility: 0.2,
542 slack_drop: 0.1,
543 },
544 p_fail,
545 risk_band,
546 severe,
547 }
548 }
549
550 #[test]
551 fn low_risk_maps_to_no_intervention() {
552 let cfg = CapacityControllerConfig::default();
553 let snap = make_snapshot(0.2, false, RiskBand::Low);
554 assert_eq!(decide_policy(&cfg, &snap), GuardrailAction::NoIntervention);
555 }
556
557 #[test]
558 fn medium_risk_maps_to_refresh() {
559 let cfg = CapacityControllerConfig::default();
560 let snap = make_snapshot(0.5, false, RiskBand::Medium);
561 assert_eq!(
562 decide_policy(&cfg, &snap),
563 GuardrailAction::TargetedContextRefresh
564 );
565 }
566
567 #[test]
568 fn high_non_severe_maps_to_replay() {
569 let cfg = CapacityControllerConfig::default();
570 let snap = make_snapshot(0.8, false, RiskBand::High);
571 assert_eq!(
572 decide_policy(&cfg, &snap),
573 GuardrailAction::VerifyWithToolReplay
574 );
575 }
576
577 #[test]
578 fn high_severe_maps_to_replan() {
579 let cfg = CapacityControllerConfig::default();
580 let snap = make_snapshot(0.9, true, RiskBand::High);
581 assert_eq!(decide_policy(&cfg, &snap), GuardrailAction::VerifyAndReplan);
582 }
583
584 #[test]
587 fn default_controller_is_disabled_and_skips_observations() {
588 let cfg = CapacityControllerConfig::default();
589 assert!(!cfg.enabled);
590
591 let mut controller = CapacityController::new(cfg);
592 let snapshot = controller.observe_pre_turn(CapacityObservationInput {
593 turn_index: 1,
594 model: "deepseek-v4-pro".to_string(),
595 action_count_this_turn: 10,
596 tool_calls_recent_window: 10,
597 unique_reference_ids_recent_window: 10,
598 context_used_ratio: 0.95,
599 });
600
601 assert!(snapshot.is_none());
603 }
604
605 #[test]
608 fn opt_in_controller_observes_and_decides() {
609 let cfg = CapacityControllerConfig {
610 enabled: true,
611 ..Default::default()
612 };
613
614 let mut controller = CapacityController::new(cfg);
615 let snapshot = controller.observe_pre_turn(CapacityObservationInput {
616 turn_index: 1,
617 model: "deepseek-v4-pro".to_string(),
618 action_count_this_turn: 10,
619 tool_calls_recent_window: 10,
620 unique_reference_ids_recent_window: 10,
621 context_used_ratio: 0.95,
622 });
623
624 assert!(snapshot.is_some());
625 let snap = snapshot.unwrap();
626 assert_eq!(snap.turn_index, 1);
627 assert!(snap.p_fail > 0.0);
628 }
629
630 #[test]
631 fn normalize_v4_pro_variants() {
632 assert_eq!(
633 normalize_model_prior_key("deepseek-v4-pro"),
634 "deepseek_v4_pro"
635 );
636 assert_eq!(
637 normalize_model_prior_key("deepseek-v4_pro"),
638 "deepseek_v4_pro"
639 );
640 assert_eq!(
641 normalize_model_prior_key("deepseek-ai/deepseek-v4-pro"),
642 "deepseek_v4_pro"
643 );
644 assert_eq!(
645 normalize_model_prior_key("deepseek-ai/deepseek-v4_pro"),
646 "deepseek_v4_pro"
647 );
648 }
649
650 #[test]
651 fn normalize_v4_flash_variants() {
652 assert_eq!(
653 normalize_model_prior_key("deepseek-v4-flash"),
654 "deepseek_v4_flash"
655 );
656 assert_eq!(
657 normalize_model_prior_key("deepseek-v4_flash"),
658 "deepseek_v4_flash"
659 );
660 assert_eq!(
661 normalize_model_prior_key("deepseek-ai/deepseek-v4-flash"),
662 "deepseek_v4_flash"
663 );
664 assert_eq!(
665 normalize_model_prior_key("deepseek-ai/deepseek-v4_flash"),
666 "deepseek_v4_flash"
667 );
668 }
669
670 #[test]
671 fn normalize_v4_and_fallback_prior_keys() {
672 assert_eq!(
673 normalize_model_prior_key("deepseek-v4-pro"),
674 "deepseek_v4_pro"
675 );
676 assert_eq!(
677 normalize_model_prior_key("deepseek-v4-flash"),
678 "deepseek_v4_flash"
679 );
680 assert_eq!(
681 normalize_model_prior_key("unknown-model"),
682 "fallback_default"
683 );
684 }
685
686 #[test]
687 fn v4_priors_loaded_into_default_config() {
688 let cfg = CapacityControllerConfig::default();
689 assert_eq!(cfg.deepseek_v4_pro_prior, 3.5);
690 assert_eq!(cfg.deepseek_v4_flash_prior, 4.2);
691 }
692
693 #[test]
694 fn cooldown_blocks_repeated_action() {
695 let config = CapacityControllerConfig {
698 enabled: true,
699 ..CapacityControllerConfig::default()
700 };
701 let mut controller = CapacityController::new(config);
702 let turn_index = 5;
703 controller.mark_turn_start(turn_index);
704 controller.mark_intervention_applied(turn_index, GuardrailAction::TargetedContextRefresh);
705
706 let snapshot = make_snapshot(0.5, false, RiskBand::Medium);
707 let decision = controller.decide(turn_index + 1, Some(&snapshot));
708 assert_eq!(decision.action, GuardrailAction::NoIntervention);
709 assert!(decision.cooldown_blocked);
710 }
711
712 #[test]
722 #[ignore]
723 fn bench_compute_profile() {
724 use std::time::Instant;
725
726 for &window_len in &[16usize, 64, 256, 1024] {
727 let mut window: VecDeque<f64> = VecDeque::with_capacity(window_len);
728 for i in 0..window_len {
729 #[allow(clippy::cast_precision_loss)]
730 window.push_back((i as f64).sin() * 0.5);
731 }
732
733 let iters = 100_000usize;
734 let start = Instant::now();
735 for _ in 0..iters {
736 let profile = compute_profile(&window);
737 std::hint::black_box(profile);
738 }
739 let elapsed = start.elapsed();
740 let per_call_ns = elapsed.as_nanos() as f64 / iters as f64;
741 println!(
742 "compute_profile window={window_len:>4} total={:?} per-call={per_call_ns:>8.0}ns",
743 elapsed
744 );
745 }
746 }
747}