1use std::collections::HashMap;
10use std::hash::{DefaultHasher, Hash, Hasher};
11
12use crate::config::UtilityScoringConfig;
13use crate::executor::ToolCall;
14
15fn default_gain(tool_name: &str) -> f32 {
20 if tool_name.starts_with("memory") {
21 return 0.8;
22 }
23 if tool_name.starts_with("mcp_") {
24 return 0.5;
25 }
26 match tool_name {
27 "bash" | "shell" => 0.6,
28 "read" | "write" => 0.55,
29 "search_code" | "grep" | "glob" => 0.65,
30 _ => 0.5,
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct UtilityScore {
37 pub gain: f32,
39 pub cost: f32,
41 pub redundancy: f32,
43 pub uncertainty: f32,
45 pub total: f32,
47}
48
49impl UtilityScore {
50 fn is_valid(&self) -> bool {
52 self.gain.is_finite()
53 && self.cost.is_finite()
54 && self.redundancy.is_finite()
55 && self.uncertainty.is_finite()
56 && self.total.is_finite()
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct UtilityContext {
63 pub tool_calls_this_turn: usize,
65 pub tokens_consumed: usize,
67 pub token_budget: usize,
69 pub user_requested: bool,
72}
73
74fn call_hash(call: &ToolCall) -> u64 {
76 let mut h = DefaultHasher::new();
77 call.tool_id.hash(&mut h);
78 format!("{:?}", call.params).hash(&mut h);
82 h.finish()
83}
84
85#[derive(Debug)]
90pub struct UtilityScorer {
91 config: UtilityScoringConfig,
92 recent_calls: HashMap<u64, u32>,
94}
95
96impl UtilityScorer {
97 #[must_use]
99 pub fn new(config: UtilityScoringConfig) -> Self {
100 Self {
101 config,
102 recent_calls: HashMap::new(),
103 }
104 }
105
106 #[must_use]
108 pub fn is_enabled(&self) -> bool {
109 self.config.enabled
110 }
111
112 #[must_use]
118 pub fn score(&self, call: &ToolCall, ctx: &UtilityContext) -> Option<UtilityScore> {
119 if !self.config.enabled {
120 return None;
121 }
122
123 let gain = default_gain(&call.tool_id);
124
125 let cost = if ctx.token_budget > 0 {
126 #[allow(clippy::cast_precision_loss)]
127 (ctx.tokens_consumed as f32 / ctx.token_budget as f32).clamp(0.0, 1.0)
128 } else {
129 0.0
130 };
131
132 let hash = call_hash(call);
133 let redundancy = if self.recent_calls.contains_key(&hash) {
134 1.0_f32
135 } else {
136 0.0_f32
137 };
138
139 #[allow(clippy::cast_precision_loss)]
142 let uncertainty = (1.0_f32 - ctx.tool_calls_this_turn as f32 / 10.0).clamp(0.0, 1.0);
143
144 let total = self.config.gain_weight * gain
145 - self.config.cost_weight * cost
146 - self.config.redundancy_weight * redundancy
147 + self.config.uncertainty_bonus * uncertainty;
148
149 let score = UtilityScore {
150 gain,
151 cost,
152 redundancy,
153 uncertainty,
154 total,
155 };
156
157 if score.is_valid() { Some(score) } else { None }
158 }
159
160 #[must_use]
166 pub fn should_execute(&self, score: Option<&UtilityScore>, user_requested: bool) -> bool {
167 if user_requested {
168 return true;
169 }
170 match score {
171 Some(s) => s.total >= self.config.threshold,
172 None if !self.config.enabled => true,
175 None => false,
176 }
177 }
178
179 pub fn record_call(&mut self, call: &ToolCall) {
184 let hash = call_hash(call);
185 *self.recent_calls.entry(hash).or_insert(0) += 1;
186 }
187
188 pub fn clear(&mut self) {
190 self.recent_calls.clear();
191 }
192
193 #[must_use]
195 pub fn threshold(&self) -> f32 {
196 self.config.threshold
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use serde_json::json;
204
205 fn make_call(name: &str, params: serde_json::Value) -> ToolCall {
206 ToolCall {
207 tool_id: name.to_owned(),
208 params: if let serde_json::Value::Object(m) = params {
209 m
210 } else {
211 serde_json::Map::new()
212 },
213 }
214 }
215
216 fn default_ctx() -> UtilityContext {
217 UtilityContext {
218 tool_calls_this_turn: 0,
219 tokens_consumed: 0,
220 token_budget: 1000,
221 user_requested: false,
222 }
223 }
224
225 fn default_config() -> UtilityScoringConfig {
226 UtilityScoringConfig {
227 enabled: true,
228 ..UtilityScoringConfig::default()
229 }
230 }
231
232 #[test]
233 fn disabled_returns_none() {
234 let scorer = UtilityScorer::new(UtilityScoringConfig::default());
235 assert!(!scorer.is_enabled());
236 let call = make_call("bash", json!({}));
237 let score = scorer.score(&call, &default_ctx());
238 assert!(score.is_none());
239 assert!(scorer.should_execute(score.as_ref(), false));
241 }
242
243 #[test]
244 fn first_call_passes_default_threshold() {
245 let scorer = UtilityScorer::new(default_config());
246 let call = make_call("bash", json!({"cmd": "ls"}));
247 let score = scorer.score(&call, &default_ctx());
248 assert!(score.is_some());
249 let s = score.unwrap();
250 assert!(
251 s.total >= 0.1,
252 "first call should exceed threshold: {}",
253 s.total
254 );
255 assert!(scorer.should_execute(Some(&s), false));
256 }
257
258 #[test]
259 fn redundant_call_penalized() {
260 let mut scorer = UtilityScorer::new(default_config());
261 let call = make_call("bash", json!({"cmd": "ls"}));
262 scorer.record_call(&call);
263 let score = scorer.score(&call, &default_ctx()).unwrap();
264 assert!((score.redundancy - 1.0).abs() < f32::EPSILON);
265 }
266
267 #[test]
268 fn clear_resets_redundancy() {
269 let mut scorer = UtilityScorer::new(default_config());
270 let call = make_call("bash", json!({"cmd": "ls"}));
271 scorer.record_call(&call);
272 scorer.clear();
273 let score = scorer.score(&call, &default_ctx()).unwrap();
274 assert!(score.redundancy.abs() < f32::EPSILON);
275 }
276
277 #[test]
278 fn user_requested_always_executes() {
279 let scorer = UtilityScorer::new(default_config());
280 let score = UtilityScore {
282 gain: 0.0,
283 cost: 1.0,
284 redundancy: 1.0,
285 uncertainty: 0.0,
286 total: -100.0,
287 };
288 assert!(scorer.should_execute(Some(&score), true));
289 }
290
291 #[test]
292 fn none_score_fail_closed_when_enabled() {
293 let scorer = UtilityScorer::new(default_config());
294 assert!(!scorer.should_execute(None, false));
296 }
297
298 #[test]
299 fn none_score_executes_when_disabled() {
300 let scorer = UtilityScorer::new(UtilityScoringConfig::default()); assert!(scorer.should_execute(None, false));
302 }
303
304 #[test]
305 fn cost_increases_with_token_consumption() {
306 let scorer = UtilityScorer::new(default_config());
307 let call = make_call("bash", json!({}));
308 let ctx_low = UtilityContext {
309 tokens_consumed: 100,
310 token_budget: 1000,
311 ..default_ctx()
312 };
313 let ctx_high = UtilityContext {
314 tokens_consumed: 900,
315 token_budget: 1000,
316 ..default_ctx()
317 };
318 let s_low = scorer.score(&call, &ctx_low).unwrap();
319 let s_high = scorer.score(&call, &ctx_high).unwrap();
320 assert!(s_low.cost < s_high.cost);
321 assert!(s_low.total > s_high.total);
322 }
323
324 #[test]
325 fn uncertainty_decreases_with_call_count() {
326 let scorer = UtilityScorer::new(default_config());
327 let call = make_call("bash", json!({}));
328 let ctx_early = UtilityContext {
329 tool_calls_this_turn: 0,
330 ..default_ctx()
331 };
332 let ctx_late = UtilityContext {
333 tool_calls_this_turn: 9,
334 ..default_ctx()
335 };
336 let s_early = scorer.score(&call, &ctx_early).unwrap();
337 let s_late = scorer.score(&call, &ctx_late).unwrap();
338 assert!(s_early.uncertainty > s_late.uncertainty);
339 }
340
341 #[test]
342 fn memory_tool_has_higher_gain_than_scrape() {
343 let scorer = UtilityScorer::new(default_config());
344 let mem_call = make_call("memory_search", json!({}));
345 let web_call = make_call("scrape", json!({}));
346 let s_mem = scorer.score(&mem_call, &default_ctx()).unwrap();
347 let s_web = scorer.score(&web_call, &default_ctx()).unwrap();
348 assert!(s_mem.gain > s_web.gain);
349 }
350
351 #[test]
352 fn zero_token_budget_zeroes_cost() {
353 let scorer = UtilityScorer::new(default_config());
354 let call = make_call("bash", json!({}));
355 let ctx = UtilityContext {
356 tokens_consumed: 500,
357 token_budget: 0,
358 ..default_ctx()
359 };
360 let s = scorer.score(&call, &ctx).unwrap();
361 assert!(s.cost.abs() < f32::EPSILON);
362 }
363
364 #[test]
365 fn validate_rejects_negative_weights() {
366 let cfg = UtilityScoringConfig {
367 enabled: true,
368 gain_weight: -1.0,
369 ..UtilityScoringConfig::default()
370 };
371 assert!(cfg.validate().is_err());
372 }
373
374 #[test]
375 fn validate_rejects_nan_weights() {
376 let cfg = UtilityScoringConfig {
377 enabled: true,
378 threshold: f32::NAN,
379 ..UtilityScoringConfig::default()
380 };
381 assert!(cfg.validate().is_err());
382 }
383
384 #[test]
385 fn validate_accepts_default() {
386 assert!(UtilityScoringConfig::default().validate().is_ok());
387 }
388
389 #[test]
390 fn threshold_zero_all_calls_pass() {
391 let scorer = UtilityScorer::new(UtilityScoringConfig {
393 enabled: true,
394 threshold: 0.0,
395 ..UtilityScoringConfig::default()
396 });
397 let call = make_call("bash", json!({}));
398 let score = scorer.score(&call, &default_ctx()).unwrap();
399 assert!(
401 score.total >= 0.0,
402 "total should be non-negative: {}",
403 score.total
404 );
405 assert!(scorer.should_execute(Some(&score), false));
406 }
407
408 #[test]
409 fn threshold_one_blocks_all_calls() {
410 let scorer = UtilityScorer::new(UtilityScoringConfig {
412 enabled: true,
413 threshold: 1.0,
414 ..UtilityScoringConfig::default()
415 });
416 let call = make_call("bash", json!({}));
417 let score = scorer.score(&call, &default_ctx()).unwrap();
418 assert!(
419 score.total < 1.0,
420 "realistic score should be below 1.0: {}",
421 score.total
422 );
423 assert!(!scorer.should_execute(Some(&score), false));
424 }
425}