1use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11use std::sync::LazyLock;
12
13use regex::Regex;
14
15use crate::config::UtilityScoringConfig;
16use crate::executor::ToolCall;
17
18#[must_use]
27pub fn has_explicit_tool_request(user_message: &str) -> bool {
28 static RE: LazyLock<Regex> = LazyLock::new(|| {
29 Regex::new(
30 r"(?xi)
31 using\s+a\s+tool
32 | call\s+(the\s+)?[a-z_]+\s+tool
33 | use\s+(the\s+)?[a-z_]+\s+tool
34 | run\s+(the\s+)?[a-z_]+\s+tool
35 | invoke\s+(the\s+)?[a-z_]+\s+tool
36 | execute\s+(the\s+)?[a-z_]+\s+tool
37 ",
38 )
39 .expect("static regex is valid")
40 });
41 RE.is_match(user_message)
42}
43
44fn default_gain(tool_name: &str) -> f32 {
49 if tool_name.starts_with("memory") {
50 return 0.8;
51 }
52 if tool_name.starts_with("mcp_") {
53 return 0.5;
54 }
55 match tool_name {
56 "bash" | "shell" => 0.6,
57 "read" | "write" => 0.55,
58 "search_code" | "grep" | "glob" => 0.65,
59 _ => 0.5,
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct UtilityScore {
66 pub gain: f32,
68 pub cost: f32,
70 pub redundancy: f32,
72 pub uncertainty: f32,
74 pub total: f32,
76}
77
78impl UtilityScore {
79 fn is_valid(&self) -> bool {
81 self.gain.is_finite()
82 && self.cost.is_finite()
83 && self.redundancy.is_finite()
84 && self.uncertainty.is_finite()
85 && self.total.is_finite()
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct UtilityContext {
92 pub tool_calls_this_turn: usize,
94 pub tokens_consumed: usize,
96 pub token_budget: usize,
98 pub user_requested: bool,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum UtilityAction {
107 Respond,
109 Retrieve,
111 ToolCall,
113 Verify,
115 Stop,
117}
118
119fn call_hash(call: &ToolCall) -> u64 {
121 let mut h = DefaultHasher::new();
122 call.tool_id.hash(&mut h);
123 format!("{:?}", call.params).hash(&mut h);
127 h.finish()
128}
129
130#[derive(Debug)]
135pub struct UtilityScorer {
136 config: UtilityScoringConfig,
137 recent_calls: HashMap<u64, u32>,
139}
140
141impl UtilityScorer {
142 #[must_use]
144 pub fn new(config: UtilityScoringConfig) -> Self {
145 Self {
146 config,
147 recent_calls: HashMap::new(),
148 }
149 }
150
151 #[must_use]
153 pub fn is_enabled(&self) -> bool {
154 self.config.enabled
155 }
156
157 #[must_use]
163 pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
164 if !self.config.enabled {
165 return None;
166 }
167
168 let gain = default_gain(&call.tool_id);
169
170 let cost = if ctx.token_budget > 0 {
171 #[allow(clippy::cast_precision_loss)]
172 (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
173 } else {
174 0.0
175 };
176
177 let hash = call_hash(call);
178 let redundancy = if self.recent_calls.contains_key(&hash) {
179 1.0_f32
180 } else {
181 0.0_f32
182 };
183
184 #[allow(clippy::cast_precision_loss)]
187 let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
188
189 let total = self.config.gain_weight * gain
190 - self.config.cost_weight * cost
191 - self.config.redundancy_weight * redundancy
192 + self.config.uncertainty_bonus * uncertainty;
193
194 let score = UtilityScore {
195 gain,
196 cost,
197 redundancy,
198 uncertainty,
199 total,
200 };
201
202 if score.is_valid() { Some(score) } else { None }
203 }
204
205 #[must_use]
219 pub fn recommend_action(
220 &self,
221 score: Option<&UtilityScore>,
222 ctx: &UtilityContext,
223 ) -> UtilityAction {
224 if ctx.user_requested {
226 return UtilityAction::ToolCall;
227 }
228 if !self.config.enabled {
230 return UtilityAction::ToolCall;
231 }
232 let Some(s) = score else {
233 return UtilityAction::Stop;
235 };
236
237 if s.cost > 0.9 {
239 return UtilityAction::Stop;
240 }
241 if s.redundancy >= 1.0 {
243 return UtilityAction::Respond;
244 }
245 if s.gain >= 0.7 && s.total >= self.config.threshold {
247 return UtilityAction::ToolCall;
248 }
249 if s.gain >= 0.5 && s.uncertainty > 0.5 {
251 return UtilityAction::Retrieve;
252 }
253 if s.total < self.config.threshold && ctx.tool_calls_this_turn > 0 {
255 return UtilityAction::Verify;
256 }
257 if s.total >= self.config.threshold {
259 return UtilityAction::ToolCall;
260 }
261 UtilityAction::Respond
262 }
263
264 pub fn record_call(&mut self, call: &ToolCall) {
269 let hash = call_hash(call);
270 *self.recent_calls.entry(hash).or_insert(0) += 1;
271 }
272
273 pub fn clear(&mut self) {
275 self.recent_calls.clear();
276 }
277
278 #[must_use]
280 pub fn threshold(&self) -> f32 {
281 self.config.threshold
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use serde_json::json;
289
290 fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
291 ToolCall {
292 tool_id: name.to_owned(),
293 params: if let serde_json::Value::Object(m) = params {
294 m
295 } else {
296 serde_json::Map::new()
297 },
298 }
299 }
300
301 fn default_ctx() -> UtilityContext {
302 UtilityContext {
303 tool_calls_this_turn: 0,
304 tokens_consumed: 0,
305 token_budget: 1000,
306 user_requested: false,
307 }
308 }
309
310 fn default_config() -> UtilityScoringConfig {
311 UtilityScoringConfig {
312 enabled: true,
313 ..UtilityScoringConfig::default()
314 }
315 }
316
317 #[test]
318 fn disabled_returns_none() {
319 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
320 assert!(!scorer.is_enabled());
321 let call = make_call("bash", json!({}));
322 let score = scorer.score(&call, &default_ctx());
323 assert!(score.is_none());
324 assert_eq!(
326 scorer.recommend_action(score.as_ref(), &default_ctx()),
327 UtilityAction::ToolCall
328 );
329 }
330
331 #[test]
332 fn first_call_passes_default_threshold() {
333 let scorer = UtilityScorer::new(default_config());
334 let call = make_call("bash", json!({"cmd": "ls"}));
335 let score = scorer.score(&call, &default_ctx());
336 assert!(score.is_some());
337 let s = score.unwrap();
338 assert!(
339 s.total >= 0.1,
340 "first call should exceed threshold: {}",
341 s.total
342 );
343 let action = scorer.recommend_action(Some(&s), &default_ctx());
346 assert!(
347 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
348 "first call should not be blocked, got {action:?}",
349 );
350 }
351
352 #[test]
353 fn redundant_call_penalized() {
354 let mut scorer = UtilityScorer::new(default_config());
355 let call = make_call("bash", json!({"cmd": "ls"}));
356 scorer.record_call(&call);
357 let score = scorer.score(&call, &default_ctx()).unwrap();
358 assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
359 }
360
361 #[test]
362 fn clear_resets_redundancy() {
363 let mut scorer = UtilityScorer::new(default_config());
364 let call = make_call("bash", json!({"cmd": "ls"}));
365 scorer.record_call(&call);
366 scorer.clear();
367 let score = scorer.score(&call, &default_ctx()).unwrap();
368 assert!(score.redundancy.abs() < f32::EPSILON);
369 }
370
371 #[test]
372 fn user_requested_always_executes() {
373 let scorer = UtilityScorer::new(default_config());
374 let score = UtilityScore {
376 gain: 0.0,
377 cost: 1.0,
378 redundancy: 1.0,
379 uncertainty: 0.0,
380 total: -100.0,
381 };
382 let ctx = UtilityContext {
383 user_requested: true,
384 ..default_ctx()
385 };
386 assert_eq!(
387 scorer.recommend_action(Some(&score), &ctx),
388 UtilityAction::ToolCall
389 );
390 }
391
392 #[test]
393 fn none_score_fail_closed_when_enabled() {
394 let scorer = UtilityScorer::new(default_config());
395 assert_eq!(
397 scorer.recommend_action(None, &default_ctx()),
398 UtilityAction::Stop
399 );
400 }
401
402 #[test]
403 fn none_score_executes_when_disabled() {
404 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert_eq!(
406 scorer.recommend_action(None, &default_ctx()),
407 UtilityAction::ToolCall
408 );
409 }
410
411 #[test]
412 fn cost_increases_with_token_consumption() {
413 let scorer = UtilityScorer::new(default_config());
414 let call = make_call("bash", json!({}));
415 let ctx_low = UtilityContext {
416 tokens_consumed: 100,
417 token_budget: 1000,
418 ..default_ctx()
419 };
420 let ctx_high = UtilityContext {
421 tokens_consumed: 900,
422 token_budget: 1000,
423 ..default_ctx()
424 };
425 let s_low = scorer.score(&call, &ctx_low).unwrap();
426 let s_high = scorer.score(&call, &ctx_high).unwrap();
427 assert!(s_low.cost < s_high.cost);
428 assert!(s_low.total > s_high.total);
429 }
430
431 #[test]
432 fn uncertainty_decreases_with_call_count() {
433 let scorer = UtilityScorer::new(default_config());
434 let call = make_call("bash", json!({}));
435 let ctx_early = UtilityContext {
436 tool_calls_this_turn: 0,
437 ..default_ctx()
438 };
439 let ctx_late = UtilityContext {
440 tool_calls_this_turn: 9,
441 ..default_ctx()
442 };
443 let s_early = scorer.score(&call, &ctx_early).unwrap();
444 let s_late = scorer.score(&call, &ctx_late).unwrap();
445 assert!(s_early.uncertainty > s_late.uncertainty);
446 }
447
448 #[test]
449 fn memory_tool_has_higher_gain_than_scrape() {
450 let scorer = UtilityScorer::new(default_config());
451 let mem_call = make_call("memory_search", json!({}));
452 let web_call = make_call("scrape", json!({}));
453 let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
454 let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
455 assert!(s_mem.gain > s_web.gain);
456 }
457
458 #[test]
459 fn zero_token_budget_zeroes_cost() {
460 let scorer = UtilityScorer::new(default_config());
461 let call = make_call("bash", json!({}));
462 let ctx = UtilityContext {
463 tokens_consumed: 500,
464 token_budget: 0,
465 ..default_ctx()
466 };
467 let s = scorer.score(&call, &ctx).unwrap();
468 assert!(s.cost.abs() < f32::EPSILON);
469 }
470
471 #[test]
472 fn validate_rejects_negative_weights() {
473 let cfg = UtilityScoringConfig {
474 enabled: true,
475 gain_weight: -1.0,
476 ..UtilityScoringConfig::default()
477 };
478 assert!(cfg.validate().is_err());
479 }
480
481 #[test]
482 fn validate_rejects_nan_weights() {
483 let cfg = UtilityScoringConfig {
484 enabled: true,
485 threshold: f32::NAN,
486 ..UtilityScoringConfig::default()
487 };
488 assert!(cfg.validate().is_err());
489 }
490
491 #[test]
492 fn validate_accepts_default() {
493 assert!(UtilityScoringConfig::default().validate().is_ok());
494 }
495
496 #[test]
497 fn threshold_zero_all_calls_pass() {
498 let scorer = UtilityScorer::new(UtilityScoringConfig {
500 enabled: true,
501 threshold: 0.0,
502 ..UtilityScoringConfig::default()
503 });
504 let call = make_call("bash", json!({}));
505 let score = scorer.score(&call, &default_ctx()).unwrap();
506 assert!(
508 score.total >= 0.0,
509 "total should be non-negative: {}",
510 score.total
511 );
512 let action = scorer.recommend_action(Some(&score), &default_ctx());
514 assert!(
515 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
516 "threshold=0 should not block calls, got {action:?}",
517 );
518 }
519
520 #[test]
521 fn threshold_one_blocks_all_calls() {
522 let scorer = UtilityScorer::new(UtilityScoringConfig {
524 enabled: true,
525 threshold: 1.0,
526 ..UtilityScoringConfig::default()
527 });
528 let call = make_call("bash", json!({}));
529 let score = scorer.score(&call, &default_ctx()).unwrap();
530 assert!(
531 score.total < 1.0,
532 "realistic score should be below 1.0: {}",
533 score.total
534 );
535 assert_ne!(
537 scorer.recommend_action(Some(&score), &default_ctx()),
538 UtilityAction::ToolCall
539 );
540 }
541
542 #[test]
545 fn recommend_action_user_requested_always_tool_call() {
546 let scorer = UtilityScorer::new(default_config());
547 let score = UtilityScore {
548 gain: 0.0,
549 cost: 1.0,
550 redundancy: 1.0,
551 uncertainty: 0.0,
552 total: -100.0,
553 };
554 let ctx = UtilityContext {
555 user_requested: true,
556 ..default_ctx()
557 };
558 assert_eq!(
559 scorer.recommend_action(Some(&score), &ctx),
560 UtilityAction::ToolCall
561 );
562 }
563
564 #[test]
565 fn recommend_action_disabled_scorer_always_tool_call() {
566 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); let ctx = default_ctx();
568 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::ToolCall);
569 }
570
571 #[test]
572 fn recommend_action_none_score_enabled_stops() {
573 let scorer = UtilityScorer::new(default_config());
574 let ctx = default_ctx();
575 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::Stop);
576 }
577
578 #[test]
579 fn recommend_action_budget_exhausted_stops() {
580 let scorer = UtilityScorer::new(default_config());
581 let score = UtilityScore {
582 gain: 0.8,
583 cost: 0.95,
584 redundancy: 0.0,
585 uncertainty: 0.5,
586 total: 0.5,
587 };
588 assert_eq!(
589 scorer.recommend_action(Some(&score), &default_ctx()),
590 UtilityAction::Stop
591 );
592 }
593
594 #[test]
595 fn recommend_action_redundant_responds() {
596 let scorer = UtilityScorer::new(default_config());
597 let score = UtilityScore {
598 gain: 0.8,
599 cost: 0.1,
600 redundancy: 1.0,
601 uncertainty: 0.5,
602 total: 0.5,
603 };
604 assert_eq!(
605 scorer.recommend_action(Some(&score), &default_ctx()),
606 UtilityAction::Respond
607 );
608 }
609
610 #[test]
611 fn recommend_action_high_gain_above_threshold_tool_call() {
612 let scorer = UtilityScorer::new(default_config());
613 let score = UtilityScore {
614 gain: 0.8,
615 cost: 0.1,
616 redundancy: 0.0,
617 uncertainty: 0.4,
618 total: 0.6,
619 };
620 assert_eq!(
621 scorer.recommend_action(Some(&score), &default_ctx()),
622 UtilityAction::ToolCall
623 );
624 }
625
626 #[test]
627 fn recommend_action_uncertain_retrieves() {
628 let scorer = UtilityScorer::new(default_config());
629 let score = UtilityScore {
631 gain: 0.6,
632 cost: 0.1,
633 redundancy: 0.0,
634 uncertainty: 0.8,
635 total: 0.4,
636 };
637 assert_eq!(
638 scorer.recommend_action(Some(&score), &default_ctx()),
639 UtilityAction::Retrieve
640 );
641 }
642
643 #[test]
644 fn recommend_action_below_threshold_with_prior_calls_verifies() {
645 let scorer = UtilityScorer::new(default_config());
646 let score = UtilityScore {
647 gain: 0.3,
648 cost: 0.1,
649 redundancy: 0.0,
650 uncertainty: 0.2,
651 total: 0.05, };
653 let ctx = UtilityContext {
654 tool_calls_this_turn: 1,
655 ..default_ctx()
656 };
657 assert_eq!(
658 scorer.recommend_action(Some(&score), &ctx),
659 UtilityAction::Verify
660 );
661 }
662
663 #[test]
664 fn recommend_action_default_responds() {
665 let scorer = UtilityScorer::new(default_config());
666 let score = UtilityScore {
667 gain: 0.3,
668 cost: 0.1,
669 redundancy: 0.0,
670 uncertainty: 0.2,
671 total: 0.05, };
673 let ctx = UtilityContext {
674 tool_calls_this_turn: 0,
675 ..default_ctx()
676 };
677 assert_eq!(
678 scorer.recommend_action(Some(&score), &ctx),
679 UtilityAction::Respond
680 );
681 }
682
683 #[test]
686 fn explicit_request_using_a_tool() {
687 assert!(has_explicit_tool_request(
688 "Please list the files in the current directory using a tool"
689 ));
690 }
691
692 #[test]
693 fn explicit_request_call_the_tool() {
694 assert!(has_explicit_tool_request("call the list_directory tool"));
695 }
696
697 #[test]
698 fn explicit_request_use_the_tool() {
699 assert!(has_explicit_tool_request("use the shell tool to run ls"));
700 }
701
702 #[test]
703 fn explicit_request_run_the_tool() {
704 assert!(has_explicit_tool_request("run the bash tool"));
705 }
706
707 #[test]
708 fn explicit_request_invoke_the_tool() {
709 assert!(has_explicit_tool_request("invoke the search_code tool"));
710 }
711
712 #[test]
713 fn explicit_request_execute_the_tool() {
714 assert!(has_explicit_tool_request("execute the grep tool for me"));
715 }
716
717 #[test]
718 fn explicit_request_case_insensitive() {
719 assert!(has_explicit_tool_request("USING A TOOL to find files"));
720 }
721
722 #[test]
723 fn explicit_request_no_match_plain_message() {
724 assert!(!has_explicit_tool_request("what is the weather today?"));
725 }
726
727 #[test]
728 fn explicit_request_no_match_tool_mentioned_without_invocation() {
729 assert!(!has_explicit_tool_request(
730 "the shell tool is very useful in general"
731 ));
732 }
733}