1use std::any::Any;
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::time::Duration;
9
10use rayon::prelude::*;
11
12use crate::async_task::TaskStatus;
13use crate::extensions::Extensions;
14use crate::online_stats::SwarmStats;
15use crate::types::{AgentId, TaskId, WorkerId};
16
17#[derive(Debug, Clone, Default)]
23pub struct LlmStats {
24 pub invocations: u64,
26 pub errors: u64,
28 pub total_duration: Duration,
30}
31
32impl LlmStats {
33 pub fn success_rate(&self) -> f64 {
35 if self.invocations == 0 {
36 1.0
37 } else {
38 (self.invocations - self.errors) as f64 / self.invocations as f64
39 }
40 }
41
42 pub fn record(&mut self, success: bool, duration: Duration) {
44 self.invocations += 1;
45 self.total_duration += duration;
46 if !success {
47 self.errors += 1;
48 }
49 }
50}
51
52pub struct SwarmState {
54 pub shared: SharedState,
56 pub workers: WorkerStates,
58}
59
60impl SwarmState {
61 pub fn new(worker_count: usize) -> Self {
62 Self {
63 shared: SharedState::default(),
64 workers: WorkerStates::new(worker_count),
65 }
66 }
67
68 pub fn advance_tick(&mut self) {
70 self.shared.tick += 1;
71 }
72}
73
74#[derive(Default)]
76pub struct SharedState {
77 pub environment: Environment,
79 pub stats: SwarmStats,
81 pub tick: u64,
83 pub shared_data: SharedData,
85 pub extensions: Extensions,
87 pub avg_tick_duration_ns: u64,
89 pub done_workers: HashSet<WorkerId>,
91 pub environment_done: bool,
93 pub llm_stats: LlmStats,
95}
96
97impl SharedState {
98 pub fn mark_worker_done(&mut self, worker_id: WorkerId) {
100 self.done_workers.insert(worker_id);
101 self.environment_done = true;
102 }
103
104 pub fn is_worker_done(&self, worker_id: WorkerId) -> bool {
106 self.done_workers.contains(&worker_id)
107 }
108
109 pub fn is_environment_done(&self) -> bool {
111 self.environment_done
112 }
113
114 pub fn llm_invocations(&self) -> u64 {
116 self.llm_stats.invocations
117 }
118
119 pub fn llm_errors(&self) -> u64 {
121 self.llm_stats.errors
122 }
123}
124
125#[derive(Default)]
127pub struct Environment {
128 pub variables: HashMap<String, String>,
130 pub flags: HashMap<String, bool>,
132}
133
134#[derive(Debug, Clone)]
142pub struct TickSnapshot {
143 pub tick: u64,
145 pub duration: std::time::Duration,
147 pub manager_phase: Option<ManagerPhaseSnapshot>,
149 pub worker_results: Vec<WorkerResultSnapshot>,
151}
152
153#[derive(Debug, Clone)]
155pub struct ManagerPhaseSnapshot {
156 pub batch_request: crate::agent::BatchDecisionRequest,
158 pub responses: Vec<(crate::types::WorkerId, crate::agent::DecisionResponse)>,
160 pub guidances: std::collections::HashMap<crate::types::WorkerId, crate::agent::Guidance>,
162 pub llm_errors: u64,
164}
165
166#[derive(Debug, Clone)]
168pub struct WorkerResultSnapshot {
169 pub worker_id: crate::types::WorkerId,
170 pub guidance_received: Option<crate::agent::Guidance>,
172 pub result: WorkResultSnapshot,
174}
175
176#[derive(Debug, Clone)]
178pub enum WorkResultSnapshot {
179 Acted {
181 action_result: ActionResultSnapshot,
182 state_delta: Option<crate::agent::WorkerStateDelta>,
183 },
184 Continuing { progress: f32 },
186 NeedsGuidance {
188 reason: String,
189 context: crate::agent::GuidanceContext,
190 },
191 Escalate {
193 reason: crate::agent::EscalationReason,
194 context: Option<String>,
195 },
196 Idle,
198 Done {
200 success: bool,
201 message: Option<String>,
202 },
203}
204
205#[derive(Debug, Clone)]
207pub struct ActionResultSnapshot {
208 pub success: bool,
209 pub output_debug: Option<String>,
211 pub duration: std::time::Duration,
212 pub error: Option<String>,
213}
214
215impl ActionResultSnapshot {
216 pub fn from_action_result(result: &crate::types::ActionResult) -> Self {
218 Self {
219 success: result.success,
220 output_debug: result.output.as_ref().map(|o| o.as_text()),
221 duration: result.duration,
222 error: result.error.clone(),
223 }
224 }
225}
226
227const DEFAULT_MAX_ENV_ENTRIES: usize = 500;
229
230pub struct SharedData {
232 pub kv: HashMap<String, Vec<u8>>,
234 pub completed_async_tasks: Vec<CompletedAsyncTask>,
237 max_env_entries: usize,
239}
240
241impl Default for SharedData {
242 fn default() -> Self {
243 Self {
244 kv: HashMap::new(),
245 completed_async_tasks: Vec::new(),
246 max_env_entries: DEFAULT_MAX_ENV_ENTRIES,
247 }
248 }
249}
250
251impl SharedData {
252 pub fn cleanup_env_entries(&mut self) {
257 let mut env_entries: Vec<(String, u64)> = self
259 .kv
260 .keys()
261 .filter(|k| k.starts_with("env:"))
262 .filter_map(|k| {
263 k.rsplit(':')
265 .next()?
266 .parse::<u64>()
267 .ok()
268 .map(|tick| (k.clone(), tick))
269 })
270 .collect();
271
272 if env_entries.len() <= self.max_env_entries {
273 return;
274 }
275
276 env_entries.sort_by_key(|(_, tick)| *tick);
278
279 let remove_count = env_entries.len() - self.max_env_entries;
281 for (key, _) in env_entries.into_iter().take(remove_count) {
282 self.kv.remove(&key);
283 }
284 }
285
286 pub fn set_max_env_entries(&mut self, max: usize) {
288 self.max_env_entries = max;
289 }
290}
291
292#[derive(Debug, Clone)]
296pub struct CompletedAsyncTask {
297 pub task_id: TaskId,
299 pub worker_id: Option<WorkerId>,
301 pub task_type: String,
303 pub completed_at_tick: u64,
305 pub status: TaskStatus,
307 pub error: Option<String>,
309}
310
311pub struct WorkerStates {
313 states: Vec<WorkerState>,
315}
316
317impl WorkerStates {
318 pub fn new(count: usize) -> Self {
319 let states = (0..count).map(|i| WorkerState::new(AgentId(i))).collect();
320 Self { states }
321 }
322
323 pub fn get_mut(&mut self, id: AgentId) -> Option<&mut WorkerState> {
325 self.states.get_mut(id.0)
326 }
327
328 pub fn get(&self, id: AgentId) -> Option<&WorkerState> {
330 self.states.get(id.0)
331 }
332
333 pub fn len(&self) -> usize {
335 self.states.len()
336 }
337
338 pub fn is_empty(&self) -> bool {
340 self.states.is_empty()
341 }
342
343 pub fn iter(&self) -> impl Iterator<Item = &WorkerState> {
345 self.states.iter()
346 }
347
348 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut WorkerState> {
350 self.states.iter_mut()
351 }
352
353 pub fn par_iter_mut(&mut self) -> impl ParallelIterator<Item = &mut WorkerState> {
355 self.states.par_iter_mut()
356 }
357}
358
359#[derive(Debug, Clone, PartialEq, Eq)]
361pub enum EscalationReason {
362 ConsecutiveFailures(u32),
364 ResourceExhausted,
366 Timeout,
368 AgentRequested(String),
370 Unknown(String),
372}
373
374#[derive(Debug, Clone)]
376pub struct Escalation {
377 pub reason: EscalationReason,
379 pub raised_at_tick: u64,
381 pub context: Option<String>,
383}
384
385impl Escalation {
386 pub fn consecutive_failures(count: u32, tick: u64) -> Self {
387 Self {
388 reason: EscalationReason::ConsecutiveFailures(count),
389 raised_at_tick: tick,
390 context: None,
391 }
392 }
393
394 pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
395 self.context = Some(ctx.into());
396 self
397 }
398}
399
400pub struct WorkerState {
404 pub id: AgentId,
406 internal_state: Option<Box<dyn Any + Send + Sync>>,
408 pub history: ActionHistory,
410 pub cache: LocalCache,
412 pub pending_tasks: HashSet<TaskId>,
414 pub escalation: Option<Escalation>,
416 pub consecutive_failures: u32,
418 pub last_output: Option<String>,
420}
421
422impl WorkerState {
423 pub fn new(id: AgentId) -> Self {
424 Self {
425 id,
426 internal_state: None,
427 history: ActionHistory::default(),
428 cache: LocalCache::default(),
429 pending_tasks: HashSet::new(),
430 escalation: None,
431 consecutive_failures: 0,
432 last_output: None,
433 }
434 }
435
436 pub fn raise_escalation(&mut self, escalation: Escalation) {
438 self.escalation = Some(escalation);
439 }
440
441 pub fn clear_escalation(&mut self) {
443 self.escalation = None;
444 self.consecutive_failures = 0;
445 }
446
447 pub fn record_failure(&mut self, tick: u64, threshold: u32) -> bool {
449 self.consecutive_failures += 1;
450 if self.consecutive_failures >= threshold && self.escalation.is_none() {
451 self.raise_escalation(Escalation::consecutive_failures(
452 self.consecutive_failures,
453 tick,
454 ));
455 true
456 } else {
457 false
458 }
459 }
460
461 pub fn record_success(&mut self) {
463 self.consecutive_failures = 0;
464 }
465
466 pub fn set_state<T: Any + Send + Sync + 'static>(&mut self, state: T) {
468 self.internal_state = Some(Box::new(state));
469 }
470
471 pub fn get_state<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
473 self.internal_state.as_ref()?.downcast_ref()
474 }
475
476 pub fn get_state_mut<T: Any + Send + Sync + 'static>(&mut self) -> Option<&mut T> {
478 self.internal_state.as_mut()?.downcast_mut()
479 }
480
481 pub fn add_pending_task(&mut self, task_id: TaskId) {
483 self.pending_tasks.insert(task_id);
484 }
485
486 pub fn complete_task(&mut self, task_id: TaskId) {
488 self.pending_tasks.remove(&task_id);
489 }
490}
491
492pub struct ActionHistory {
497 entries: VecDeque<HistoryEntry>,
499 max_entries: usize,
501}
502
503impl Default for ActionHistory {
504 fn default() -> Self {
505 Self::new(100) }
507}
508
509impl ActionHistory {
510 pub fn new(max_entries: usize) -> Self {
511 Self {
512 entries: VecDeque::with_capacity(max_entries),
513 max_entries,
514 }
515 }
516
517 pub fn push(&mut self, entry: HistoryEntry) {
519 if self.max_entries > 0 && self.entries.len() >= self.max_entries {
520 self.entries.pop_front(); }
522 self.entries.push_back(entry);
523 }
524
525 pub fn latest(&self) -> Option<&HistoryEntry> {
527 self.entries.back()
528 }
529
530 pub fn len(&self) -> usize {
532 self.entries.len()
533 }
534
535 pub fn is_empty(&self) -> bool {
537 self.entries.is_empty()
538 }
539
540 pub fn iter(&self) -> impl Iterator<Item = &HistoryEntry> {
542 self.entries.iter()
543 }
544}
545
546#[derive(Debug, Clone)]
548pub struct HistoryEntry {
549 pub tick: u64,
550 pub action_name: String,
551 pub success: bool,
552}
553
554#[derive(Default)]
556pub struct LocalCache {
557 data: HashMap<String, CacheEntry>,
559}
560
561impl LocalCache {
562 pub fn set(&mut self, key: impl Into<String>, value: Vec<u8>, ttl_ticks: u64) {
564 self.data.insert(
565 key.into(),
566 CacheEntry {
567 value,
568 expires_at_tick: ttl_ticks,
569 },
570 );
571 }
572
573 pub fn get(&self, key: &str, current_tick: u64) -> Option<&[u8]> {
575 let entry = self.data.get(key)?;
576 if entry.expires_at_tick > current_tick {
577 Some(&entry.value)
578 } else {
579 None
580 }
581 }
582
583 pub fn cleanup(&mut self, current_tick: u64) {
585 self.data.retain(|_, v| v.expires_at_tick > current_tick);
586 }
587}
588
589struct CacheEntry {
591 value: Vec<u8>,
592 expires_at_tick: u64,
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598
599 #[test]
600 fn test_swarm_state_creation() {
601 let state = SwarmState::new(3);
602 assert_eq!(state.workers.len(), 3);
603 assert_eq!(state.shared.tick, 0);
604 }
605
606 #[test]
607 fn test_swarm_state_advance_tick() {
608 let mut state = SwarmState::new(1);
609 assert_eq!(state.shared.tick, 0);
610
611 state.advance_tick();
612 assert_eq!(state.shared.tick, 1);
613
614 state.advance_tick();
615 assert_eq!(state.shared.tick, 2);
616 }
617
618 #[test]
619 fn test_worker_states_access() {
620 let mut states = WorkerStates::new(3);
621 assert_eq!(states.len(), 3);
622 assert!(!states.is_empty());
623
624 let ws = states.get_mut(AgentId(1)).unwrap();
626 assert_eq!(ws.id.0, 1);
627
628 assert!(states.get(AgentId(10)).is_none());
630 }
631
632 #[test]
633 fn test_worker_state_internal() {
634 let mut ws = WorkerState::new(AgentId(0));
635
636 assert!(ws.get_state::<i32>().is_none());
638
639 ws.set_state(42i32);
641 assert_eq!(ws.get_state::<i32>(), Some(&42));
642
643 if let Some(state) = ws.get_state_mut::<i32>() {
645 *state = 100;
646 }
647 assert_eq!(ws.get_state::<i32>(), Some(&100));
648
649 assert!(ws.get_state::<String>().is_none());
651 }
652
653 #[test]
654 fn test_worker_state_pending_tasks() {
655 let mut ws = WorkerState::new(AgentId(0));
656 assert!(ws.pending_tasks.is_empty());
657
658 ws.add_pending_task(TaskId(1));
659 ws.add_pending_task(TaskId(2));
660 assert_eq!(ws.pending_tasks.len(), 2);
661 assert!(ws.pending_tasks.contains(&TaskId(1)));
662 assert!(ws.pending_tasks.contains(&TaskId(2)));
663
664 ws.complete_task(TaskId(1));
665 assert_eq!(ws.pending_tasks.len(), 1);
666 assert!(!ws.pending_tasks.contains(&TaskId(1)));
667 assert!(ws.pending_tasks.contains(&TaskId(2)));
668 }
669
670 #[test]
671 fn test_action_history() {
672 let mut history = ActionHistory::new(3);
673
674 history.push(HistoryEntry {
675 tick: 0,
676 action_name: "action1".to_string(),
677 success: true,
678 });
679 history.push(HistoryEntry {
680 tick: 1,
681 action_name: "action2".to_string(),
682 success: false,
683 });
684
685 assert_eq!(history.len(), 2);
686 assert_eq!(history.latest().unwrap().action_name, "action2");
687
688 history.push(HistoryEntry {
690 tick: 2,
691 action_name: "action3".to_string(),
692 success: true,
693 });
694 history.push(HistoryEntry {
695 tick: 3,
696 action_name: "action4".to_string(),
697 success: true,
698 });
699
700 assert_eq!(history.len(), 3);
701 let entries: Vec<_> = history.iter().collect();
703 assert_eq!(entries[0].action_name, "action2");
704 }
705
706 #[test]
707 fn test_local_cache() {
708 let mut cache = LocalCache::default();
709
710 cache.set("key1", vec![1, 2, 3], 10);
711 cache.set("key2", vec![4, 5, 6], 5);
712
713 assert_eq!(cache.get("key1", 0), Some([1u8, 2, 3].as_slice()));
715 assert_eq!(cache.get("key2", 4), Some([4u8, 5, 6].as_slice()));
716
717 assert!(cache.get("key2", 5).is_none());
719 assert!(cache.get("key2", 10).is_none());
720
721 assert_eq!(cache.get("key1", 9), Some([1u8, 2, 3].as_slice()));
723
724 cache.cleanup(6);
726 assert!(cache.get("key1", 0).is_some()); cache.cleanup(11);
728 assert!(cache.get("key1", 0).is_none()); }
730
731 #[test]
732 fn test_environment() {
733 let mut env = Environment::default();
734 env.variables
735 .insert("PATH".to_string(), "/usr/bin".to_string());
736 env.flags.insert("debug".to_string(), true);
737
738 assert_eq!(env.variables.get("PATH"), Some(&"/usr/bin".to_string()));
739 assert_eq!(env.flags.get("debug"), Some(&true));
740 }
741}