1use std::any::Any;
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::time::Duration;
9
10use rayon::prelude::*;
11
12use crate::async_task::TaskStatus;
13use crate::extensions::Extensions;
14use crate::online_stats::SwarmStats;
15use crate::types::{AgentId, TaskId, WorkerId};
16
17#[derive(Debug, Clone, Default)]
23pub struct LlmStats {
24 pub invocations: u64,
26 pub errors: u64,
28 pub total_duration: Duration,
30}
31
32impl LlmStats {
33 pub fn success_rate(&self) -> f64 {
35 if self.invocations == 0 {
36 1.0
37 } else {
38 (self.invocations - self.errors) as f64 / self.invocations as f64
39 }
40 }
41
42 pub fn record(&mut self, success: bool, duration: Duration) {
44 self.invocations += 1;
45 self.total_duration += duration;
46 if !success {
47 self.errors += 1;
48 }
49 }
50}
51
52pub struct SwarmState {
54 pub shared: SharedState,
56 pub workers: WorkerStates,
58}
59
60impl SwarmState {
61 pub fn new(worker_count: usize) -> Self {
62 Self {
63 shared: SharedState::default(),
64 workers: WorkerStates::new(worker_count),
65 }
66 }
67
68 pub fn advance_tick(&mut self) {
70 self.shared.tick += 1;
71 }
72}
73
74#[derive(Default)]
76pub struct SharedState {
77 pub environment: Environment,
79 pub stats: SwarmStats,
81 pub tick: u64,
83 pub shared_data: SharedData,
85 pub extensions: Extensions,
87 pub avg_tick_duration_ns: u64,
89 pub done_workers: HashSet<WorkerId>,
91 pub environment_done: bool,
93 pub llm_stats: LlmStats,
95}
96
97
98impl SharedState {
99 pub fn mark_worker_done(&mut self, worker_id: WorkerId) {
101 self.done_workers.insert(worker_id);
102 self.environment_done = true;
103 }
104
105 pub fn is_worker_done(&self, worker_id: WorkerId) -> bool {
107 self.done_workers.contains(&worker_id)
108 }
109
110 pub fn is_environment_done(&self) -> bool {
112 self.environment_done
113 }
114
115 pub fn llm_invocations(&self) -> u64 {
117 self.llm_stats.invocations
118 }
119
120 pub fn llm_errors(&self) -> u64 {
122 self.llm_stats.errors
123 }
124}
125
126#[derive(Default)]
128pub struct Environment {
129 pub variables: HashMap<String, String>,
131 pub flags: HashMap<String, bool>,
133}
134
135#[derive(Debug, Clone)]
143pub struct TickSnapshot {
144 pub tick: u64,
146 pub duration: std::time::Duration,
148 pub manager_phase: Option<ManagerPhaseSnapshot>,
150 pub worker_results: Vec<WorkerResultSnapshot>,
152}
153
154#[derive(Debug, Clone)]
156pub struct ManagerPhaseSnapshot {
157 pub batch_request: crate::agent::BatchDecisionRequest,
159 pub responses: Vec<(crate::types::WorkerId, crate::agent::DecisionResponse)>,
161 pub guidances: std::collections::HashMap<crate::types::WorkerId, crate::agent::Guidance>,
163 pub llm_errors: u64,
165}
166
167#[derive(Debug, Clone)]
169pub struct WorkerResultSnapshot {
170 pub worker_id: crate::types::WorkerId,
171 pub guidance_received: Option<crate::agent::Guidance>,
173 pub result: WorkResultSnapshot,
175}
176
177#[derive(Debug, Clone)]
179pub enum WorkResultSnapshot {
180 Acted {
182 action_result: ActionResultSnapshot,
183 state_delta: Option<crate::agent::WorkerStateDelta>,
184 },
185 Continuing { progress: f32 },
187 NeedsGuidance {
189 reason: String,
190 context: crate::agent::GuidanceContext,
191 },
192 Escalate {
194 reason: crate::agent::EscalationReason,
195 context: Option<String>,
196 },
197 Idle,
199 Done {
201 success: bool,
202 message: Option<String>,
203 },
204}
205
206#[derive(Debug, Clone)]
208pub struct ActionResultSnapshot {
209 pub success: bool,
210 pub output_debug: Option<String>,
212 pub duration: std::time::Duration,
213 pub error: Option<String>,
214}
215
216impl ActionResultSnapshot {
217 pub fn from_action_result(result: &crate::types::ActionResult) -> Self {
219 Self {
220 success: result.success,
221 output_debug: result.output.as_ref().map(|o| o.as_text()),
222 duration: result.duration,
223 error: result.error.clone(),
224 }
225 }
226}
227
228const DEFAULT_MAX_ENV_ENTRIES: usize = 500;
230
231pub struct SharedData {
233 pub kv: HashMap<String, Vec<u8>>,
235 pub completed_async_tasks: Vec<CompletedAsyncTask>,
238 max_env_entries: usize,
240}
241
242impl Default for SharedData {
243 fn default() -> Self {
244 Self {
245 kv: HashMap::new(),
246 completed_async_tasks: Vec::new(),
247 max_env_entries: DEFAULT_MAX_ENV_ENTRIES,
248 }
249 }
250}
251
252impl SharedData {
253 pub fn cleanup_env_entries(&mut self) {
258 let mut env_entries: Vec<(String, u64)> = self
260 .kv
261 .keys()
262 .filter(|k| k.starts_with("env:"))
263 .filter_map(|k| {
264 k.rsplit(':')
266 .next()?
267 .parse::<u64>()
268 .ok()
269 .map(|tick| (k.clone(), tick))
270 })
271 .collect();
272
273 if env_entries.len() <= self.max_env_entries {
274 return;
275 }
276
277 env_entries.sort_by_key(|(_, tick)| *tick);
279
280 let remove_count = env_entries.len() - self.max_env_entries;
282 for (key, _) in env_entries.into_iter().take(remove_count) {
283 self.kv.remove(&key);
284 }
285 }
286
287 pub fn set_max_env_entries(&mut self, max: usize) {
289 self.max_env_entries = max;
290 }
291}
292
293#[derive(Debug, Clone)]
297pub struct CompletedAsyncTask {
298 pub task_id: TaskId,
300 pub worker_id: Option<WorkerId>,
302 pub task_type: String,
304 pub completed_at_tick: u64,
306 pub status: TaskStatus,
308 pub error: Option<String>,
310}
311
312pub struct WorkerStates {
314 states: Vec<WorkerState>,
316}
317
318impl WorkerStates {
319 pub fn new(count: usize) -> Self {
320 let states = (0..count).map(|i| WorkerState::new(AgentId(i))).collect();
321 Self { states }
322 }
323
324 pub fn get_mut(&mut self, id: AgentId) -> Option<&mut WorkerState> {
326 self.states.get_mut(id.0)
327 }
328
329 pub fn get(&self, id: AgentId) -> Option<&WorkerState> {
331 self.states.get(id.0)
332 }
333
334 pub fn len(&self) -> usize {
336 self.states.len()
337 }
338
339 pub fn is_empty(&self) -> bool {
341 self.states.is_empty()
342 }
343
344 pub fn iter(&self) -> impl Iterator<Item = &WorkerState> {
346 self.states.iter()
347 }
348
349 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut WorkerState> {
351 self.states.iter_mut()
352 }
353
354 pub fn par_iter_mut(&mut self) -> impl ParallelIterator<Item = &mut WorkerState> {
356 self.states.par_iter_mut()
357 }
358}
359
360#[derive(Debug, Clone, PartialEq, Eq)]
362pub enum EscalationReason {
363 ConsecutiveFailures(u32),
365 ResourceExhausted,
367 Timeout,
369 AgentRequested(String),
371 Unknown(String),
373}
374
375#[derive(Debug, Clone)]
377pub struct Escalation {
378 pub reason: EscalationReason,
380 pub raised_at_tick: u64,
382 pub context: Option<String>,
384}
385
386impl Escalation {
387 pub fn consecutive_failures(count: u32, tick: u64) -> Self {
388 Self {
389 reason: EscalationReason::ConsecutiveFailures(count),
390 raised_at_tick: tick,
391 context: None,
392 }
393 }
394
395 pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
396 self.context = Some(ctx.into());
397 self
398 }
399}
400
401pub struct WorkerState {
405 pub id: AgentId,
407 internal_state: Option<Box<dyn Any + Send + Sync>>,
409 pub history: ActionHistory,
411 pub cache: LocalCache,
413 pub pending_tasks: HashSet<TaskId>,
415 pub escalation: Option<Escalation>,
417 pub consecutive_failures: u32,
419 pub last_output: Option<String>,
421}
422
423impl WorkerState {
424 pub fn new(id: AgentId) -> Self {
425 Self {
426 id,
427 internal_state: None,
428 history: ActionHistory::default(),
429 cache: LocalCache::default(),
430 pending_tasks: HashSet::new(),
431 escalation: None,
432 consecutive_failures: 0,
433 last_output: None,
434 }
435 }
436
437 pub fn raise_escalation(&mut self, escalation: Escalation) {
439 self.escalation = Some(escalation);
440 }
441
442 pub fn clear_escalation(&mut self) {
444 self.escalation = None;
445 self.consecutive_failures = 0;
446 }
447
448 pub fn record_failure(&mut self, tick: u64, threshold: u32) -> bool {
450 self.consecutive_failures += 1;
451 if self.consecutive_failures >= threshold && self.escalation.is_none() {
452 self.raise_escalation(Escalation::consecutive_failures(
453 self.consecutive_failures,
454 tick,
455 ));
456 true
457 } else {
458 false
459 }
460 }
461
462 pub fn record_success(&mut self) {
464 self.consecutive_failures = 0;
465 }
466
467 pub fn set_state<T: Any + Send + Sync + 'static>(&mut self, state: T) {
469 self.internal_state = Some(Box::new(state));
470 }
471
472 pub fn get_state<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
474 self.internal_state.as_ref()?.downcast_ref()
475 }
476
477 pub fn get_state_mut<T: Any + Send + Sync + 'static>(&mut self) -> Option<&mut T> {
479 self.internal_state.as_mut()?.downcast_mut()
480 }
481
482 pub fn add_pending_task(&mut self, task_id: TaskId) {
484 self.pending_tasks.insert(task_id);
485 }
486
487 pub fn complete_task(&mut self, task_id: TaskId) {
489 self.pending_tasks.remove(&task_id);
490 }
491}
492
493pub struct ActionHistory {
498 entries: VecDeque<HistoryEntry>,
500 max_entries: usize,
502}
503
504impl Default for ActionHistory {
505 fn default() -> Self {
506 Self::new(100) }
508}
509
510impl ActionHistory {
511 pub fn new(max_entries: usize) -> Self {
512 Self {
513 entries: VecDeque::with_capacity(max_entries),
514 max_entries,
515 }
516 }
517
518 pub fn push(&mut self, entry: HistoryEntry) {
520 if self.max_entries > 0 && self.entries.len() >= self.max_entries {
521 self.entries.pop_front(); }
523 self.entries.push_back(entry);
524 }
525
526 pub fn latest(&self) -> Option<&HistoryEntry> {
528 self.entries.back()
529 }
530
531 pub fn len(&self) -> usize {
533 self.entries.len()
534 }
535
536 pub fn is_empty(&self) -> bool {
538 self.entries.is_empty()
539 }
540
541 pub fn iter(&self) -> impl Iterator<Item = &HistoryEntry> {
543 self.entries.iter()
544 }
545}
546
547#[derive(Debug, Clone)]
549pub struct HistoryEntry {
550 pub tick: u64,
551 pub action_name: String,
552 pub success: bool,
553}
554
555#[derive(Default)]
557pub struct LocalCache {
558 data: HashMap<String, CacheEntry>,
560}
561
562impl LocalCache {
563 pub fn set(&mut self, key: impl Into<String>, value: Vec<u8>, ttl_ticks: u64) {
565 self.data.insert(
566 key.into(),
567 CacheEntry {
568 value,
569 expires_at_tick: ttl_ticks,
570 },
571 );
572 }
573
574 pub fn get(&self, key: &str, current_tick: u64) -> Option<&[u8]> {
576 let entry = self.data.get(key)?;
577 if entry.expires_at_tick > current_tick {
578 Some(&entry.value)
579 } else {
580 None
581 }
582 }
583
584 pub fn cleanup(&mut self, current_tick: u64) {
586 self.data.retain(|_, v| v.expires_at_tick > current_tick);
587 }
588}
589
590struct CacheEntry {
592 value: Vec<u8>,
593 expires_at_tick: u64,
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 #[test]
601 fn test_swarm_state_creation() {
602 let state = SwarmState::new(3);
603 assert_eq!(state.workers.len(), 3);
604 assert_eq!(state.shared.tick, 0);
605 }
606
607 #[test]
608 fn test_swarm_state_advance_tick() {
609 let mut state = SwarmState::new(1);
610 assert_eq!(state.shared.tick, 0);
611
612 state.advance_tick();
613 assert_eq!(state.shared.tick, 1);
614
615 state.advance_tick();
616 assert_eq!(state.shared.tick, 2);
617 }
618
619 #[test]
620 fn test_worker_states_access() {
621 let mut states = WorkerStates::new(3);
622 assert_eq!(states.len(), 3);
623 assert!(!states.is_empty());
624
625 let ws = states.get_mut(AgentId(1)).unwrap();
627 assert_eq!(ws.id.0, 1);
628
629 assert!(states.get(AgentId(10)).is_none());
631 }
632
633 #[test]
634 fn test_worker_state_internal() {
635 let mut ws = WorkerState::new(AgentId(0));
636
637 assert!(ws.get_state::<i32>().is_none());
639
640 ws.set_state(42i32);
642 assert_eq!(ws.get_state::<i32>(), Some(&42));
643
644 if let Some(state) = ws.get_state_mut::<i32>() {
646 *state = 100;
647 }
648 assert_eq!(ws.get_state::<i32>(), Some(&100));
649
650 assert!(ws.get_state::<String>().is_none());
652 }
653
654 #[test]
655 fn test_worker_state_pending_tasks() {
656 let mut ws = WorkerState::new(AgentId(0));
657 assert!(ws.pending_tasks.is_empty());
658
659 ws.add_pending_task(TaskId(1));
660 ws.add_pending_task(TaskId(2));
661 assert_eq!(ws.pending_tasks.len(), 2);
662 assert!(ws.pending_tasks.contains(&TaskId(1)));
663 assert!(ws.pending_tasks.contains(&TaskId(2)));
664
665 ws.complete_task(TaskId(1));
666 assert_eq!(ws.pending_tasks.len(), 1);
667 assert!(!ws.pending_tasks.contains(&TaskId(1)));
668 assert!(ws.pending_tasks.contains(&TaskId(2)));
669 }
670
671 #[test]
672 fn test_action_history() {
673 let mut history = ActionHistory::new(3);
674
675 history.push(HistoryEntry {
676 tick: 0,
677 action_name: "action1".to_string(),
678 success: true,
679 });
680 history.push(HistoryEntry {
681 tick: 1,
682 action_name: "action2".to_string(),
683 success: false,
684 });
685
686 assert_eq!(history.len(), 2);
687 assert_eq!(history.latest().unwrap().action_name, "action2");
688
689 history.push(HistoryEntry {
691 tick: 2,
692 action_name: "action3".to_string(),
693 success: true,
694 });
695 history.push(HistoryEntry {
696 tick: 3,
697 action_name: "action4".to_string(),
698 success: true,
699 });
700
701 assert_eq!(history.len(), 3);
702 let entries: Vec<_> = history.iter().collect();
704 assert_eq!(entries[0].action_name, "action2");
705 }
706
707 #[test]
708 fn test_local_cache() {
709 let mut cache = LocalCache::default();
710
711 cache.set("key1", vec![1, 2, 3], 10);
712 cache.set("key2", vec![4, 5, 6], 5);
713
714 assert_eq!(cache.get("key1", 0), Some([1u8, 2, 3].as_slice()));
716 assert_eq!(cache.get("key2", 4), Some([4u8, 5, 6].as_slice()));
717
718 assert!(cache.get("key2", 5).is_none());
720 assert!(cache.get("key2", 10).is_none());
721
722 assert_eq!(cache.get("key1", 9), Some([1u8, 2, 3].as_slice()));
724
725 cache.cleanup(6);
727 assert!(cache.get("key1", 0).is_some()); cache.cleanup(11);
729 assert!(cache.get("key1", 0).is_none()); }
731
732 #[test]
733 fn test_environment() {
734 let mut env = Environment::default();
735 env.variables
736 .insert("PATH".to_string(), "/usr/bin".to_string());
737 env.flags.insert("debug".to_string(), true);
738
739 assert_eq!(env.variables.get("PATH"), Some(&"/usr/bin".to_string()));
740 assert_eq!(env.flags.get("debug"), Some(&true));
741 }
742}