1use std::sync::Arc;
10
11use super::{DispatchingRule, RuleScore, SchedulingContext};
12use crate::models::Task;
13
14#[derive(Debug, Clone, Default)]
16pub enum EvaluationMode {
17 #[default]
19 Sequential,
20 Weighted,
22}
23
24#[derive(Debug, Clone, Default)]
26pub enum TieBreaker {
27 #[default]
29 NextRule,
30 ById,
32}
33
34#[derive(Clone)]
35struct WeightedRule {
36 rule: Arc<dyn DispatchingRule>,
37 weight: f64,
38}
39
40#[derive(Clone)]
55pub struct RuleEngine {
56 rules: Vec<WeightedRule>,
57 mode: EvaluationMode,
58 tie_breaker: TieBreaker,
59 epsilon: f64,
60}
61
62impl RuleEngine {
63 pub fn new() -> Self {
65 Self {
66 rules: Vec::new(),
67 mode: EvaluationMode::Sequential,
68 tie_breaker: TieBreaker::NextRule,
69 epsilon: 1e-9,
70 }
71 }
72
73 pub fn with_rule<R: DispatchingRule + 'static>(mut self, rule: R) -> Self {
75 self.rules.push(WeightedRule {
76 rule: Arc::new(rule),
77 weight: 1.0,
78 });
79 self
80 }
81
82 pub fn with_weighted_rule<R: DispatchingRule + 'static>(
84 mut self,
85 rule: R,
86 weight: f64,
87 ) -> Self {
88 self.rules.push(WeightedRule {
89 rule: Arc::new(rule),
90 weight,
91 });
92 self
93 }
94
95 pub fn with_tie_breaker<R: DispatchingRule + 'static>(mut self, rule: R) -> Self {
97 self.rules.push(WeightedRule {
98 rule: Arc::new(rule),
99 weight: 0.0,
100 });
101 self
102 }
103
104 pub fn with_mode(mut self, mode: EvaluationMode) -> Self {
106 self.mode = mode;
107 self
108 }
109
110 pub fn with_final_tie_breaker(mut self, tie_breaker: TieBreaker) -> Self {
112 self.tie_breaker = tie_breaker;
113 self
114 }
115
116 pub fn sort_indices(&self, tasks: &[Task], context: &SchedulingContext) -> Vec<usize> {
120 if tasks.is_empty() {
121 return Vec::new();
122 }
123
124 let mut indices: Vec<usize> = (0..tasks.len()).collect();
125
126 match &self.mode {
127 EvaluationMode::Sequential => {
128 indices.sort_by(|&a, &b| self.compare_sequential(&tasks[a], &tasks[b], context));
129 }
130 EvaluationMode::Weighted => {
131 let scores: Vec<f64> = tasks
132 .iter()
133 .map(|t| self.weighted_score(t, context))
134 .collect();
135 indices.sort_by(|&a, &b| {
136 scores[a]
137 .partial_cmp(&scores[b])
138 .unwrap_or(std::cmp::Ordering::Equal)
139 });
140 }
141 }
142
143 indices
144 }
145
146 pub fn select_best(&self, tasks: &[Task], context: &SchedulingContext) -> Option<usize> {
148 self.sort_indices(tasks, context).first().copied()
149 }
150
151 pub fn evaluate(&self, task: &Task, context: &SchedulingContext) -> Vec<RuleScore> {
153 self.rules
154 .iter()
155 .map(|wr| wr.rule.evaluate(task, context) * wr.weight)
156 .collect()
157 }
158
159 fn compare_sequential(
160 &self,
161 a: &Task,
162 b: &Task,
163 context: &SchedulingContext,
164 ) -> std::cmp::Ordering {
165 for wr in &self.rules {
166 let score_a = wr.rule.evaluate(a, context);
167 let score_b = wr.rule.evaluate(b, context);
168
169 if (score_a - score_b).abs() > self.epsilon {
170 return score_a
171 .partial_cmp(&score_b)
172 .unwrap_or(std::cmp::Ordering::Equal);
173 }
174 }
175
176 match &self.tie_breaker {
178 TieBreaker::NextRule => std::cmp::Ordering::Equal,
179 TieBreaker::ById => a.id.cmp(&b.id),
180 }
181 }
182
183 fn weighted_score(&self, task: &Task, context: &SchedulingContext) -> f64 {
184 self.rules
185 .iter()
186 .map(|wr| wr.rule.evaluate(task, context) * wr.weight)
187 .sum()
188 }
189}
190
191impl Default for RuleEngine {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197impl std::fmt::Debug for RuleEngine {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct("RuleEngine")
200 .field(
201 "rules",
202 &self
203 .rules
204 .iter()
205 .map(|r| format!("{}(w={})", r.rule.name(), r.weight))
206 .collect::<Vec<_>>(),
207 )
208 .field("mode", &self.mode)
209 .finish()
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::dispatching::rules;
217 use crate::models::{Activity, ActivityDuration, Task};
218
219 fn make_task(id: &str, duration_ms: i64, deadline: Option<i64>, priority: i32) -> Task {
220 Task::new(id)
221 .with_priority(priority)
222 .with_activity(
223 Activity::new(format!("{id}_O1"), id, 0)
224 .with_duration(ActivityDuration::fixed(duration_ms)),
225 )
226 .with_deadline_opt(deadline)
227 }
228
229 trait TaskExt {
231 fn with_deadline_opt(self, deadline: Option<i64>) -> Self;
232 }
233 impl TaskExt for Task {
234 fn with_deadline_opt(mut self, deadline: Option<i64>) -> Self {
235 self.deadline = deadline;
236 self
237 }
238 }
239
240 #[test]
241 fn test_spt_ordering() {
242 let tasks = vec![
243 make_task("long", 5000, None, 0),
244 make_task("short", 1000, None, 0),
245 make_task("medium", 3000, None, 0),
246 ];
247 let ctx = SchedulingContext::at_time(0);
248 let engine = RuleEngine::new().with_rule(rules::Spt);
249
250 let indices = engine.sort_indices(&tasks, &ctx);
251 assert_eq!(tasks[indices[0]].id, "short");
252 assert_eq!(tasks[indices[1]].id, "medium");
253 assert_eq!(tasks[indices[2]].id, "long");
254 }
255
256 #[test]
257 fn test_edd_ordering() {
258 let tasks = vec![
259 make_task("late", 1000, Some(50_000), 0),
260 make_task("early", 1000, Some(10_000), 0),
261 make_task("no_deadline", 1000, None, 0),
262 ];
263 let ctx = SchedulingContext::at_time(0);
264 let engine = RuleEngine::new().with_rule(rules::Edd);
265
266 let indices = engine.sort_indices(&tasks, &ctx);
267 assert_eq!(tasks[indices[0]].id, "early");
268 assert_eq!(tasks[indices[1]].id, "late");
269 assert_eq!(tasks[indices[2]].id, "no_deadline");
270 }
271
272 #[test]
273 fn test_sequential_with_tie_breaker() {
274 let tasks = vec![
275 make_task("A", 1000, Some(10_000), 0),
276 make_task("B", 2000, Some(10_000), 0), ];
278 let ctx = SchedulingContext::at_time(0);
279 let engine = RuleEngine::new()
280 .with_rule(rules::Edd)
281 .with_tie_breaker(rules::Spt);
282
283 let indices = engine.sort_indices(&tasks, &ctx);
284 assert_eq!(tasks[indices[0]].id, "A");
286 }
287
288 #[test]
289 fn test_weighted_mode() {
290 let tasks = vec![
291 make_task("A", 1000, Some(50_000), 0),
292 make_task("B", 5000, Some(10_000), 0),
293 ];
294 let ctx = SchedulingContext::at_time(0);
295 let engine = RuleEngine::new()
296 .with_mode(EvaluationMode::Weighted)
297 .with_weighted_rule(rules::Edd, 0.5)
298 .with_weighted_rule(rules::Spt, 0.5);
299
300 let indices = engine.sort_indices(&tasks, &ctx);
301 assert_eq!(tasks[indices[0]].id, "B");
305 }
306
307 #[test]
308 fn test_by_id_tie_breaker() {
309 let tasks = vec![make_task("B", 1000, None, 0), make_task("A", 1000, None, 0)];
310 let ctx = SchedulingContext::at_time(0);
311 let engine = RuleEngine::new()
312 .with_rule(rules::Spt)
313 .with_final_tie_breaker(TieBreaker::ById);
314
315 let indices = engine.sort_indices(&tasks, &ctx);
316 assert_eq!(tasks[indices[0]].id, "A");
318 }
319
320 #[test]
321 fn test_empty_tasks() {
322 let ctx = SchedulingContext::at_time(0);
323 let engine = RuleEngine::new().with_rule(rules::Spt);
324 assert!(engine.sort_indices(&[], &ctx).is_empty());
325 assert!(engine.select_best(&[], &ctx).is_none());
326 }
327
328 #[test]
329 fn test_select_best() {
330 let tasks = vec![
331 make_task("long", 5000, None, 0),
332 make_task("short", 1000, None, 0),
333 ];
334 let ctx = SchedulingContext::at_time(0);
335 let engine = RuleEngine::new().with_rule(rules::Spt);
336
337 assert_eq!(engine.select_best(&tasks, &ctx), Some(1));
338 }
339
340 #[test]
341 fn test_evaluate_scores() {
342 let task = make_task("T1", 3000, Some(20_000), 0);
343 let ctx = SchedulingContext::at_time(0);
344 let engine = RuleEngine::new()
345 .with_rule(rules::Spt)
346 .with_rule(rules::Edd);
347
348 let scores = engine.evaluate(&task, &ctx);
349 assert_eq!(scores.len(), 2);
350 assert!((scores[0] - 3000.0).abs() < 1e-10); assert!((scores[1] - 20_000.0).abs() < 1e-10); }
353}