1use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11use std::sync::LazyLock;
12
13use regex::Regex;
14
15use crate::config::UtilityScoringConfig;
16use crate::executor::ToolCall;
17
18#[must_use]
27pub fn has_explicit_tool_request(user_message: &str) -> bool {
28 static RE: LazyLock<Regex> = LazyLock::new(|| {
29 Regex::new(
30 r"(?xi)
31 using\s+a\s+tool
32 | call\s+(the\s+)?[a-z_]+\s+tool
33 | use\s+(the\s+)?[a-z_]+\s+tool
34 | run\s+(the\s+)?[a-z_]+\s+tool
35 | invoke\s+(the\s+)?[a-z_]+\s+tool
36 | execute\s+(the\s+)?[a-z_]+\s+tool
37 | show\s+me\s+the\s+result\s+of\s*:
38 | run\s*:
39 | execute\s*:
40 | what\s+(does|would|is\s+the\s+output\s+of)
41 ",
42 )
43 .expect("static regex is valid")
44 });
45 static RE_CODE: LazyLock<Regex> =
48 LazyLock::new(|| Regex::new(r"`[^`]*[|><$;&][^`]*`").expect("static regex is valid"));
49 RE.is_match(user_message) || RE_CODE.is_match(user_message)
50}
51
52fn default_gain(tool_name: &str) -> f32 {
57 if tool_name.starts_with("memory") {
58 return 0.8;
59 }
60 if tool_name.starts_with("mcp_") {
61 return 0.5;
62 }
63 match tool_name {
64 "bash" | "shell" => 0.6,
65 "read" | "write" => 0.55,
66 "search_code" | "grep" | "glob" => 0.65,
67 _ => 0.5,
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct UtilityScore {
74 pub gain: f32,
76 pub cost: f32,
78 pub redundancy: f32,
80 pub uncertainty: f32,
82 pub total: f32,
84}
85
86impl UtilityScore {
87 fn is_valid(&self) -> bool {
89 self.gain.is_finite()
90 && self.cost.is_finite()
91 && self.redundancy.is_finite()
92 && self.uncertainty.is_finite()
93 && self.total.is_finite()
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct UtilityContext {
100 pub tool_calls_this_turn: usize,
102 pub tokens_consumed: usize,
104 pub token_budget: usize,
106 pub user_requested: bool,
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum UtilityAction {
115 Respond,
117 Retrieve,
119 ToolCall,
121 Verify,
123 Stop,
125}
126
127fn call_hash(call: &ToolCall) -> u64 {
129 let mut h = DefaultHasher::new();
130 call.tool_id.hash(&mut h);
131 format!("{:?}", call.params).hash(&mut h);
135 h.finish()
136}
137
138#[derive(Debug)]
143pub struct UtilityScorer {
144 config: UtilityScoringConfig,
145 recent_calls: HashMap<u64, u32>,
147}
148
149impl UtilityScorer {
150 #[must_use]
152 pub fn new(config: UtilityScoringConfig) -> Self {
153 Self {
154 config,
155 recent_calls: HashMap::new(),
156 }
157 }
158
159 #[must_use]
161 pub fn is_enabled(&self) -> bool {
162 self.config.enabled
163 }
164
165 #[must_use]
171 pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
172 if !self.config.enabled {
173 return None;
174 }
175
176 let gain = default_gain(call.tool_id.as_str());
177
178 let cost = if ctx.token_budget > 0 {
179 #[allow(clippy::cast_precision_loss)]
180 (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
181 } else {
182 0.0
183 };
184
185 let hash = call_hash(call);
186 let redundancy = if self.recent_calls.contains_key(&hash) {
187 1.0_f32
188 } else {
189 0.0_f32
190 };
191
192 #[allow(clippy::cast_precision_loss)]
195 let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
196
197 let total = self.config.gain_weight * gain
198 - self.config.cost_weight * cost
199 - self.config.redundancy_weight * redundancy
200 + self.config.uncertainty_bonus * uncertainty;
201
202 let score = UtilityScore {
203 gain,
204 cost,
205 redundancy,
206 uncertainty,
207 total,
208 };
209
210 if score.is_valid() { Some(score) } else { None }
211 }
212
213 #[must_use]
227 pub fn recommend_action(
228 &self,
229 score: Option<&UtilityScore>,
230 ctx: &UtilityContext,
231 ) -> UtilityAction {
232 if ctx.user_requested {
234 return UtilityAction::ToolCall;
235 }
236 if !self.config.enabled {
238 return UtilityAction::ToolCall;
239 }
240 let Some(s) = score else {
241 return UtilityAction::Stop;
243 };
244
245 if s.cost > 0.9 {
247 return UtilityAction::Stop;
248 }
249 if s.redundancy >= 1.0 {
251 return UtilityAction::Respond;
252 }
253 if s.gain >= 0.7 && s.total >= self.config.threshold {
255 return UtilityAction::ToolCall;
256 }
257 if s.gain >= 0.5 && s.uncertainty > 0.5 {
259 return UtilityAction::Retrieve;
260 }
261 if s.total < self.config.threshold && ctx.tool_calls_this_turn > 0 {
263 return UtilityAction::Verify;
264 }
265 if s.total >= self.config.threshold {
267 return UtilityAction::ToolCall;
268 }
269 UtilityAction::Respond
270 }
271
272 pub fn record_call(&mut self, call: &ToolCall) {
277 let hash = call_hash(call);
278 *self.recent_calls.entry(hash).or_insert(0) += 1;
279 }
280
281 pub fn clear(&mut self) {
283 self.recent_calls.clear();
284 }
285
286 #[must_use]
290 pub fn is_exempt(&self, tool_name: &str) -> bool {
291 let lower = tool_name.to_lowercase();
292 self.config
293 .exempt_tools
294 .iter()
295 .any(|e| e.to_lowercase() == lower)
296 }
297
298 #[must_use]
300 pub fn threshold(&self) -> f32 {
301 self.config.threshold
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::ToolName;
309 use serde_json::json;
310
311 fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
312 ToolCall {
313 tool_id: ToolName::new(name),
314 params: if let serde_json::Value::Object(m) = params {
315 m
316 } else {
317 serde_json::Map::new()
318 },
319 caller_id: None,
320 context: None,
321 }
322 }
323
324 fn default_ctx() -> UtilityContext {
325 UtilityContext {
326 tool_calls_this_turn: 0,
327 tokens_consumed: 0,
328 token_budget: 1000,
329 user_requested: false,
330 }
331 }
332
333 fn default_config() -> UtilityScoringConfig {
334 UtilityScoringConfig {
335 enabled: true,
336 ..UtilityScoringConfig::default()
337 }
338 }
339
340 #[test]
341 fn disabled_returns_none() {
342 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
343 assert!(!scorer.is_enabled());
344 let call = make_call("bash", json!({}));
345 let score = scorer.score(&call, &default_ctx());
346 assert!(score.is_none());
347 assert_eq!(
349 scorer.recommend_action(score.as_ref(), &default_ctx()),
350 UtilityAction::ToolCall
351 );
352 }
353
354 #[test]
355 fn first_call_passes_default_threshold() {
356 let scorer = UtilityScorer::new(default_config());
357 let call = make_call("bash", json!({"cmd": "ls"}));
358 let score = scorer.score(&call, &default_ctx());
359 assert!(score.is_some());
360 let s = score.unwrap();
361 assert!(
362 s.total >= 0.1,
363 "first call should exceed threshold: {}",
364 s.total
365 );
366 let action = scorer.recommend_action(Some(&s), &default_ctx());
369 assert!(
370 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
371 "first call should not be blocked, got {action:?}",
372 );
373 }
374
375 #[test]
376 fn redundant_call_penalized() {
377 let mut scorer = UtilityScorer::new(default_config());
378 let call = make_call("bash", json!({"cmd": "ls"}));
379 scorer.record_call(&call);
380 let score = scorer.score(&call, &default_ctx()).unwrap();
381 assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
382 }
383
384 #[test]
385 fn clear_resets_redundancy() {
386 let mut scorer = UtilityScorer::new(default_config());
387 let call = make_call("bash", json!({"cmd": "ls"}));
388 scorer.record_call(&call);
389 scorer.clear();
390 let score = scorer.score(&call, &default_ctx()).unwrap();
391 assert!(score.redundancy.abs() < f32::EPSILON);
392 }
393
394 #[test]
395 fn user_requested_always_executes() {
396 let scorer = UtilityScorer::new(default_config());
397 let score = UtilityScore {
399 gain: 0.0,
400 cost: 1.0,
401 redundancy: 1.0,
402 uncertainty: 0.0,
403 total: -100.0,
404 };
405 let ctx = UtilityContext {
406 user_requested: true,
407 ..default_ctx()
408 };
409 assert_eq!(
410 scorer.recommend_action(Some(&score), &ctx),
411 UtilityAction::ToolCall
412 );
413 }
414
415 #[test]
416 fn none_score_fail_closed_when_enabled() {
417 let scorer = UtilityScorer::new(default_config());
418 assert_eq!(
420 scorer.recommend_action(None, &default_ctx()),
421 UtilityAction::Stop
422 );
423 }
424
425 #[test]
426 fn none_score_executes_when_disabled() {
427 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert_eq!(
429 scorer.recommend_action(None, &default_ctx()),
430 UtilityAction::ToolCall
431 );
432 }
433
434 #[test]
435 fn cost_increases_with_token_consumption() {
436 let scorer = UtilityScorer::new(default_config());
437 let call = make_call("bash", json!({}));
438 let ctx_low = UtilityContext {
439 tokens_consumed: 100,
440 token_budget: 1000,
441 ..default_ctx()
442 };
443 let ctx_high = UtilityContext {
444 tokens_consumed: 900,
445 token_budget: 1000,
446 ..default_ctx()
447 };
448 let s_low = scorer.score(&call, &ctx_low).unwrap();
449 let s_high = scorer.score(&call, &ctx_high).unwrap();
450 assert!(s_low.cost < s_high.cost);
451 assert!(s_low.total > s_high.total);
452 }
453
454 #[test]
455 fn uncertainty_decreases_with_call_count() {
456 let scorer = UtilityScorer::new(default_config());
457 let call = make_call("bash", json!({}));
458 let ctx_early = UtilityContext {
459 tool_calls_this_turn: 0,
460 ..default_ctx()
461 };
462 let ctx_late = UtilityContext {
463 tool_calls_this_turn: 9,
464 ..default_ctx()
465 };
466 let s_early = scorer.score(&call, &ctx_early).unwrap();
467 let s_late = scorer.score(&call, &ctx_late).unwrap();
468 assert!(s_early.uncertainty > s_late.uncertainty);
469 }
470
471 #[test]
472 fn memory_tool_has_higher_gain_than_scrape() {
473 let scorer = UtilityScorer::new(default_config());
474 let mem_call = make_call("memory_search", json!({}));
475 let web_call = make_call("scrape", json!({}));
476 let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
477 let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
478 assert!(s_mem.gain > s_web.gain);
479 }
480
481 #[test]
482 fn zero_token_budget_zeroes_cost() {
483 let scorer = UtilityScorer::new(default_config());
484 let call = make_call("bash", json!({}));
485 let ctx = UtilityContext {
486 tokens_consumed: 500,
487 token_budget: 0,
488 ..default_ctx()
489 };
490 let s = scorer.score(&call, &ctx).unwrap();
491 assert!(s.cost.abs() < f32::EPSILON);
492 }
493
494 #[test]
495 fn validate_rejects_negative_weights() {
496 let cfg = UtilityScoringConfig {
497 enabled: true,
498 gain_weight: -1.0,
499 ..UtilityScoringConfig::default()
500 };
501 assert!(cfg.validate().is_err());
502 }
503
504 #[test]
505 fn validate_rejects_nan_weights() {
506 let cfg = UtilityScoringConfig {
507 enabled: true,
508 threshold: f32::NAN,
509 ..UtilityScoringConfig::default()
510 };
511 assert!(cfg.validate().is_err());
512 }
513
514 #[test]
515 fn validate_accepts_default() {
516 assert!(UtilityScoringConfig::default().validate().is_ok());
517 }
518
519 #[test]
520 fn threshold_zero_all_calls_pass() {
521 let scorer = UtilityScorer::new(UtilityScoringConfig {
523 enabled: true,
524 threshold: 0.0,
525 ..UtilityScoringConfig::default()
526 });
527 let call = make_call("bash", json!({}));
528 let score = scorer.score(&call, &default_ctx()).unwrap();
529 assert!(
531 score.total >= 0.0,
532 "total should be non-negative: {}",
533 score.total
534 );
535 let action = scorer.recommend_action(Some(&score), &default_ctx());
537 assert!(
538 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
539 "threshold=0 should not block calls, got {action:?}",
540 );
541 }
542
543 #[test]
544 fn threshold_one_blocks_all_calls() {
545 let scorer = UtilityScorer::new(UtilityScoringConfig {
547 enabled: true,
548 threshold: 1.0,
549 ..UtilityScoringConfig::default()
550 });
551 let call = make_call("bash", json!({}));
552 let score = scorer.score(&call, &default_ctx()).unwrap();
553 assert!(
554 score.total < 1.0,
555 "realistic score should be below 1.0: {}",
556 score.total
557 );
558 assert_ne!(
560 scorer.recommend_action(Some(&score), &default_ctx()),
561 UtilityAction::ToolCall
562 );
563 }
564
565 #[test]
568 fn recommend_action_user_requested_always_tool_call() {
569 let scorer = UtilityScorer::new(default_config());
570 let score = UtilityScore {
571 gain: 0.0,
572 cost: 1.0,
573 redundancy: 1.0,
574 uncertainty: 0.0,
575 total: -100.0,
576 };
577 let ctx = UtilityContext {
578 user_requested: true,
579 ..default_ctx()
580 };
581 assert_eq!(
582 scorer.recommend_action(Some(&score), &ctx),
583 UtilityAction::ToolCall
584 );
585 }
586
587 #[test]
588 fn recommend_action_disabled_scorer_always_tool_call() {
589 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); let ctx = default_ctx();
591 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::ToolCall);
592 }
593
594 #[test]
595 fn recommend_action_none_score_enabled_stops() {
596 let scorer = UtilityScorer::new(default_config());
597 let ctx = default_ctx();
598 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::Stop);
599 }
600
601 #[test]
602 fn recommend_action_budget_exhausted_stops() {
603 let scorer = UtilityScorer::new(default_config());
604 let score = UtilityScore {
605 gain: 0.8,
606 cost: 0.95,
607 redundancy: 0.0,
608 uncertainty: 0.5,
609 total: 0.5,
610 };
611 assert_eq!(
612 scorer.recommend_action(Some(&score), &default_ctx()),
613 UtilityAction::Stop
614 );
615 }
616
617 #[test]
618 fn recommend_action_redundant_responds() {
619 let scorer = UtilityScorer::new(default_config());
620 let score = UtilityScore {
621 gain: 0.8,
622 cost: 0.1,
623 redundancy: 1.0,
624 uncertainty: 0.5,
625 total: 0.5,
626 };
627 assert_eq!(
628 scorer.recommend_action(Some(&score), &default_ctx()),
629 UtilityAction::Respond
630 );
631 }
632
633 #[test]
634 fn recommend_action_high_gain_above_threshold_tool_call() {
635 let scorer = UtilityScorer::new(default_config());
636 let score = UtilityScore {
637 gain: 0.8,
638 cost: 0.1,
639 redundancy: 0.0,
640 uncertainty: 0.4,
641 total: 0.6,
642 };
643 assert_eq!(
644 scorer.recommend_action(Some(&score), &default_ctx()),
645 UtilityAction::ToolCall
646 );
647 }
648
649 #[test]
650 fn recommend_action_uncertain_retrieves() {
651 let scorer = UtilityScorer::new(default_config());
652 let score = UtilityScore {
654 gain: 0.6,
655 cost: 0.1,
656 redundancy: 0.0,
657 uncertainty: 0.8,
658 total: 0.4,
659 };
660 assert_eq!(
661 scorer.recommend_action(Some(&score), &default_ctx()),
662 UtilityAction::Retrieve
663 );
664 }
665
666 #[test]
667 fn recommend_action_below_threshold_with_prior_calls_verifies() {
668 let scorer = UtilityScorer::new(default_config());
669 let score = UtilityScore {
670 gain: 0.3,
671 cost: 0.1,
672 redundancy: 0.0,
673 uncertainty: 0.2,
674 total: 0.05, };
676 let ctx = UtilityContext {
677 tool_calls_this_turn: 1,
678 ..default_ctx()
679 };
680 assert_eq!(
681 scorer.recommend_action(Some(&score), &ctx),
682 UtilityAction::Verify
683 );
684 }
685
686 #[test]
687 fn recommend_action_default_responds() {
688 let scorer = UtilityScorer::new(default_config());
689 let score = UtilityScore {
690 gain: 0.3,
691 cost: 0.1,
692 redundancy: 0.0,
693 uncertainty: 0.2,
694 total: 0.05, };
696 let ctx = UtilityContext {
697 tool_calls_this_turn: 0,
698 ..default_ctx()
699 };
700 assert_eq!(
701 scorer.recommend_action(Some(&score), &ctx),
702 UtilityAction::Respond
703 );
704 }
705
706 #[test]
709 fn explicit_request_using_a_tool() {
710 assert!(has_explicit_tool_request(
711 "Please list the files in the current directory using a tool"
712 ));
713 }
714
715 #[test]
716 fn explicit_request_call_the_tool() {
717 assert!(has_explicit_tool_request("call the list_directory tool"));
718 }
719
720 #[test]
721 fn explicit_request_use_the_tool() {
722 assert!(has_explicit_tool_request("use the shell tool to run ls"));
723 }
724
725 #[test]
726 fn explicit_request_run_the_tool() {
727 assert!(has_explicit_tool_request("run the bash tool"));
728 }
729
730 #[test]
731 fn explicit_request_invoke_the_tool() {
732 assert!(has_explicit_tool_request("invoke the search_code tool"));
733 }
734
735 #[test]
736 fn explicit_request_execute_the_tool() {
737 assert!(has_explicit_tool_request("execute the grep tool for me"));
738 }
739
740 #[test]
741 fn explicit_request_case_insensitive() {
742 assert!(has_explicit_tool_request("USING A TOOL to find files"));
743 }
744
745 #[test]
746 fn explicit_request_no_match_plain_message() {
747 assert!(!has_explicit_tool_request("what is the weather today?"));
748 }
749
750 #[test]
751 fn explicit_request_no_match_tool_mentioned_without_invocation() {
752 assert!(!has_explicit_tool_request(
753 "the shell tool is very useful in general"
754 ));
755 }
756
757 #[test]
758 fn explicit_request_show_me_result_of() {
759 assert!(has_explicit_tool_request(
760 "show me the result of: echo hello"
761 ));
762 }
763
764 #[test]
765 fn explicit_request_run_colon() {
766 assert!(has_explicit_tool_request("run: echo hello"));
767 }
768
769 #[test]
770 fn explicit_request_execute_colon() {
771 assert!(has_explicit_tool_request("execute: ls -la"));
772 }
773
774 #[test]
775 fn explicit_request_what_does() {
776 assert!(has_explicit_tool_request("what does echo hello output?"));
777 }
778
779 #[test]
780 fn explicit_request_what_would() {
781 assert!(has_explicit_tool_request("what would cat /etc/hosts show?"));
782 }
783
784 #[test]
785 fn explicit_request_what_is_the_output_of() {
786 assert!(has_explicit_tool_request(
787 "what is the output of ls | grep foo?"
788 ));
789 }
790
791 #[test]
792 fn explicit_request_inline_code_pipe() {
793 assert!(has_explicit_tool_request("try running `ls | grep foo`"));
794 }
795
796 #[test]
797 fn explicit_request_inline_code_redirect() {
798 assert!(has_explicit_tool_request("run `echo hello > /tmp/out`"));
799 }
800
801 #[test]
802 fn explicit_request_inline_code_dollar() {
803 assert!(has_explicit_tool_request("check `$HOME/bin`"));
804 }
805
806 #[test]
807 fn explicit_request_inline_code_and() {
808 assert!(has_explicit_tool_request("try `git fetch && git rebase`"));
809 }
810
811 #[test]
812 fn no_match_run_the_tests() {
813 assert!(!has_explicit_tool_request("run the tests please"));
814 }
815
816 #[test]
817 fn no_match_execute_the_plan() {
818 assert!(!has_explicit_tool_request("execute the plan we discussed"));
819 }
820
821 #[test]
822 fn no_match_inline_code_no_shell_syntax() {
823 assert!(!has_explicit_tool_request(
824 "the function `process_items` handles it"
825 ));
826 }
827
828 #[test]
833 fn known_fp_what_does_function_do() {
834 assert!(has_explicit_tool_request("what does this function do?"));
836 }
837
838 #[test]
839 fn no_match_show_me_result_without_colon() {
840 assert!(!has_explicit_tool_request(
842 "show me the result of running it"
843 ));
844 }
845
846 #[test]
847 fn is_exempt_matches_case_insensitively() {
848 let scorer = UtilityScorer::new(UtilityScoringConfig {
849 enabled: true,
850 exempt_tools: vec!["Read".to_owned(), "file_read".to_owned()],
851 ..UtilityScoringConfig::default()
852 });
853 assert!(scorer.is_exempt("read"));
854 assert!(scorer.is_exempt("READ"));
855 assert!(scorer.is_exempt("FILE_READ"));
856 assert!(!scorer.is_exempt("write"));
857 assert!(!scorer.is_exempt("bash"));
858 }
859
860 #[test]
861 fn is_exempt_empty_list_returns_false() {
862 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
863 assert!(!scorer.is_exempt("read"));
864 }
865}