swarm_engine_core/context/
summary.rs1use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20
21use serde_json::Value;
22
23use crate::actions::ActionsConfig;
24use crate::agent::Guidance;
25use crate::state::Escalation;
26use crate::types::WorkerId;
27
28#[derive(Debug, Clone)]
34pub struct WorkerSummary {
35 pub id: WorkerId,
37 pub consecutive_failures: u32,
39 pub last_action: Option<String>,
41 pub last_success: Option<bool>,
43 pub last_output: Option<String>,
45 pub history_len: usize,
47 pub has_escalation: bool,
49}
50
51impl WorkerSummary {
52 pub fn new(id: WorkerId) -> Self {
53 Self {
54 id,
55 consecutive_failures: 0,
56 last_action: None,
57 last_success: None,
58 last_output: None,
59 history_len: 0,
60 has_escalation: false,
61 }
62 }
63
64 pub fn with_failures(mut self, count: u32) -> Self {
66 self.consecutive_failures = count;
67 self
68 }
69
70 pub fn with_last_action(mut self, action: impl Into<String>, success: bool) -> Self {
72 self.last_action = Some(action.into());
73 self.last_success = Some(success);
74 self
75 }
76
77 pub fn with_history_len(mut self, len: usize) -> Self {
79 self.history_len = len;
80 self
81 }
82
83 pub fn with_escalation(mut self, has_escalation: bool) -> Self {
85 self.has_escalation = has_escalation;
86 self
87 }
88}
89
90#[derive(Debug, Clone)]
111pub struct TaskContext {
112 pub tick: u64,
115 pub workers: HashMap<WorkerId, WorkerSummary>,
117 pub success_rate: f64,
119 pub progress: f64,
121 pub escalations: Vec<(WorkerId, Escalation)>,
123 pub available_actions: Option<ActionsConfig>,
125
126 pub v2_guidances: Option<Vec<Guidance>>,
132 pub excluded_actions: Vec<String>,
134
135 pub previous_guidances: HashMap<WorkerId, Arc<Guidance>>,
142
143 pub done_workers: HashSet<WorkerId>,
146
147 pub metadata: HashMap<String, Value>,
150}
151
152impl TaskContext {
153 pub fn new(tick: u64) -> Self {
155 Self {
156 tick,
157 workers: HashMap::new(),
158 success_rate: 0.0,
159 progress: 0.0,
160 escalations: Vec::new(),
161 available_actions: None,
162 v2_guidances: None,
163 excluded_actions: Vec::new(),
164 previous_guidances: HashMap::new(),
165 done_workers: HashSet::new(),
166 metadata: HashMap::new(),
167 }
168 }
169
170 pub fn with_worker(mut self, summary: WorkerSummary) -> Self {
174 self.workers.insert(summary.id, summary);
175 self
176 }
177
178 pub fn with_success_rate(mut self, rate: f64) -> Self {
180 self.success_rate = rate;
181 self
182 }
183
184 pub fn with_progress(mut self, progress: f64) -> Self {
186 self.progress = progress;
187 self
188 }
189
190 pub fn with_escalation(mut self, worker_id: WorkerId, escalation: Escalation) -> Self {
192 self.escalations.push((worker_id, escalation));
193 self
194 }
195
196 pub fn with_actions(mut self, actions: ActionsConfig) -> Self {
198 self.available_actions = Some(actions);
199 self
200 }
201
202 pub fn with_previous_guidances(mut self, guidances: HashMap<WorkerId, Arc<Guidance>>) -> Self {
204 self.previous_guidances = guidances;
205 self
206 }
207
208 pub fn with_previous_guidance(mut self, worker_id: WorkerId, guidance: Arc<Guidance>) -> Self {
210 self.previous_guidances.insert(worker_id, guidance);
211 self
212 }
213
214 pub fn insert<V: Into<Value>>(mut self, key: impl Into<String>, value: V) -> Self {
218 self.metadata.insert(key.into(), value.into());
219 self
220 }
221
222 pub fn set<V: Into<Value>>(&mut self, key: impl Into<String>, value: V) {
224 self.metadata.insert(key.into(), value.into());
225 }
226
227 pub fn get(&self, key: &str) -> Option<&Value> {
229 self.metadata.get(key)
230 }
231
232 pub fn get_str(&self, key: &str) -> Option<&str> {
234 self.metadata.get(key).and_then(|v| v.as_str())
235 }
236
237 pub fn get_f64(&self, key: &str) -> Option<f64> {
239 self.metadata.get(key).and_then(|v| v.as_f64())
240 }
241
242 pub fn get_i64(&self, key: &str) -> Option<i64> {
244 self.metadata.get(key).and_then(|v| v.as_i64())
245 }
246
247 pub fn get_bool(&self, key: &str) -> Option<bool> {
249 self.metadata.get(key).and_then(|v| v.as_bool())
250 }
251
252 pub fn has_escalations(&self) -> bool {
256 !self.escalations.is_empty()
257 }
258
259 pub fn has_escalation_for(&self, worker_id: WorkerId) -> bool {
261 self.escalations.iter().any(|(id, _)| *id == worker_id)
262 }
263
264 pub fn worker(&self, id: WorkerId) -> Option<&WorkerSummary> {
266 self.workers.get(&id)
267 }
268
269 pub fn escalated_worker_count(&self) -> usize {
271 self.workers.values().filter(|w| w.has_escalation).count()
272 }
273
274 pub fn worker_ids(&self) -> Vec<WorkerId> {
276 self.workers.keys().copied().collect()
277 }
278}
279
280impl TaskContext {
281 pub fn has_exploration(&self) -> bool {
285 self.v2_guidances.is_some()
286 }
287
288 pub fn filter_for_workers(&self, worker_ids: &[WorkerId]) -> TaskContext {
293 use std::collections::HashSet;
294 let worker_set: HashSet<WorkerId> = worker_ids.iter().copied().collect();
295
296 let filtered_workers: HashMap<WorkerId, WorkerSummary> = self
298 .workers
299 .iter()
300 .filter(|(id, _)| worker_set.contains(id))
301 .map(|(id, summary)| (*id, summary.clone()))
302 .collect();
303
304 let filtered_escalations: Vec<(WorkerId, Escalation)> = self
306 .escalations
307 .iter()
308 .filter(|(id, _)| worker_set.contains(id))
309 .cloned()
310 .collect();
311
312 let filtered_guidances: HashMap<WorkerId, Arc<Guidance>> = self
314 .previous_guidances
315 .iter()
316 .filter(|(id, _)| worker_set.contains(id))
317 .map(|(id, g)| (*id, Arc::clone(g)))
318 .collect();
319
320 let filtered_done_workers: HashSet<WorkerId> = self
322 .done_workers
323 .iter()
324 .filter(|id| worker_set.contains(id))
325 .copied()
326 .collect();
327
328 TaskContext {
329 tick: self.tick,
330 workers: filtered_workers,
331 success_rate: self.success_rate,
332 progress: self.progress,
333 escalations: filtered_escalations,
334 available_actions: self.available_actions.clone(),
335 v2_guidances: self.v2_guidances.clone(),
336 excluded_actions: self.excluded_actions.clone(),
337 previous_guidances: filtered_guidances,
338 done_workers: filtered_done_workers,
339 metadata: self.metadata.clone(),
340 }
341 }
342}
343
344impl Default for TaskContext {
345 fn default() -> Self {
346 Self::new(0)
347 }
348}
349
350#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_task_context_new() {
360 let ctx = TaskContext::new(10);
361 assert_eq!(ctx.tick, 10);
362 assert!(ctx.workers.is_empty());
363 assert_eq!(ctx.success_rate, 0.0);
364 assert_eq!(ctx.progress, 0.0);
365 }
366
367 #[test]
368 fn test_task_context_builder() {
369 let ctx = TaskContext::new(5)
370 .with_worker(WorkerSummary::new(WorkerId(0)))
371 .with_worker(WorkerSummary::new(WorkerId(1)).with_escalation(true))
372 .with_success_rate(0.8)
373 .with_progress(0.5)
374 .insert("key1", "value1")
375 .insert("count", 42);
376
377 assert_eq!(ctx.tick, 5);
378 assert_eq!(ctx.workers.len(), 2);
379 assert_eq!(ctx.success_rate, 0.8);
380 assert_eq!(ctx.progress, 0.5);
381 assert_eq!(ctx.get_str("key1"), Some("value1"));
382 assert_eq!(ctx.get_i64("count"), Some(42));
383 }
384
385 #[test]
386 fn test_worker_summary() {
387 let summary = WorkerSummary::new(WorkerId(0))
388 .with_failures(2)
389 .with_last_action("read:/path", true)
390 .with_history_len(10)
391 .with_escalation(true);
392
393 assert_eq!(summary.id, WorkerId(0));
394 assert_eq!(summary.consecutive_failures, 2);
395 assert_eq!(summary.last_action, Some("read:/path".to_string()));
396 assert_eq!(summary.last_success, Some(true));
397 assert_eq!(summary.history_len, 10);
398 assert!(summary.has_escalation);
399 }
400
401 #[test]
402 fn test_query_methods() {
403 let ctx = TaskContext::new(0)
404 .with_worker(WorkerSummary::new(WorkerId(0)))
405 .with_worker(WorkerSummary::new(WorkerId(1)).with_escalation(true))
406 .with_worker(WorkerSummary::new(WorkerId(2)));
407
408 assert_eq!(ctx.escalated_worker_count(), 1);
409 assert_eq!(ctx.worker_ids().len(), 3);
410 }
411
412 #[test]
413 fn test_filter_for_workers() {
414 let ctx = TaskContext::new(10)
416 .with_worker(WorkerSummary::new(WorkerId(0)).with_failures(1))
417 .with_worker(WorkerSummary::new(WorkerId(1)).with_escalation(true))
418 .with_worker(WorkerSummary::new(WorkerId(2)).with_history_len(5))
419 .with_worker(WorkerSummary::new(WorkerId(3)).with_last_action("read", true))
420 .with_escalation(WorkerId(1), Escalation::consecutive_failures(3, 5))
421 .with_success_rate(0.75)
422 .with_progress(0.5)
423 .insert("meta_key", "meta_value");
424
425 let filtered = ctx.filter_for_workers(&[WorkerId(0), WorkerId(2)]);
427
428 assert_eq!(filtered.tick, 10);
430 assert_eq!(filtered.workers.len(), 2);
431 assert!(filtered.workers.contains_key(&WorkerId(0)));
432 assert!(filtered.workers.contains_key(&WorkerId(2)));
433 assert!(!filtered.workers.contains_key(&WorkerId(1)));
434 assert!(!filtered.workers.contains_key(&WorkerId(3)));
435
436 assert_eq!(
438 filtered
439 .workers
440 .get(&WorkerId(0))
441 .unwrap()
442 .consecutive_failures,
443 1
444 );
445
446 assert_eq!(filtered.workers.get(&WorkerId(2)).unwrap().history_len, 5);
448
449 assert!(filtered.escalations.is_empty());
451
452 assert_eq!(filtered.success_rate, 0.75);
454 assert_eq!(filtered.progress, 0.5);
455
456 assert_eq!(filtered.get_str("meta_key"), Some("meta_value"));
458 }
459
460 #[test]
461 fn test_filter_for_workers_empty() {
462 let ctx = TaskContext::new(5)
463 .with_worker(WorkerSummary::new(WorkerId(0)))
464 .with_worker(WorkerSummary::new(WorkerId(1)));
465
466 let filtered = ctx.filter_for_workers(&[]);
467
468 assert_eq!(filtered.tick, 5);
469 assert!(filtered.workers.is_empty());
470 }
471}