1use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11
12use crate::config::UtilityScoringConfig;
13use crate::executor::ToolCall;
14
15fn default_gain(tool_name: &str) -> f32 {
20 if tool_name.starts_with("memory") {
21 return 0.8;
22 }
23 if tool_name.starts_with("mcp_") {
24 return 0.5;
25 }
26 match tool_name {
27 "bash" | "shell" => 0.6,
28 "read" | "write" => 0.55,
29 "search_code" | "grep" | "glob" => 0.65,
30 _ => 0.5,
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct UtilityScore {
37 pub gain: f32,
39 pub cost: f32,
41 pub redundancy: f32,
43 pub uncertainty: f32,
45 pub total: f32,
47}
48
49impl UtilityScore {
50 fn is_valid(&self) -> bool {
52 self.gain.is_finite()
53 && self.cost.is_finite()
54 && self.redundancy.is_finite()
55 && self.uncertainty.is_finite()
56 && self.total.is_finite()
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct UtilityContext {
63 pub tool_calls_this_turn: usize,
65 pub tokens_consumed: usize,
67 pub token_budget: usize,
69 pub user_requested: bool,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum UtilityAction {
77 Respond,
79 Retrieve,
81 ToolCall,
83 Verify,
85 Stop,
87}
88
89fn call_hash(call: &ToolCall) -> u64 {
91 let mut h = DefaultHasher::new();
92 call.tool_id.hash(&mut h);
93 format!("{:?}", call.params).hash(&mut h);
97 h.finish()
98}
99
100#[derive(Debug)]
105pub struct UtilityScorer {
106 config: UtilityScoringConfig,
107 recent_calls: HashMap<u64, u32>,
109}
110
111impl UtilityScorer {
112 #[must_use]
114 pub fn new(config: UtilityScoringConfig) -> Self {
115 Self {
116 config,
117 recent_calls: HashMap::new(),
118 }
119 }
120
121 #[must_use]
123 pub fn is_enabled(&self) -> bool {
124 self.config.enabled
125 }
126
127 #[must_use]
133 pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
134 if !self.config.enabled {
135 return None;
136 }
137
138 let gain = default_gain(&call.tool_id);
139
140 let cost = if ctx.token_budget > 0 {
141 #[allow(clippy::cast_precision_loss)]
142 (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
143 } else {
144 0.0
145 };
146
147 let hash = call_hash(call);
148 let redundancy = if self.recent_calls.contains_key(&hash) {
149 1.0_f32
150 } else {
151 0.0_f32
152 };
153
154 #[allow(clippy::cast_precision_loss)]
157 let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
158
159 let total = self.config.gain_weight * gain
160 - self.config.cost_weight * cost
161 - self.config.redundancy_weight * redundancy
162 + self.config.uncertainty_bonus * uncertainty;
163
164 let score = UtilityScore {
165 gain,
166 cost,
167 redundancy,
168 uncertainty,
169 total,
170 };
171
172 if score.is_valid() { Some(score) } else { None }
173 }
174
175 #[must_use]
189 pub fn recommend_action(
190 &self,
191 score: Option<&UtilityScore>,
192 ctx: &UtilityContext,
193 ) -> UtilityAction {
194 if ctx.user_requested {
196 return UtilityAction::ToolCall;
197 }
198 if !self.config.enabled {
200 return UtilityAction::ToolCall;
201 }
202 let Some(s) = score else {
203 return UtilityAction::Stop;
205 };
206
207 if s.cost > 0.9 {
209 return UtilityAction::Stop;
210 }
211 if s.redundancy >= 1.0 {
213 return UtilityAction::Respond;
214 }
215 if s.gain >= 0.7 && s.total >= self.config.threshold {
217 return UtilityAction::ToolCall;
218 }
219 if s.gain >= 0.5 && s.uncertainty > 0.5 {
221 return UtilityAction::Retrieve;
222 }
223 if s.total < self.config.threshold && ctx.tool_calls_this_turn > 0 {
225 return UtilityAction::Verify;
226 }
227 if s.total >= self.config.threshold {
229 return UtilityAction::ToolCall;
230 }
231 UtilityAction::Respond
232 }
233
234 pub fn record_call(&mut self, call: &ToolCall) {
239 let hash = call_hash(call);
240 *self.recent_calls.entry(hash).or_insert(0) += 1;
241 }
242
243 pub fn clear(&mut self) {
245 self.recent_calls.clear();
246 }
247
248 #[must_use]
250 pub fn threshold(&self) -> f32 {
251 self.config.threshold
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use serde_json::json;
259
260 fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
261 ToolCall {
262 tool_id: name.to_owned(),
263 params: if let serde_json::Value::Object(m) = params {
264 m
265 } else {
266 serde_json::Map::new()
267 },
268 }
269 }
270
271 fn default_ctx() -> UtilityContext {
272 UtilityContext {
273 tool_calls_this_turn: 0,
274 tokens_consumed: 0,
275 token_budget: 1000,
276 user_requested: false,
277 }
278 }
279
280 fn default_config() -> UtilityScoringConfig {
281 UtilityScoringConfig {
282 enabled: true,
283 ..UtilityScoringConfig::default()
284 }
285 }
286
287 #[test]
288 fn disabled_returns_none() {
289 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
290 assert!(!scorer.is_enabled());
291 let call = make_call("bash", json!({}));
292 let score = scorer.score(&call, &default_ctx());
293 assert!(score.is_none());
294 assert_eq!(
296 scorer.recommend_action(score.as_ref(), &default_ctx()),
297 UtilityAction::ToolCall
298 );
299 }
300
301 #[test]
302 fn first_call_passes_default_threshold() {
303 let scorer = UtilityScorer::new(default_config());
304 let call = make_call("bash", json!({"cmd": "ls"}));
305 let score = scorer.score(&call, &default_ctx());
306 assert!(score.is_some());
307 let s = score.unwrap();
308 assert!(
309 s.total >= 0.1,
310 "first call should exceed threshold: {}",
311 s.total
312 );
313 let action = scorer.recommend_action(Some(&s), &default_ctx());
316 assert!(
317 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
318 "first call should not be blocked, got {action:?}",
319 );
320 }
321
322 #[test]
323 fn redundant_call_penalized() {
324 let mut scorer = UtilityScorer::new(default_config());
325 let call = make_call("bash", json!({"cmd": "ls"}));
326 scorer.record_call(&call);
327 let score = scorer.score(&call, &default_ctx()).unwrap();
328 assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
329 }
330
331 #[test]
332 fn clear_resets_redundancy() {
333 let mut scorer = UtilityScorer::new(default_config());
334 let call = make_call("bash", json!({"cmd": "ls"}));
335 scorer.record_call(&call);
336 scorer.clear();
337 let score = scorer.score(&call, &default_ctx()).unwrap();
338 assert!(score.redundancy.abs() < f32::EPSILON);
339 }
340
341 #[test]
342 fn user_requested_always_executes() {
343 let scorer = UtilityScorer::new(default_config());
344 let score = UtilityScore {
346 gain: 0.0,
347 cost: 1.0,
348 redundancy: 1.0,
349 uncertainty: 0.0,
350 total: -100.0,
351 };
352 let ctx = UtilityContext {
353 user_requested: true,
354 ..default_ctx()
355 };
356 assert_eq!(
357 scorer.recommend_action(Some(&score), &ctx),
358 UtilityAction::ToolCall
359 );
360 }
361
362 #[test]
363 fn none_score_fail_closed_when_enabled() {
364 let scorer = UtilityScorer::new(default_config());
365 assert_eq!(
367 scorer.recommend_action(None, &default_ctx()),
368 UtilityAction::Stop
369 );
370 }
371
372 #[test]
373 fn none_score_executes_when_disabled() {
374 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert_eq!(
376 scorer.recommend_action(None, &default_ctx()),
377 UtilityAction::ToolCall
378 );
379 }
380
381 #[test]
382 fn cost_increases_with_token_consumption() {
383 let scorer = UtilityScorer::new(default_config());
384 let call = make_call("bash", json!({}));
385 let ctx_low = UtilityContext {
386 tokens_consumed: 100,
387 token_budget: 1000,
388 ..default_ctx()
389 };
390 let ctx_high = UtilityContext {
391 tokens_consumed: 900,
392 token_budget: 1000,
393 ..default_ctx()
394 };
395 let s_low = scorer.score(&call, &ctx_low).unwrap();
396 let s_high = scorer.score(&call, &ctx_high).unwrap();
397 assert!(s_low.cost < s_high.cost);
398 assert!(s_low.total > s_high.total);
399 }
400
401 #[test]
402 fn uncertainty_decreases_with_call_count() {
403 let scorer = UtilityScorer::new(default_config());
404 let call = make_call("bash", json!({}));
405 let ctx_early = UtilityContext {
406 tool_calls_this_turn: 0,
407 ..default_ctx()
408 };
409 let ctx_late = UtilityContext {
410 tool_calls_this_turn: 9,
411 ..default_ctx()
412 };
413 let s_early = scorer.score(&call, &ctx_early).unwrap();
414 let s_late = scorer.score(&call, &ctx_late).unwrap();
415 assert!(s_early.uncertainty > s_late.uncertainty);
416 }
417
418 #[test]
419 fn memory_tool_has_higher_gain_than_scrape() {
420 let scorer = UtilityScorer::new(default_config());
421 let mem_call = make_call("memory_search", json!({}));
422 let web_call = make_call("scrape", json!({}));
423 let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
424 let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
425 assert!(s_mem.gain > s_web.gain);
426 }
427
428 #[test]
429 fn zero_token_budget_zeroes_cost() {
430 let scorer = UtilityScorer::new(default_config());
431 let call = make_call("bash", json!({}));
432 let ctx = UtilityContext {
433 tokens_consumed: 500,
434 token_budget: 0,
435 ..default_ctx()
436 };
437 let s = scorer.score(&call, &ctx).unwrap();
438 assert!(s.cost.abs() < f32::EPSILON);
439 }
440
441 #[test]
442 fn validate_rejects_negative_weights() {
443 let cfg = UtilityScoringConfig {
444 enabled: true,
445 gain_weight: -1.0,
446 ..UtilityScoringConfig::default()
447 };
448 assert!(cfg.validate().is_err());
449 }
450
451 #[test]
452 fn validate_rejects_nan_weights() {
453 let cfg = UtilityScoringConfig {
454 enabled: true,
455 threshold: f32::NAN,
456 ..UtilityScoringConfig::default()
457 };
458 assert!(cfg.validate().is_err());
459 }
460
461 #[test]
462 fn validate_accepts_default() {
463 assert!(UtilityScoringConfig::default().validate().is_ok());
464 }
465
466 #[test]
467 fn threshold_zero_all_calls_pass() {
468 let scorer = UtilityScorer::new(UtilityScoringConfig {
470 enabled: true,
471 threshold: 0.0,
472 ..UtilityScoringConfig::default()
473 });
474 let call = make_call("bash", json!({}));
475 let score = scorer.score(&call, &default_ctx()).unwrap();
476 assert!(
478 score.total >= 0.0,
479 "total should be non-negative: {}",
480 score.total
481 );
482 let action = scorer.recommend_action(Some(&score), &default_ctx());
484 assert!(
485 action == UtilityAction::ToolCall || action == UtilityAction::Retrieve,
486 "threshold=0 should not block calls, got {action:?}",
487 );
488 }
489
490 #[test]
491 fn threshold_one_blocks_all_calls() {
492 let scorer = UtilityScorer::new(UtilityScoringConfig {
494 enabled: true,
495 threshold: 1.0,
496 ..UtilityScoringConfig::default()
497 });
498 let call = make_call("bash", json!({}));
499 let score = scorer.score(&call, &default_ctx()).unwrap();
500 assert!(
501 score.total < 1.0,
502 "realistic score should be below 1.0: {}",
503 score.total
504 );
505 assert_ne!(
507 scorer.recommend_action(Some(&score), &default_ctx()),
508 UtilityAction::ToolCall
509 );
510 }
511
512 #[test]
515 fn recommend_action_user_requested_always_tool_call() {
516 let scorer = UtilityScorer::new(default_config());
517 let score = UtilityScore {
518 gain: 0.0,
519 cost: 1.0,
520 redundancy: 1.0,
521 uncertainty: 0.0,
522 total: -100.0,
523 };
524 let ctx = UtilityContext {
525 user_requested: true,
526 ..default_ctx()
527 };
528 assert_eq!(
529 scorer.recommend_action(Some(&score), &ctx),
530 UtilityAction::ToolCall
531 );
532 }
533
534 #[test]
535 fn recommend_action_disabled_scorer_always_tool_call() {
536 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); let ctx = default_ctx();
538 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::ToolCall);
539 }
540
541 #[test]
542 fn recommend_action_none_score_enabled_stops() {
543 let scorer = UtilityScorer::new(default_config());
544 let ctx = default_ctx();
545 assert_eq!(scorer.recommend_action(None, &ctx), UtilityAction::Stop);
546 }
547
548 #[test]
549 fn recommend_action_budget_exhausted_stops() {
550 let scorer = UtilityScorer::new(default_config());
551 let score = UtilityScore {
552 gain: 0.8,
553 cost: 0.95,
554 redundancy: 0.0,
555 uncertainty: 0.5,
556 total: 0.5,
557 };
558 assert_eq!(
559 scorer.recommend_action(Some(&score), &default_ctx()),
560 UtilityAction::Stop
561 );
562 }
563
564 #[test]
565 fn recommend_action_redundant_responds() {
566 let scorer = UtilityScorer::new(default_config());
567 let score = UtilityScore {
568 gain: 0.8,
569 cost: 0.1,
570 redundancy: 1.0,
571 uncertainty: 0.5,
572 total: 0.5,
573 };
574 assert_eq!(
575 scorer.recommend_action(Some(&score), &default_ctx()),
576 UtilityAction::Respond
577 );
578 }
579
580 #[test]
581 fn recommend_action_high_gain_above_threshold_tool_call() {
582 let scorer = UtilityScorer::new(default_config());
583 let score = UtilityScore {
584 gain: 0.8,
585 cost: 0.1,
586 redundancy: 0.0,
587 uncertainty: 0.4,
588 total: 0.6,
589 };
590 assert_eq!(
591 scorer.recommend_action(Some(&score), &default_ctx()),
592 UtilityAction::ToolCall
593 );
594 }
595
596 #[test]
597 fn recommend_action_uncertain_retrieves() {
598 let scorer = UtilityScorer::new(default_config());
599 let score = UtilityScore {
601 gain: 0.6,
602 cost: 0.1,
603 redundancy: 0.0,
604 uncertainty: 0.8,
605 total: 0.4,
606 };
607 assert_eq!(
608 scorer.recommend_action(Some(&score), &default_ctx()),
609 UtilityAction::Retrieve
610 );
611 }
612
613 #[test]
614 fn recommend_action_below_threshold_with_prior_calls_verifies() {
615 let scorer = UtilityScorer::new(default_config());
616 let score = UtilityScore {
617 gain: 0.3,
618 cost: 0.1,
619 redundancy: 0.0,
620 uncertainty: 0.2,
621 total: 0.05, };
623 let ctx = UtilityContext {
624 tool_calls_this_turn: 1,
625 ..default_ctx()
626 };
627 assert_eq!(
628 scorer.recommend_action(Some(&score), &ctx),
629 UtilityAction::Verify
630 );
631 }
632
633 #[test]
634 fn recommend_action_default_responds() {
635 let scorer = UtilityScorer::new(default_config());
636 let score = UtilityScore {
637 gain: 0.3,
638 cost: 0.1,
639 redundancy: 0.0,
640 uncertainty: 0.2,
641 total: 0.05, };
643 let ctx = UtilityContext {
644 tool_calls_this_turn: 0,
645 ..default_ctx()
646 };
647 assert_eq!(
648 scorer.recommend_action(Some(&score), &ctx),
649 UtilityAction::Respond
650 );
651 }
652}