1use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11use std::sync::LazyLock;
12
13use regex::Regex;
14
15use crate::config::UtilityScoringConfig;
16use crate::executor::ToolCall;
17
18#[must_use]
27pub fn has_explicit_tool_request(user_message: &str) -> bool {
28 static RE: LazyLock<Regex> = LazyLock::new(|| {
29 Regex::new(
30 r"(?xi)
31 using\s+a\s+tool
32 | call\s+(the\s+)?[a-z_]+\s+tool
33 | use\s+(the\s+)?[a-z_]+\s+tool
34 | run\s+(the\s+)?[a-z_]+\s+tool
35 | invoke\s+(the\s+)?[a-z_]+\s+tool
36 | execute\s+(the\s+)?[a-z_]+\s+tool
37 | show\s+me\s+the\s+result\s+of\s*:
38 | run\s*:
39 | execute\s*:
40 | what\s+(does|would|is\s+the\s+output\s+of)
41 ",
42 )
43 .expect("static regex is valid")
44 });
45 static RE_CODE: LazyLock<Regex> =
48 LazyLock::new(|| Regex::new(r"`[^`]*[|><$;&][^`]*`").expect("static regex is valid"));
49 RE.is_match(user_message) || RE_CODE.is_match(user_message)
50}
51
52fn default_gain(tool_name: &str) -> f32 {
57 if tool_name.starts_with("memory") {
58 return 0.8;
59 }
60 if tool_name.starts_with("mcp_") {
61 return 0.5;
62 }
63 match tool_name {
64 "bash" | "shell" => 0.6,
65 "read" | "write" => 0.55,
66 "search_code" | "grep" | "glob" => 0.65,
67 _ => 0.5,
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct UtilityScore {
74 pub gain: f32,
76 pub cost: f32,
78 pub redundancy: f32,
80 pub uncertainty: f32,
82 pub total: f32,
84}
85
86impl UtilityScore {
87 fn is_valid(&self) -> bool {
89 self.gain.is_finite()
90 && self.cost.is_finite()
91 && self.redundancy.is_finite()
92 && self.uncertainty.is_finite()
93 && self.total.is_finite()
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct UtilityContext {
100 pub tool_calls_this_turn: usize,
102 pub tokens_consumed: usize,
104 pub token_budget: usize,
106 pub user_requested: bool,
110}
111
112#[non_exhaustive]
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum UtilityAction {
116 Respond,
118 Retrieve,
120 ToolCall,
122 Verify,
124 Stop,
126}
127
128fn call_hash(call: &ToolCall) -> u64 {
130 let mut h = DefaultHasher::new();
131 call.tool_id.hash(&mut h);
132 format!("{:?}", call.params).hash(&mut h);
136 h.finish()
137}
138
139#[derive(Debug)]
144pub struct UtilityScorer {
145 config: UtilityScoringConfig,
146 recent_calls: HashMap<u64, u32>,
148}
149
150impl UtilityScorer {
151 #[must_use]
153 pub fn new(config: UtilityScoringConfig) -> Self {
154 Self {
155 config,
156 recent_calls: HashMap::new(),
157 }
158 }
159
160 #[must_use]
162 pub fn is_enabled(&self) -> bool {
163 self.config.enabled
164 }
165
166 #[must_use]
172 pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
173 if !self.config.enabled {
174 return None;
175 }
176
177 let gain = default_gain(call.tool_id.as_str());
178
179 let cost = if ctx.token_budget > 0 {
180 #[allow(clippy::cast_precision_loss)]
181 (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
182 } else {
183 0.0
184 };
185
186 let hash = call_hash(call);
187 let redundancy = if self.recent_calls.contains_key(&hash) {
188 1.0_f32
189 } else {
190 0.0_f32
191 };
192
193 #[allow(clippy::cast_precision_loss)]
196 let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
197
198 let total = self.config.gain_weight * gain
199 - self.config.cost_weight * cost
200 - self.config.redundancy_weight * redundancy
201 + self.config.uncertainty_bonus * uncertainty;
202
203 let score = UtilityScore {
204 gain,
205 cost,
206 redundancy,
207 uncertainty,
208 total,
209 };
210
211 if score.is_valid() { Some(score) } else { None }
212 }
213
214 #[must_use]
228 pub fn recommend_action(
229 &self,
230 score: Option<&UtilityScore>,
231 ctx: &UtilityContext,
232 ) -> UtilityAction {
233 if ctx.user_requested {
235 return UtilityAction::ToolCall;
236 }
237 if !self.config.enabled {
239 return UtilityAction::ToolCall;
240 }
241 let Some(s) = score else {
242 return UtilityAction::Stop;
244 };
245
246 if s.cost > 0.9 {
248 return UtilityAction::Stop;
249 }
250 if s.redundancy >= 1.0 {
252 return UtilityAction::Respond;
253 }
254 if s.gain >= 0.7 && s.total >= self.config.threshold {
256 return UtilityAction::ToolCall;
257 }
258 if s.gain >= 0.5 && s.uncertainty > 0.5 {
260 return UtilityAction::Retrieve;
261 }
262 if s.total < self.config.threshold && ctx.tool_calls_this_turn > 0 {
264 return UtilityAction::Verify;
265 }
266 if s.total >= self.config.threshold {
268 return UtilityAction::ToolCall;
269 }
270 UtilityAction::Respond
271 }
272
273 pub fn record_call(&mut self, call: &ToolCall) {
278 let hash = call_hash(call);
279 *self.recent_calls.entry(hash).or_insert(0) += 1;
280 }
281
282 pub fn clear(&mut self) {
284 self.recent_calls.clear();
285 }
286
287 #[must_use]
291 pub fn is_exempt(&self, tool_name: &str) -> bool {
292 let lower = tool_name.to_lowercase();
293 self.config
294 .exempt_tools
295 .iter()
296 .any(|e| e.to_lowercase() == lower)
297 }
298
299 #[must_use]
301 pub fn threshold(&self) -> f32 {
302 self.config.threshold
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::ToolName;
310 use serde_json::json;
311
312 fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
313 ToolCall {
314 tool_id: ToolName::new(name),
315 params: if let serde_json::Value::Object(m) = params {
316 m
317 } else {
318 serde_json::Map::new()
319 },
320 caller_id: None,
321 context: None,
322
323 tool_call_id: String::new(),
324 skill_name: None,
325 }
326 }
327
328 fn default_ctx() -> UtilityContext {
329 UtilityContext {
330 tool_calls_this_turn: 0,
331 tokens_consumed: 0,
332 token_budget: 1000,
333 user_requested: false,
334 }
335 }
336
337 fn default_config() -> UtilityScoringConfig {
338 UtilityScoringConfig {
339 enabled: true,
340 ..UtilityScoringConfig::default()
341 }
342 }
343
344 #[test]
345 fn disabled_returns_none() {
346 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
347 assert!(!scorer.is_enabled());
348 let call = make_call("bash", json!({}));
349 let score = scorer.score(&call, &default_ctx());
350 assert!(score.is_none());
351 assert_eq!(
353 scorer.recommend_action(score.as_ref(), &default_ctx()),
354 UtilityAction::ToolCall
355 );
356 }
357
358 #[test]
359 fn first_call_passes_default_threshold() {
360 let scorer = UtilityScorer::new(default_config());
361 let call = make_call("bash", json!({"cmd": "ls"}));
362 let score = scorer.score(&call, &default_ctx());
363 assert!(score.is_some());
364 let s = score.unwrap();
365 assert!(
366 s.total >= 0.1,
367 "first call should exceed threshold: {}",
368 s.total
369 );
370 let action = scorer.recommend_action(Some(&s), &default_ctx());
373 assert!(
374 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
375 "first call should not be blocked, got {action:?}",
376 );
377 }
378
379 #[test]
380 fn redundant_call_penalized() {
381 let mut scorer = UtilityScorer::new(default_config());
382 let call = make_call("bash", json!({"cmd": "ls"}));
383 scorer.record_call(&call);
384 let score = scorer.score(&call, &default_ctx()).unwrap();
385 assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
386 }
387
388 #[test]
389 fn clear_resets_redundancy() {
390 let mut scorer = UtilityScorer::new(default_config());
391 let call = make_call("bash", json!({"cmd": "ls"}));
392 scorer.record_call(&call);
393 scorer.clear();
394 let score = scorer.score(&call, &default_ctx()).unwrap();
395 assert!(score.redundancy.abs() < f32::EPSILON);
396 }
397
398 #[test]
399 fn user_requested_always_executes() {
400 let scorer = UtilityScorer::new(default_config());
401 let score = UtilityScore {
403 gain: 0.0,
404 cost: 1.0,
405 redundancy: 1.0,
406 uncertainty: 0.0,
407 total: -100.0,
408 };
409 let ctx = UtilityContext {
410 user_requested: true,
411 ..default_ctx()
412 };
413 assert_eq!(
414 scorer.recommend_action(Some(&score), &ctx),
415 UtilityAction::ToolCall
416 );
417 }
418
419 #[test]
420 fn none_score_fail_closed_when_enabled() {
421 let scorer = UtilityScorer::new(default_config());
422 assert_eq!(
424 scorer.recommend_action(None, &default_ctx()),
425 UtilityAction::Stop
426 );
427 }
428
429 #[test]
430 fn none_score_executes_when_disabled() {
431 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert_eq!(
433 scorer.recommend_action(None, &default_ctx()),
434 UtilityAction::ToolCall
435 );
436 }
437
438 #[test]
439 fn cost_increases_with_token_consumption() {
440 let scorer = UtilityScorer::new(default_config());
441 let call = make_call("bash", json!({}));
442 let ctx_low = UtilityContext {
443 tokens_consumed: 100,
444 token_budget: 1000,
445 ..default_ctx()
446 };
447 let ctx_high = UtilityContext {
448 tokens_consumed: 900,
449 token_budget: 1000,
450 ..default_ctx()
451 };
452 let s_low = scorer.score(&call, &ctx_low).unwrap();
453 let s_high = scorer.score(&call, &ctx_high).unwrap();
454 assert!(s_low.cost < s_high.cost);
455 assert!(s_low.total > s_high.total);
456 }
457
458 #[test]
459 fn uncertainty_decreases_with_call_count() {
460 let scorer = UtilityScorer::new(default_config());
461 let call = make_call("bash", json!({}));
462 let ctx_early = UtilityContext {
463 tool_calls_this_turn: 0,
464 ..default_ctx()
465 };
466 let ctx_late = UtilityContext {
467 tool_calls_this_turn: 9,
468 ..default_ctx()
469 };
470 let s_early = scorer.score(&call, &ctx_early).unwrap();
471 let s_late = scorer.score(&call, &ctx_late).unwrap();
472 assert!(s_early.uncertainty > s_late.uncertainty);
473 }
474
475 #[test]
476 fn memory_tool_has_higher_gain_than_scrape() {
477 let scorer = UtilityScorer::new(default_config());
478 let mem_call = make_call("memory_search", json!({}));
479 let web_call = make_call("scrape", json!({}));
480 let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
481 let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
482 assert!(s_mem.gain > s_web.gain);
483 }
484
485 #[test]
486 fn zero_token_budget_zeroes_cost() {
487 let scorer = UtilityScorer::new(default_config());
488 let call = make_call("bash", json!({}));
489 let ctx = UtilityContext {
490 tokens_consumed: 500,
491 token_budget: 0,
492 ..default_ctx()
493 };
494 let s = scorer.score(&call, &ctx).unwrap();
495 assert!(s.cost.abs() < f32::EPSILON);
496 }
497
498 #[test]
499 fn validate_rejects_negative_weights() {
500 let cfg = UtilityScoringConfig {
501 enabled: true,
502 gain_weight: -1.0,
503 ..UtilityScoringConfig::default()
504 };
505 assert!(cfg.validate().is_err());
506 }
507
508 #[test]
509 fn validate_rejects_nan_weights() {
510 let cfg = UtilityScoringConfig {
511 enabled: true,
512 threshold: f32::NAN,
513 ..UtilityScoringConfig::default()
514 };
515 assert!(cfg.validate().is_err());
516 }
517
518 #[test]
519 fn validate_accepts_default() {
520 assert!(UtilityScoringConfig::default().validate().is_ok());
521 }
522
523 #[test]
524 fn threshold_zero_all_calls_pass() {
525 let scorer = UtilityScorer::new(UtilityScoringConfig {
527 enabled: true,
528 threshold: 0.0,
529 ..UtilityScoringConfig::default()
530 });
531 let call = make_call("bash", json!({}));
532 let score = scorer.score(&call, &default_ctx()).unwrap();
533 assert!(
535 score.total >= 0.0,
536 "total should be non-negative: {}",
537 score.total
538 );
539 let action = scorer.recommend_action(Some(&score), &default_ctx());
541 assert!(
542 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
543 "threshold=0 should not block calls, got {action:?}",
544 );
545 }
546
547 #[test]
548 fn threshold_one_blocks_all_calls() {
549 let scorer = UtilityScorer::new(UtilityScoringConfig {
551 enabled: true,
552 threshold: 1.0,
553 ..UtilityScoringConfig::default()
554 });
555 let call = make_call("bash", json!({}));
556 let score = scorer.score(&call, &default_ctx()).unwrap();
557 assert!(
558 score.total < 1.0,
559 "realistic score should be below 1.0: {}",
560 score.total
561 );
562 assert_ne!(
564 scorer.recommend_action(Some(&score), &default_ctx()),
565 UtilityAction::ToolCall
566 );
567 }
568
569 #[test]
572 fn recommend_action_user_requested_always_tool_call() {
573 let scorer = UtilityScorer::new(default_config());
574 let score = UtilityScore {
575 gain: 0.0,
576 cost: 1.0,
577 redundancy: 1.0,
578 uncertainty: 0.0,
579 total: -100.0,
580 };
581 let ctx = UtilityContext {
582 user_requested: true,
583 ..default_ctx()
584 };
585 assert_eq!(
586 scorer.recommend_action(Some(&score), &ctx),
587 UtilityAction::ToolCall
588 );
589 }
590
591 #[test]
592 fn recommend_action_disabled_scorer_always_tool_call() {
593 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); let ctx = default_ctx();
595 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::ToolCall);
596 }
597
598 #[test]
599 fn recommend_action_none_score_enabled_stops() {
600 let scorer = UtilityScorer::new(default_config());
601 let ctx = default_ctx();
602 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::Stop);
603 }
604
605 #[test]
606 fn recommend_action_budget_exhausted_stops() {
607 let scorer = UtilityScorer::new(default_config());
608 let score = UtilityScore {
609 gain: 0.8,
610 cost: 0.95,
611 redundancy: 0.0,
612 uncertainty: 0.5,
613 total: 0.5,
614 };
615 assert_eq!(
616 scorer.recommend_action(Some(&score), &default_ctx()),
617 UtilityAction::Stop
618 );
619 }
620
621 #[test]
622 fn recommend_action_redundant_responds() {
623 let scorer = UtilityScorer::new(default_config());
624 let score = UtilityScore {
625 gain: 0.8,
626 cost: 0.1,
627 redundancy: 1.0,
628 uncertainty: 0.5,
629 total: 0.5,
630 };
631 assert_eq!(
632 scorer.recommend_action(Some(&score), &default_ctx()),
633 UtilityAction::Respond
634 );
635 }
636
637 #[test]
638 fn recommend_action_high_gain_above_threshold_tool_call() {
639 let scorer = UtilityScorer::new(default_config());
640 let score = UtilityScore {
641 gain: 0.8,
642 cost: 0.1,
643 redundancy: 0.0,
644 uncertainty: 0.4,
645 total: 0.6,
646 };
647 assert_eq!(
648 scorer.recommend_action(Some(&score), &default_ctx()),
649 UtilityAction::ToolCall
650 );
651 }
652
653 #[test]
654 fn recommend_action_uncertain_retrieves() {
655 let scorer = UtilityScorer::new(default_config());
656 let score = UtilityScore {
658 gain: 0.6,
659 cost: 0.1,
660 redundancy: 0.0,
661 uncertainty: 0.8,
662 total: 0.4,
663 };
664 assert_eq!(
665 scorer.recommend_action(Some(&score), &default_ctx()),
666 UtilityAction::Retrieve
667 );
668 }
669
670 #[test]
671 fn recommend_action_below_threshold_with_prior_calls_verifies() {
672 let scorer = UtilityScorer::new(default_config());
673 let score = UtilityScore {
674 gain: 0.3,
675 cost: 0.1,
676 redundancy: 0.0,
677 uncertainty: 0.2,
678 total: 0.05, };
680 let ctx = UtilityContext {
681 tool_calls_this_turn: 1,
682 ..default_ctx()
683 };
684 assert_eq!(
685 scorer.recommend_action(Some(&score), &ctx),
686 UtilityAction::Verify
687 );
688 }
689
690 #[test]
691 fn recommend_action_default_responds() {
692 let scorer = UtilityScorer::new(default_config());
693 let score = UtilityScore {
694 gain: 0.3,
695 cost: 0.1,
696 redundancy: 0.0,
697 uncertainty: 0.2,
698 total: 0.05, };
700 let ctx = UtilityContext {
701 tool_calls_this_turn: 0,
702 ..default_ctx()
703 };
704 assert_eq!(
705 scorer.recommend_action(Some(&score), &ctx),
706 UtilityAction::Respond
707 );
708 }
709
710 #[test]
713 fn explicit_request_using_a_tool() {
714 assert!(has_explicit_tool_request(
715 "Please list the files in the current directory using a tool"
716 ));
717 }
718
719 #[test]
720 fn explicit_request_call_the_tool() {
721 assert!(has_explicit_tool_request("call the list_directory tool"));
722 }
723
724 #[test]
725 fn explicit_request_use_the_tool() {
726 assert!(has_explicit_tool_request("use the shell tool to run ls"));
727 }
728
729 #[test]
730 fn explicit_request_run_the_tool() {
731 assert!(has_explicit_tool_request("run the bash tool"));
732 }
733
734 #[test]
735 fn explicit_request_invoke_the_tool() {
736 assert!(has_explicit_tool_request("invoke the search_code tool"));
737 }
738
739 #[test]
740 fn explicit_request_execute_the_tool() {
741 assert!(has_explicit_tool_request("execute the grep tool for me"));
742 }
743
744 #[test]
745 fn explicit_request_case_insensitive() {
746 assert!(has_explicit_tool_request("USING A TOOL to find files"));
747 }
748
749 #[test]
750 fn explicit_request_no_match_plain_message() {
751 assert!(!has_explicit_tool_request("what is the weather today?"));
752 }
753
754 #[test]
755 fn explicit_request_no_match_tool_mentioned_without_invocation() {
756 assert!(!has_explicit_tool_request(
757 "the shell tool is very useful in general"
758 ));
759 }
760
761 #[test]
762 fn explicit_request_show_me_result_of() {
763 assert!(has_explicit_tool_request(
764 "show me the result of: echo hello"
765 ));
766 }
767
768 #[test]
769 fn explicit_request_run_colon() {
770 assert!(has_explicit_tool_request("run: echo hello"));
771 }
772
773 #[test]
774 fn explicit_request_execute_colon() {
775 assert!(has_explicit_tool_request("execute: ls -la"));
776 }
777
778 #[test]
779 fn explicit_request_what_does() {
780 assert!(has_explicit_tool_request("what does echo hello output?"));
781 }
782
783 #[test]
784 fn explicit_request_what_would() {
785 assert!(has_explicit_tool_request("what would cat /etc/hosts show?"));
786 }
787
788 #[test]
789 fn explicit_request_what_is_the_output_of() {
790 assert!(has_explicit_tool_request(
791 "what is the output of ls | grep foo?"
792 ));
793 }
794
795 #[test]
796 fn explicit_request_inline_code_pipe() {
797 assert!(has_explicit_tool_request("try running `ls | grep foo`"));
798 }
799
800 #[test]
801 fn explicit_request_inline_code_redirect() {
802 assert!(has_explicit_tool_request("run `echo hello > /tmp/out`"));
803 }
804
805 #[test]
806 fn explicit_request_inline_code_dollar() {
807 assert!(has_explicit_tool_request("check `$HOME/bin`"));
808 }
809
810 #[test]
811 fn explicit_request_inline_code_and() {
812 assert!(has_explicit_tool_request("try `git fetch && git rebase`"));
813 }
814
815 #[test]
816 fn no_match_run_the_tests() {
817 assert!(!has_explicit_tool_request("run the tests please"));
818 }
819
820 #[test]
821 fn no_match_execute_the_plan() {
822 assert!(!has_explicit_tool_request("execute the plan we discussed"));
823 }
824
825 #[test]
826 fn no_match_inline_code_no_shell_syntax() {
827 assert!(!has_explicit_tool_request(
828 "the function `process_items` handles it"
829 ));
830 }
831
832 #[test]
837 fn known_fp_what_does_function_do() {
838 assert!(has_explicit_tool_request("what does this function do?"));
840 }
841
842 #[test]
843 fn no_match_show_me_result_without_colon() {
844 assert!(!has_explicit_tool_request(
846 "show me the result of running it"
847 ));
848 }
849
850 #[test]
851 fn is_exempt_matches_case_insensitively() {
852 let scorer = UtilityScorer::new(UtilityScoringConfig {
853 enabled: true,
854 exempt_tools: vec!["Read".to_owned(), "file_read".to_owned()],
855 ..UtilityScoringConfig::default()
856 });
857 assert!(scorer.is_exempt("read"));
858 assert!(scorer.is_exempt("READ"));
859 assert!(scorer.is_exempt("FILE_READ"));
860 assert!(!scorer.is_exempt("write"));
861 assert!(!scorer.is_exempt("bash"));
862 }
863
864 #[test]
865 fn is_exempt_empty_list_returns_false() {
866 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
867 assert!(!scorer.is_exempt("read"));
868 }
869}