swarm_engine_eval/scenario/
milestone.rs1use serde::{Deserialize, Serialize};
6
7use super::conditions::{Condition, ConditionValue};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Milestone {
14 pub name: String,
16
17 #[serde(default)]
19 pub description: Option<String>,
20
21 pub condition: Condition,
23
24 pub weight: f64,
26
27 #[serde(default)]
29 pub partial: bool,
30
31 #[serde(default)]
33 pub partial_config: Option<PartialConfig>,
34}
35
36impl Milestone {
37 pub fn new(name: impl Into<String>, condition: Condition, weight: f64) -> Self {
39 Self {
40 name: name.into(),
41 description: None,
42 condition,
43 weight,
44 partial: false,
45 partial_config: None,
46 }
47 }
48
49 pub fn with_description(mut self, description: impl Into<String>) -> Self {
51 self.description = Some(description.into());
52 self
53 }
54
55 pub fn with_partial(mut self, config: PartialConfig) -> Self {
57 self.partial = true;
58 self.partial_config = Some(config);
59 self
60 }
61
62 pub fn evaluate(&self, actual: &ConditionValue) -> f64 {
66 if self.condition.evaluate(actual) {
67 return 1.0;
68 }
69
70 if !self.partial {
71 return 0.0;
72 }
73
74 self.calculate_partial_score(actual)
76 }
77
78 fn calculate_partial_score(&self, actual: &ConditionValue) -> f64 {
80 let config = match &self.partial_config {
81 Some(c) => c,
82 None => return 0.0,
83 };
84
85 let actual_f64 = match actual {
86 ConditionValue::Integer(v) => *v as f64,
87 ConditionValue::Float(v) => *v,
88 _ => return 0.0,
89 };
90
91 let target_f64 = match &self.condition.value {
92 ConditionValue::Integer(v) => *v as f64,
93 ConditionValue::Float(v) => *v,
94 _ => return 0.0,
95 };
96
97 match config {
98 PartialConfig::Linear {
99 min,
100 max,
101 descending,
102 } => {
103 let min_val = min.unwrap_or(0.0);
104 let max_val = max.unwrap_or(target_f64);
105
106 if *descending {
107 if actual_f64 <= min_val {
110 1.0
111 } else if actual_f64 >= max_val {
112 0.0
113 } else {
114 (max_val - actual_f64) / (max_val - min_val)
115 }
116 } else {
117 if actual_f64 <= min_val {
120 0.0
121 } else if actual_f64 >= max_val {
122 1.0
123 } else {
124 (actual_f64 - min_val) / (max_val - min_val)
125 }
126 }
127 }
128 PartialConfig::Threshold { thresholds } => {
129 let mut score = 0.0;
131 for (threshold, threshold_score) in thresholds {
132 if actual_f64 >= *threshold {
133 score = *threshold_score;
134 }
135 }
136 score
137 }
138 }
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144#[serde(tag = "type", rename_all = "snake_case")]
145pub enum PartialConfig {
146 Linear {
148 min: Option<f64>,
150 max: Option<f64>,
152 #[serde(default)]
156 descending: bool,
157 },
158 Threshold {
160 thresholds: Vec<(f64, f64)>,
162 },
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct MilestoneResult {
168 pub name: String,
170 pub achievement: f64,
172 pub weight: f64,
174 pub weighted_score: f64,
176 pub completed: bool,
178}
179
180impl MilestoneResult {
181 pub fn new(milestone: &Milestone, achievement: f64) -> Self {
182 Self {
183 name: milestone.name.clone(),
184 achievement,
185 weight: milestone.weight,
186 weighted_score: achievement * milestone.weight,
187 completed: achievement >= 1.0,
188 }
189 }
190}
191
192#[derive(Debug, Clone)]
194pub struct KpiCalculator {
195 milestones: Vec<Milestone>,
196}
197
198impl KpiCalculator {
199 pub fn new(milestones: Vec<Milestone>) -> Self {
201 Self { milestones }
202 }
203
204 pub fn calculate<F>(&self, metric_getter: F) -> KpiScore
209 where
210 F: Fn(&str) -> Option<ConditionValue>,
211 {
212 let mut results = Vec::new();
213 let mut total_score = 0.0;
214 let mut total_weight = 0.0;
215
216 for milestone in &self.milestones {
217 let achievement = match metric_getter(&milestone.condition.metric) {
218 Some(value) => milestone.evaluate(&value),
219 None => 0.0,
220 };
221
222 let result = MilestoneResult::new(milestone, achievement);
223 total_score += result.weighted_score;
224 total_weight += milestone.weight;
225 results.push(result);
226 }
227
228 let normalized_score = if total_weight > 0.0 {
230 total_score / total_weight
231 } else {
232 0.0
233 };
234
235 KpiScore {
236 score: normalized_score,
237 raw_score: total_score,
238 total_weight,
239 results,
240 }
241 }
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct KpiScore {
247 pub score: f64,
249 pub raw_score: f64,
251 pub total_weight: f64,
253 pub results: Vec<MilestoneResult>,
255}
256
257impl KpiScore {
258 pub fn completed_count(&self) -> usize {
260 self.results.iter().filter(|r| r.completed).count()
261 }
262
263 pub fn total_count(&self) -> usize {
265 self.results.len()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::super::conditions::CompareOp;
272 use super::*;
273
274 fn create_test_milestone(
275 name: &str,
276 metric: &str,
277 op: CompareOp,
278 value: i64,
279 weight: f64,
280 ) -> Milestone {
281 Milestone::new(name, Condition::new(name, metric, op, value), weight)
282 }
283
284 #[test]
285 fn test_milestone_evaluate_complete() {
286 let milestone = create_test_milestone(
287 "first_collection",
288 "resources_collected",
289 CompareOp::Gte,
290 1,
291 0.2,
292 );
293
294 assert_eq!(milestone.evaluate(&ConditionValue::Integer(1)), 1.0);
295 assert_eq!(milestone.evaluate(&ConditionValue::Integer(5)), 1.0);
296 assert_eq!(milestone.evaluate(&ConditionValue::Integer(0)), 0.0);
297 }
298
299 #[test]
300 fn test_milestone_evaluate_partial_linear() {
301 let mut milestone = create_test_milestone("efficiency", "tick", CompareOp::Lte, 300, 0.3);
302 milestone = milestone.with_partial(PartialConfig::Linear {
303 min: Some(300.0),
304 max: Some(400.0),
305 descending: true, });
307
308 assert_eq!(milestone.evaluate(&ConditionValue::Integer(250)), 1.0);
310 assert_eq!(milestone.evaluate(&ConditionValue::Integer(300)), 1.0);
311
312 assert!((milestone.evaluate(&ConditionValue::Integer(350)) - 0.5).abs() < 0.01);
315
316 assert_eq!(milestone.evaluate(&ConditionValue::Integer(400)), 0.0);
318 assert_eq!(milestone.evaluate(&ConditionValue::Integer(500)), 0.0);
319 }
320
321 #[test]
322 fn test_kpi_calculator() {
323 let milestones = vec![
324 create_test_milestone("first", "collected", CompareOp::Gte, 1, 0.2),
325 create_test_milestone("half", "collected", CompareOp::Gte, 3, 0.3),
326 create_test_milestone("all", "collected", CompareOp::Gte, 5, 0.5),
327 ];
328
329 let calculator = KpiCalculator::new(milestones);
330
331 let score = calculator.calculate(|_| Some(ConditionValue::Integer(5)));
333 assert_eq!(score.score, 1.0);
334 assert_eq!(score.completed_count(), 3);
335
336 let score = calculator.calculate(|_| Some(ConditionValue::Integer(3)));
338 assert!((score.score - 0.5).abs() < 0.01);
340 assert_eq!(score.completed_count(), 2);
341
342 let score = calculator.calculate(|_| Some(ConditionValue::Integer(1)));
344 assert!((score.score - 0.2).abs() < 0.01);
346 assert_eq!(score.completed_count(), 1);
347 }
348
349 #[test]
350 fn test_milestone_deserialize() {
351 let json = r#"{
352 "name": "efficiency_bonus",
353 "description": "Complete within 300 ticks",
354 "condition": {
355 "name": "efficiency",
356 "metric": "tick",
357 "op": "lte",
358 "value": 300
359 },
360 "weight": 0.2,
361 "partial": true,
362 "partial_config": {
363 "type": "linear",
364 "min": 300.0,
365 "max": 400.0
366 }
367 }"#;
368
369 let milestone: Milestone = serde_json::from_str(json).unwrap();
370 assert_eq!(milestone.name, "efficiency_bonus");
371 assert!(milestone.partial);
372 }
373}