1use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11use std::sync::LazyLock;
12
13use regex::Regex;
14
15use crate::config::UtilityScoringConfig;
16use crate::executor::ToolCall;
17
18#[must_use]
27pub fn has_explicit_tool_request(user_message: &str) -> bool {
28 static RE: LazyLock<Regex> = LazyLock::new(|| {
29 Regex::new(
30 r"(?xi)
31 using\s+a\s+tool
32 | call\s+(the\s+)?[a-z_]+\s+tool
33 | use\s+(the\s+)?[a-z_]+\s+tool
34 | run\s+(the\s+)?[a-z_]+\s+tool
35 | invoke\s+(the\s+)?[a-z_]+\s+tool
36 | execute\s+(the\s+)?[a-z_]+\s+tool
37 | show\s+me\s+the\s+result\s+of\s*:
38 | run\s*:
39 | execute\s*:
40 | what\s+(does|would|is\s+the\s+output\s+of)
41 ",
42 )
43 .expect("static regex is valid")
44 });
45 static RE_CODE: LazyLock<Regex> =
48 LazyLock::new(|| Regex::new(r"`[^`]*[|><$;&][^`]*`").expect("static regex is valid"));
49 RE.is_match(user_message) || RE_CODE.is_match(user_message)
50}
51
52fn default_gain(tool_name: &str) -> f32 {
57 if tool_name.starts_with("memory") {
58 return 0.8;
59 }
60 if tool_name.starts_with("mcp_") {
61 return 0.5;
62 }
63 match tool_name {
64 "bash" | "shell" => 0.6,
65 "read" | "write" => 0.55,
66 "search_code" | "grep" | "glob" => 0.65,
67 _ => 0.5,
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct UtilityScore {
74 pub gain: f32,
76 pub cost: f32,
78 pub redundancy: f32,
80 pub uncertainty: f32,
82 pub total: f32,
84}
85
86impl UtilityScore {
87 fn is_valid(&self) -> bool {
89 self.gain.is_finite()
90 && self.cost.is_finite()
91 && self.redundancy.is_finite()
92 && self.uncertainty.is_finite()
93 && self.total.is_finite()
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct UtilityContext {
100 pub tool_calls_this_turn: usize,
102 pub tokens_consumed: usize,
104 pub token_budget: usize,
106 pub user_requested: bool,
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum UtilityAction {
115 Respond,
117 Retrieve,
119 ToolCall,
121 Verify,
123 Stop,
125}
126
127fn call_hash(call: &ToolCall) -> u64 {
129 let mut h = DefaultHasher::new();
130 call.tool_id.hash(&mut h);
131 format!("{:?}", call.params).hash(&mut h);
135 h.finish()
136}
137
138#[derive(Debug)]
143pub struct UtilityScorer {
144 config: UtilityScoringConfig,
145 recent_calls: HashMap<u64, u32>,
147}
148
149impl UtilityScorer {
150 #[must_use]
152 pub fn new(config: UtilityScoringConfig) -> Self {
153 Self {
154 config,
155 recent_calls: HashMap::new(),
156 }
157 }
158
159 #[must_use]
161 pub fn is_enabled(&self) -> bool {
162 self.config.enabled
163 }
164
165 #[must_use]
171 pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
172 if !self.config.enabled {
173 return None;
174 }
175
176 let gain = default_gain(call.tool_id.as_str());
177
178 let cost = if ctx.token_budget > 0 {
179 #[allow(clippy::cast_precision_loss)]
180 (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
181 } else {
182 0.0
183 };
184
185 let hash = call_hash(call);
186 let redundancy = if self.recent_calls.contains_key(&hash) {
187 1.0_f32
188 } else {
189 0.0_f32
190 };
191
192 #[allow(clippy::cast_precision_loss)]
195 let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
196
197 let total = self.config.gain_weight * gain
198 - self.config.cost_weight * cost
199 - self.config.redundancy_weight * redundancy
200 + self.config.uncertainty_bonus * uncertainty;
201
202 let score = UtilityScore {
203 gain,
204 cost,
205 redundancy,
206 uncertainty,
207 total,
208 };
209
210 if score.is_valid() { Some(score) } else { None }
211 }
212
213 #[must_use]
227 pub fn recommend_action(
228 &self,
229 score: Option<&UtilityScore>,
230 ctx: &UtilityContext,
231 ) -> UtilityAction {
232 if ctx.user_requested {
234 return UtilityAction::ToolCall;
235 }
236 if !self.config.enabled {
238 return UtilityAction::ToolCall;
239 }
240 let Some(s) = score else {
241 return UtilityAction::Stop;
243 };
244
245 if s.cost > 0.9 {
247 return UtilityAction::Stop;
248 }
249 if s.redundancy >= 1.0 {
251 return UtilityAction::Respond;
252 }
253 if s.gain >= 0.7 && s.total >= self.config.threshold {
255 return UtilityAction::ToolCall;
256 }
257 if s.gain >= 0.5 && s.uncertainty > 0.5 {
259 return UtilityAction::Retrieve;
260 }
261 if s.total < self.config.threshold && ctx.tool_calls_this_turn > 0 {
263 return UtilityAction::Verify;
264 }
265 if s.total >= self.config.threshold {
267 return UtilityAction::ToolCall;
268 }
269 UtilityAction::Respond
270 }
271
272 pub fn record_call(&mut self, call: &ToolCall) {
277 let hash = call_hash(call);
278 *self.recent_calls.entry(hash).or_insert(0) += 1;
279 }
280
281 pub fn clear(&mut self) {
283 self.recent_calls.clear();
284 }
285
286 #[must_use]
290 pub fn is_exempt(&self, tool_name: &str) -> bool {
291 let lower = tool_name.to_lowercase();
292 self.config
293 .exempt_tools
294 .iter()
295 .any(|e| e.to_lowercase() == lower)
296 }
297
298 #[must_use]
300 pub fn threshold(&self) -> f32 {
301 self.config.threshold
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::ToolName;
309 use serde_json::json;
310
311 fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
312 ToolCall {
313 tool_id: ToolName::new(name),
314 params: if let serde_json::Value::Object(m) = params {
315 m
316 } else {
317 serde_json::Map::new()
318 },
319 caller_id: None,
320 context: None,
321
322 tool_call_id: String::new(),
323 }
324 }
325
326 fn default_ctx() -> UtilityContext {
327 UtilityContext {
328 tool_calls_this_turn: 0,
329 tokens_consumed: 0,
330 token_budget: 1000,
331 user_requested: false,
332 }
333 }
334
335 fn default_config() -> UtilityScoringConfig {
336 UtilityScoringConfig {
337 enabled: true,
338 ..UtilityScoringConfig::default()
339 }
340 }
341
342 #[test]
343 fn disabled_returns_none() {
344 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
345 assert!(!scorer.is_enabled());
346 let call = make_call("bash", json!({}));
347 let score = scorer.score(&call, &default_ctx());
348 assert!(score.is_none());
349 assert_eq!(
351 scorer.recommend_action(score.as_ref(), &default_ctx()),
352 UtilityAction::ToolCall
353 );
354 }
355
356 #[test]
357 fn first_call_passes_default_threshold() {
358 let scorer = UtilityScorer::new(default_config());
359 let call = make_call("bash", json!({"cmd": "ls"}));
360 let score = scorer.score(&call, &default_ctx());
361 assert!(score.is_some());
362 let s = score.unwrap();
363 assert!(
364 s.total >= 0.1,
365 "first call should exceed threshold: {}",
366 s.total
367 );
368 let action = scorer.recommend_action(Some(&s), &default_ctx());
371 assert!(
372 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
373 "first call should not be blocked, got {action:?}",
374 );
375 }
376
377 #[test]
378 fn redundant_call_penalized() {
379 let mut scorer = UtilityScorer::new(default_config());
380 let call = make_call("bash", json!({"cmd": "ls"}));
381 scorer.record_call(&call);
382 let score = scorer.score(&call, &default_ctx()).unwrap();
383 assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
384 }
385
386 #[test]
387 fn clear_resets_redundancy() {
388 let mut scorer = UtilityScorer::new(default_config());
389 let call = make_call("bash", json!({"cmd": "ls"}));
390 scorer.record_call(&call);
391 scorer.clear();
392 let score = scorer.score(&call, &default_ctx()).unwrap();
393 assert!(score.redundancy.abs() < f32::EPSILON);
394 }
395
396 #[test]
397 fn user_requested_always_executes() {
398 let scorer = UtilityScorer::new(default_config());
399 let score = UtilityScore {
401 gain: 0.0,
402 cost: 1.0,
403 redundancy: 1.0,
404 uncertainty: 0.0,
405 total: -100.0,
406 };
407 let ctx = UtilityContext {
408 user_requested: true,
409 ..default_ctx()
410 };
411 assert_eq!(
412 scorer.recommend_action(Some(&score), &ctx),
413 UtilityAction::ToolCall
414 );
415 }
416
417 #[test]
418 fn none_score_fail_closed_when_enabled() {
419 let scorer = UtilityScorer::new(default_config());
420 assert_eq!(
422 scorer.recommend_action(None, &default_ctx()),
423 UtilityAction::Stop
424 );
425 }
426
427 #[test]
428 fn none_score_executes_when_disabled() {
429 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert_eq!(
431 scorer.recommend_action(None, &default_ctx()),
432 UtilityAction::ToolCall
433 );
434 }
435
436 #[test]
437 fn cost_increases_with_token_consumption() {
438 let scorer = UtilityScorer::new(default_config());
439 let call = make_call("bash", json!({}));
440 let ctx_low = UtilityContext {
441 tokens_consumed: 100,
442 token_budget: 1000,
443 ..default_ctx()
444 };
445 let ctx_high = UtilityContext {
446 tokens_consumed: 900,
447 token_budget: 1000,
448 ..default_ctx()
449 };
450 let s_low = scorer.score(&call, &ctx_low).unwrap();
451 let s_high = scorer.score(&call, &ctx_high).unwrap();
452 assert!(s_low.cost < s_high.cost);
453 assert!(s_low.total > s_high.total);
454 }
455
456 #[test]
457 fn uncertainty_decreases_with_call_count() {
458 let scorer = UtilityScorer::new(default_config());
459 let call = make_call("bash", json!({}));
460 let ctx_early = UtilityContext {
461 tool_calls_this_turn: 0,
462 ..default_ctx()
463 };
464 let ctx_late = UtilityContext {
465 tool_calls_this_turn: 9,
466 ..default_ctx()
467 };
468 let s_early = scorer.score(&call, &ctx_early).unwrap();
469 let s_late = scorer.score(&call, &ctx_late).unwrap();
470 assert!(s_early.uncertainty > s_late.uncertainty);
471 }
472
473 #[test]
474 fn memory_tool_has_higher_gain_than_scrape() {
475 let scorer = UtilityScorer::new(default_config());
476 let mem_call = make_call("memory_search", json!({}));
477 let web_call = make_call("scrape", json!({}));
478 let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
479 let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
480 assert!(s_mem.gain > s_web.gain);
481 }
482
483 #[test]
484 fn zero_token_budget_zeroes_cost() {
485 let scorer = UtilityScorer::new(default_config());
486 let call = make_call("bash", json!({}));
487 let ctx = UtilityContext {
488 tokens_consumed: 500,
489 token_budget: 0,
490 ..default_ctx()
491 };
492 let s = scorer.score(&call, &ctx).unwrap();
493 assert!(s.cost.abs() < f32::EPSILON);
494 }
495
496 #[test]
497 fn validate_rejects_negative_weights() {
498 let cfg = UtilityScoringConfig {
499 enabled: true,
500 gain_weight: -1.0,
501 ..UtilityScoringConfig::default()
502 };
503 assert!(cfg.validate().is_err());
504 }
505
506 #[test]
507 fn validate_rejects_nan_weights() {
508 let cfg = UtilityScoringConfig {
509 enabled: true,
510 threshold: f32::NAN,
511 ..UtilityScoringConfig::default()
512 };
513 assert!(cfg.validate().is_err());
514 }
515
516 #[test]
517 fn validate_accepts_default() {
518 assert!(UtilityScoringConfig::default().validate().is_ok());
519 }
520
521 #[test]
522 fn threshold_zero_all_calls_pass() {
523 let scorer = UtilityScorer::new(UtilityScoringConfig {
525 enabled: true,
526 threshold: 0.0,
527 ..UtilityScoringConfig::default()
528 });
529 let call = make_call("bash", json!({}));
530 let score = scorer.score(&call, &default_ctx()).unwrap();
531 assert!(
533 score.total >= 0.0,
534 "total should be non-negative: {}",
535 score.total
536 );
537 let action = scorer.recommend_action(Some(&score), &default_ctx());
539 assert!(
540 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
541 "threshold=0 should not block calls, got {action:?}",
542 );
543 }
544
545 #[test]
546 fn threshold_one_blocks_all_calls() {
547 let scorer = UtilityScorer::new(UtilityScoringConfig {
549 enabled: true,
550 threshold: 1.0,
551 ..UtilityScoringConfig::default()
552 });
553 let call = make_call("bash", json!({}));
554 let score = scorer.score(&call, &default_ctx()).unwrap();
555 assert!(
556 score.total < 1.0,
557 "realistic score should be below 1.0: {}",
558 score.total
559 );
560 assert_ne!(
562 scorer.recommend_action(Some(&score), &default_ctx()),
563 UtilityAction::ToolCall
564 );
565 }
566
567 #[test]
570 fn recommend_action_user_requested_always_tool_call() {
571 let scorer = UtilityScorer::new(default_config());
572 let score = UtilityScore {
573 gain: 0.0,
574 cost: 1.0,
575 redundancy: 1.0,
576 uncertainty: 0.0,
577 total: -100.0,
578 };
579 let ctx = UtilityContext {
580 user_requested: true,
581 ..default_ctx()
582 };
583 assert_eq!(
584 scorer.recommend_action(Some(&score), &ctx),
585 UtilityAction::ToolCall
586 );
587 }
588
589 #[test]
590 fn recommend_action_disabled_scorer_always_tool_call() {
591 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); let ctx = default_ctx();
593 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::ToolCall);
594 }
595
596 #[test]
597 fn recommend_action_none_score_enabled_stops() {
598 let scorer = UtilityScorer::new(default_config());
599 let ctx = default_ctx();
600 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::Stop);
601 }
602
603 #[test]
604 fn recommend_action_budget_exhausted_stops() {
605 let scorer = UtilityScorer::new(default_config());
606 let score = UtilityScore {
607 gain: 0.8,
608 cost: 0.95,
609 redundancy: 0.0,
610 uncertainty: 0.5,
611 total: 0.5,
612 };
613 assert_eq!(
614 scorer.recommend_action(Some(&score), &default_ctx()),
615 UtilityAction::Stop
616 );
617 }
618
619 #[test]
620 fn recommend_action_redundant_responds() {
621 let scorer = UtilityScorer::new(default_config());
622 let score = UtilityScore {
623 gain: 0.8,
624 cost: 0.1,
625 redundancy: 1.0,
626 uncertainty: 0.5,
627 total: 0.5,
628 };
629 assert_eq!(
630 scorer.recommend_action(Some(&score), &default_ctx()),
631 UtilityAction::Respond
632 );
633 }
634
635 #[test]
636 fn recommend_action_high_gain_above_threshold_tool_call() {
637 let scorer = UtilityScorer::new(default_config());
638 let score = UtilityScore {
639 gain: 0.8,
640 cost: 0.1,
641 redundancy: 0.0,
642 uncertainty: 0.4,
643 total: 0.6,
644 };
645 assert_eq!(
646 scorer.recommend_action(Some(&score), &default_ctx()),
647 UtilityAction::ToolCall
648 );
649 }
650
651 #[test]
652 fn recommend_action_uncertain_retrieves() {
653 let scorer = UtilityScorer::new(default_config());
654 let score = UtilityScore {
656 gain: 0.6,
657 cost: 0.1,
658 redundancy: 0.0,
659 uncertainty: 0.8,
660 total: 0.4,
661 };
662 assert_eq!(
663 scorer.recommend_action(Some(&score), &default_ctx()),
664 UtilityAction::Retrieve
665 );
666 }
667
668 #[test]
669 fn recommend_action_below_threshold_with_prior_calls_verifies() {
670 let scorer = UtilityScorer::new(default_config());
671 let score = UtilityScore {
672 gain: 0.3,
673 cost: 0.1,
674 redundancy: 0.0,
675 uncertainty: 0.2,
676 total: 0.05, };
678 let ctx = UtilityContext {
679 tool_calls_this_turn: 1,
680 ..default_ctx()
681 };
682 assert_eq!(
683 scorer.recommend_action(Some(&score), &ctx),
684 UtilityAction::Verify
685 );
686 }
687
688 #[test]
689 fn recommend_action_default_responds() {
690 let scorer = UtilityScorer::new(default_config());
691 let score = UtilityScore {
692 gain: 0.3,
693 cost: 0.1,
694 redundancy: 0.0,
695 uncertainty: 0.2,
696 total: 0.05, };
698 let ctx = UtilityContext {
699 tool_calls_this_turn: 0,
700 ..default_ctx()
701 };
702 assert_eq!(
703 scorer.recommend_action(Some(&score), &ctx),
704 UtilityAction::Respond
705 );
706 }
707
708 #[test]
711 fn explicit_request_using_a_tool() {
712 assert!(has_explicit_tool_request(
713 "Please list the files in the current directory using a tool"
714 ));
715 }
716
717 #[test]
718 fn explicit_request_call_the_tool() {
719 assert!(has_explicit_tool_request("call the list_directory tool"));
720 }
721
722 #[test]
723 fn explicit_request_use_the_tool() {
724 assert!(has_explicit_tool_request("use the shell tool to run ls"));
725 }
726
727 #[test]
728 fn explicit_request_run_the_tool() {
729 assert!(has_explicit_tool_request("run the bash tool"));
730 }
731
732 #[test]
733 fn explicit_request_invoke_the_tool() {
734 assert!(has_explicit_tool_request("invoke the search_code tool"));
735 }
736
737 #[test]
738 fn explicit_request_execute_the_tool() {
739 assert!(has_explicit_tool_request("execute the grep tool for me"));
740 }
741
742 #[test]
743 fn explicit_request_case_insensitive() {
744 assert!(has_explicit_tool_request("USING A TOOL to find files"));
745 }
746
747 #[test]
748 fn explicit_request_no_match_plain_message() {
749 assert!(!has_explicit_tool_request("what is the weather today?"));
750 }
751
752 #[test]
753 fn explicit_request_no_match_tool_mentioned_without_invocation() {
754 assert!(!has_explicit_tool_request(
755 "the shell tool is very useful in general"
756 ));
757 }
758
759 #[test]
760 fn explicit_request_show_me_result_of() {
761 assert!(has_explicit_tool_request(
762 "show me the result of: echo hello"
763 ));
764 }
765
766 #[test]
767 fn explicit_request_run_colon() {
768 assert!(has_explicit_tool_request("run: echo hello"));
769 }
770
771 #[test]
772 fn explicit_request_execute_colon() {
773 assert!(has_explicit_tool_request("execute: ls -la"));
774 }
775
776 #[test]
777 fn explicit_request_what_does() {
778 assert!(has_explicit_tool_request("what does echo hello output?"));
779 }
780
781 #[test]
782 fn explicit_request_what_would() {
783 assert!(has_explicit_tool_request("what would cat /etc/hosts show?"));
784 }
785
786 #[test]
787 fn explicit_request_what_is_the_output_of() {
788 assert!(has_explicit_tool_request(
789 "what is the output of ls | grep foo?"
790 ));
791 }
792
793 #[test]
794 fn explicit_request_inline_code_pipe() {
795 assert!(has_explicit_tool_request("try running `ls | grep foo`"));
796 }
797
798 #[test]
799 fn explicit_request_inline_code_redirect() {
800 assert!(has_explicit_tool_request("run `echo hello > /tmp/out`"));
801 }
802
803 #[test]
804 fn explicit_request_inline_code_dollar() {
805 assert!(has_explicit_tool_request("check `$HOME/bin`"));
806 }
807
808 #[test]
809 fn explicit_request_inline_code_and() {
810 assert!(has_explicit_tool_request("try `git fetch && git rebase`"));
811 }
812
813 #[test]
814 fn no_match_run_the_tests() {
815 assert!(!has_explicit_tool_request("run the tests please"));
816 }
817
818 #[test]
819 fn no_match_execute_the_plan() {
820 assert!(!has_explicit_tool_request("execute the plan we discussed"));
821 }
822
823 #[test]
824 fn no_match_inline_code_no_shell_syntax() {
825 assert!(!has_explicit_tool_request(
826 "the function `process_items` handles it"
827 ));
828 }
829
830 #[test]
835 fn known_fp_what_does_function_do() {
836 assert!(has_explicit_tool_request("what does this function do?"));
838 }
839
840 #[test]
841 fn no_match_show_me_result_without_colon() {
842 assert!(!has_explicit_tool_request(
844 "show me the result of running it"
845 ));
846 }
847
848 #[test]
849 fn is_exempt_matches_case_insensitively() {
850 let scorer = UtilityScorer::new(UtilityScoringConfig {
851 enabled: true,
852 exempt_tools: vec!["Read".to_owned(), "file_read".to_owned()],
853 ..UtilityScoringConfig::default()
854 });
855 assert!(scorer.is_exempt("read"));
856 assert!(scorer.is_exempt("READ"));
857 assert!(scorer.is_exempt("FILE_READ"));
858 assert!(!scorer.is_exempt("write"));
859 assert!(!scorer.is_exempt("bash"));
860 }
861
862 #[test]
863 fn is_exempt_empty_list_returns_false() {
864 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
865 assert!(!scorer.is_exempt("read"));
866 }
867}