Skip to main content

swarm_engine_core/
state.rs

1//! SwarmEngine State - 2層メモリモデル
2//!
3//! 共有領域(SharedState)と個別領域(WorkerStates)を分離し、
4//! 並列Workerがそれぞれの個別Stateに高速アクセスできる構造。
5
6use std::any::Any;
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::time::Duration;
9
10use rayon::prelude::*;
11
12use crate::async_task::TaskStatus;
13use crate::extensions::Extensions;
14use crate::online_stats::SwarmStats;
15use crate::types::{AgentId, TaskId, WorkerId};
16
17// ============================================================================
18// LlmStats - LLM 呼び出し統計
19// ============================================================================
20
21/// LLM 呼び出し統計
22#[derive(Debug, Clone, Default)]
23pub struct LlmStats {
24    /// 総呼び出し回数
25    pub invocations: u64,
26    /// エラー回数
27    pub errors: u64,
28    /// 総実行時間
29    pub total_duration: Duration,
30}
31
32impl LlmStats {
33    /// 成功率
34    pub fn success_rate(&self) -> f64 {
35        if self.invocations == 0 {
36            1.0
37        } else {
38            (self.invocations - self.errors) as f64 / self.invocations as f64
39        }
40    }
41
42    /// LLM 呼び出しを記録
43    pub fn record(&mut self, success: bool, duration: Duration) {
44        self.invocations += 1;
45        self.total_duration += duration;
46        if !success {
47            self.errors += 1;
48        }
49    }
50}
51
52/// State 全体
53pub struct SwarmState {
54    /// 共有領域 - 全Worker が ReadOnly 参照
55    pub shared: SharedState,
56    /// 個別領域 - Worker ごとの State(並列アクセス用)
57    pub workers: WorkerStates,
58}
59
60impl SwarmState {
61    pub fn new(worker_count: usize) -> Self {
62        Self {
63            shared: SharedState::default(),
64            workers: WorkerStates::new(worker_count),
65        }
66    }
67
68    /// Tick を進める
69    pub fn advance_tick(&mut self) {
70        self.shared.tick += 1;
71    }
72}
73
74/// 共有領域
75#[derive(Default)]
76pub struct SharedState {
77    /// 環境情報
78    pub environment: Environment,
79    /// 統合統計(ActionEvent ベース)
80    pub stats: SwarmStats,
81    /// Tick番号
82    pub tick: u64,
83    /// Agent間で共有するデータ(ReadOnly)
84    pub shared_data: SharedData,
85    /// 動的リソース(DB接続、HTTPクライアント等)
86    pub extensions: Extensions,
87    /// 平均Tick時間(ナノ秒)- EMA で計算
88    pub avg_tick_duration_ns: u64,
89    /// Goal Action(Terminal Action)を達成した Worker
90    pub done_workers: HashSet<WorkerId>,
91    /// Environment が Done(success=true) を返したかどうか
92    pub environment_done: bool,
93    /// LLM 呼び出し統計
94    pub llm_stats: LlmStats,
95}
96
97
98impl SharedState {
99    /// Worker を完了済みとしてマーク
100    pub fn mark_worker_done(&mut self, worker_id: WorkerId) {
101        self.done_workers.insert(worker_id);
102        self.environment_done = true;
103    }
104
105    /// Worker が完了済みかどうか
106    pub fn is_worker_done(&self, worker_id: WorkerId) -> bool {
107        self.done_workers.contains(&worker_id)
108    }
109
110    /// 環境が完了したかどうか
111    pub fn is_environment_done(&self) -> bool {
112        self.environment_done
113    }
114
115    /// LLM 呼び出し回数
116    pub fn llm_invocations(&self) -> u64 {
117        self.llm_stats.invocations
118    }
119
120    /// LLM エラー回数
121    pub fn llm_errors(&self) -> u64 {
122        self.llm_stats.errors
123    }
124}
125
126/// 環境情報
127#[derive(Default)]
128pub struct Environment {
129    /// 環境変数
130    pub variables: HashMap<String, String>,
131    /// 設定フラグ
132    pub flags: HashMap<String, bool>,
133}
134
135// ============================================================================
136// Tick Snapshot - 完全履歴記録
137// ============================================================================
138
139/// Tick 毎のスナップショット(完全履歴)
140///
141/// 各Tickで何が起きたかを全て記録。Eval等での詳細分析に使用。
142#[derive(Debug, Clone)]
143pub struct TickSnapshot {
144    /// Tick 番号
145    pub tick: u64,
146    /// 処理時間
147    pub duration: std::time::Duration,
148    /// Manager フェーズの記録(起動した場合のみ)
149    pub manager_phase: Option<ManagerPhaseSnapshot>,
150    /// 全 Worker の結果
151    pub worker_results: Vec<WorkerResultSnapshot>,
152}
153
154/// Manager フェーズのスナップショット
155#[derive(Debug, Clone)]
156pub struct ManagerPhaseSnapshot {
157    /// Batch リクエスト(Manager が何を聞いたか)
158    pub batch_request: crate::agent::BatchDecisionRequest,
159    /// LLM からの返答
160    pub responses: Vec<(crate::types::WorkerId, crate::agent::DecisionResponse)>,
161    /// 発行された Guidance(Worker毎)
162    pub guidances: std::collections::HashMap<crate::types::WorkerId, crate::agent::Guidance>,
163    /// LLM呼び出しエラー数(サーバー停止等)
164    pub llm_errors: u64,
165}
166
167/// Worker の結果スナップショット
168#[derive(Debug, Clone)]
169pub struct WorkerResultSnapshot {
170    pub worker_id: crate::types::WorkerId,
171    /// 受け取った Guidance(あれば)
172    pub guidance_received: Option<crate::agent::Guidance>,
173    /// 実行結果
174    pub result: WorkResultSnapshot,
175}
176
177/// WorkResult のスナップショット版(Clone可能)
178#[derive(Debug, Clone)]
179pub enum WorkResultSnapshot {
180    /// 行動した
181    Acted {
182        action_result: ActionResultSnapshot,
183        state_delta: Option<crate::agent::WorkerStateDelta>,
184    },
185    /// 継続中
186    Continuing { progress: f32 },
187    /// Guidance 要求
188    NeedsGuidance {
189        reason: String,
190        context: crate::agent::GuidanceContext,
191    },
192    /// Escalation 要求
193    Escalate {
194        reason: crate::agent::EscalationReason,
195        context: Option<String>,
196    },
197    /// 待機
198    Idle,
199    /// タスク完了
200    Done {
201        success: bool,
202        message: Option<String>,
203    },
204}
205
206/// ActionResult のスナップショット版(Clone可能)
207#[derive(Debug, Clone)]
208pub struct ActionResultSnapshot {
209    pub success: bool,
210    /// output のテキスト表現
211    pub output_debug: Option<String>,
212    pub duration: std::time::Duration,
213    pub error: Option<String>,
214}
215
216impl ActionResultSnapshot {
217    /// ActionResult から変換
218    pub fn from_action_result(result: &crate::types::ActionResult) -> Self {
219        Self {
220            success: result.success,
221            output_debug: result.output.as_ref().map(|o| o.as_text()),
222            duration: result.duration,
223            error: result.error.clone(),
224        }
225    }
226}
227
228/// SharedData の env エントリ保持数(デフォルト)
229const DEFAULT_MAX_ENV_ENTRIES: usize = 500;
230
231/// Agent間共有データ
232pub struct SharedData {
233    /// key-value ストア
234    pub kv: HashMap<String, Vec<u8>>,
235    /// 完了した非同期タスク一覧(Meta のみ)
236    /// payload は AsyncTaskSystem に保持され、ReadPayload Action で取得
237    pub completed_async_tasks: Vec<CompletedAsyncTask>,
238    /// env エントリの最大保持数
239    max_env_entries: usize,
240}
241
242impl Default for SharedData {
243    fn default() -> Self {
244        Self {
245            kv: HashMap::new(),
246            completed_async_tasks: Vec::new(),
247            max_env_entries: DEFAULT_MAX_ENV_ENTRIES,
248        }
249    }
250}
251
252impl SharedData {
253    /// env:{worker_id}:{tick} 形式の古いエントリをクリーンアップ
254    ///
255    /// kv 内の env:* エントリが max_env_entries を超えた場合、
256    /// tick が古いものから削除する。
257    pub fn cleanup_env_entries(&mut self) {
258        // env:* エントリを収集
259        let mut env_entries: Vec<(String, u64)> = self
260            .kv
261            .keys()
262            .filter(|k| k.starts_with("env:"))
263            .filter_map(|k| {
264                // env:{worker_id}:{tick} から tick を抽出
265                k.rsplit(':')
266                    .next()?
267                    .parse::<u64>()
268                    .ok()
269                    .map(|tick| (k.clone(), tick))
270            })
271            .collect();
272
273        if env_entries.len() <= self.max_env_entries {
274            return;
275        }
276
277        // tick でソート(古い順)
278        env_entries.sort_by_key(|(_, tick)| *tick);
279
280        // 超過分を削除
281        let remove_count = env_entries.len() - self.max_env_entries;
282        for (key, _) in env_entries.into_iter().take(remove_count) {
283            self.kv.remove(&key);
284        }
285    }
286
287    /// env エントリの最大保持数を設定
288    pub fn set_max_env_entries(&mut self, max: usize) {
289        self.max_env_entries = max;
290    }
291}
292
293/// 完了した非同期タスクのメタ情報
294///
295/// payload は含まず、完了通知のみ。Worker は ReadPayload Action で payload を取得する。
296#[derive(Debug, Clone)]
297pub struct CompletedAsyncTask {
298    /// タスクID
299    pub task_id: TaskId,
300    /// 発行した Worker(None = Manager 等から発行)
301    pub worker_id: Option<WorkerId>,
302    /// タスク種別("web_search", "llm_call" など)
303    pub task_type: String,
304    /// 完了した Tick
305    pub completed_at_tick: u64,
306    /// ステータス
307    pub status: TaskStatus,
308    /// エラーメッセージ(失敗時)
309    pub error: Option<String>,
310}
311
312/// Worker の永続状態コンテナ - GPU的なメモリ配置
313pub struct WorkerStates {
314    /// 連続メモリ領域に全 Worker の State を配置
315    states: Vec<WorkerState>,
316}
317
318impl WorkerStates {
319    pub fn new(count: usize) -> Self {
320        let states = (0..count).map(|i| WorkerState::new(AgentId(i))).collect();
321        Self { states }
322    }
323
324    /// Worker は自分の State のみ mutable アクセス可能
325    pub fn get_mut(&mut self, id: AgentId) -> Option<&mut WorkerState> {
326        self.states.get_mut(id.0)
327    }
328
329    /// 他 Worker の State は ReadOnly
330    pub fn get(&self, id: AgentId) -> Option<&WorkerState> {
331        self.states.get(id.0)
332    }
333
334    /// Worker 数を取得
335    pub fn len(&self) -> usize {
336        self.states.len()
337    }
338
339    /// 空かどうか
340    pub fn is_empty(&self) -> bool {
341        self.states.is_empty()
342    }
343
344    /// イテレーション
345    pub fn iter(&self) -> impl Iterator<Item = &WorkerState> {
346        self.states.iter()
347    }
348
349    /// 可変イテレーション
350    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut WorkerState> {
351        self.states.iter_mut()
352    }
353
354    /// 並列イテレーション(Rayon等で並列処理)
355    pub fn par_iter_mut(&mut self) -> impl ParallelIterator<Item = &mut WorkerState> {
356        self.states.par_iter_mut()
357    }
358}
359
360/// Escalation 理由
361#[derive(Debug, Clone, PartialEq, Eq)]
362pub enum EscalationReason {
363    /// 連続失敗
364    ConsecutiveFailures(u32),
365    /// リソース不足
366    ResourceExhausted,
367    /// タイムアウト
368    Timeout,
369    /// Agent からの申告(Agent が自発的に介入を要求)
370    AgentRequested(String),
371    /// 不明なエラー
372    Unknown(String),
373}
374
375/// Escalation 情報
376#[derive(Debug, Clone)]
377pub struct Escalation {
378    /// 理由
379    pub reason: EscalationReason,
380    /// 発生 tick
381    pub raised_at_tick: u64,
382    /// 追加コンテキスト
383    pub context: Option<String>,
384}
385
386impl Escalation {
387    pub fn consecutive_failures(count: u32, tick: u64) -> Self {
388        Self {
389            reason: EscalationReason::ConsecutiveFailures(count),
390            raised_at_tick: tick,
391            context: None,
392        }
393    }
394
395    pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
396        self.context = Some(ctx.into());
397        self
398    }
399}
400
401/// Worker の永続状態(Tick間で保持される)
402///
403/// WorkerStateDelta 経由で更新される。
404pub struct WorkerState {
405    /// Agent ID(内部識別用)
406    pub id: AgentId,
407    /// 内部状態(Worker 固有)
408    internal_state: Option<Box<dyn Any + Send + Sync>>,
409    /// 行動履歴
410    pub history: ActionHistory,
411    /// ローカルキャッシュ
412    pub cache: LocalCache,
413    /// 保留中の非同期Task(HashSet で O(1) 検索)
414    pub pending_tasks: HashSet<TaskId>,
415    /// Escalation 情報(Worker が設定、Manager が読み取り)
416    pub escalation: Option<Escalation>,
417    /// 連続失敗カウント
418    pub consecutive_failures: u32,
419    /// 最新アクションの出力(Environment からの結果)
420    pub last_output: Option<String>,
421}
422
423impl WorkerState {
424    pub fn new(id: AgentId) -> Self {
425        Self {
426            id,
427            internal_state: None,
428            history: ActionHistory::default(),
429            cache: LocalCache::default(),
430            pending_tasks: HashSet::new(),
431            escalation: None,
432            consecutive_failures: 0,
433            last_output: None,
434        }
435    }
436
437    /// Escalation を発生させる
438    pub fn raise_escalation(&mut self, escalation: Escalation) {
439        self.escalation = Some(escalation);
440    }
441
442    /// Escalation をクリア
443    pub fn clear_escalation(&mut self) {
444        self.escalation = None;
445        self.consecutive_failures = 0;
446    }
447
448    /// 失敗を記録し、閾値を超えたら Escalation
449    pub fn record_failure(&mut self, tick: u64, threshold: u32) -> bool {
450        self.consecutive_failures += 1;
451        if self.consecutive_failures >= threshold && self.escalation.is_none() {
452            self.raise_escalation(Escalation::consecutive_failures(
453                self.consecutive_failures,
454                tick,
455            ));
456            true
457        } else {
458            false
459        }
460    }
461
462    /// 成功を記録(連続失敗をリセット)
463    pub fn record_success(&mut self) {
464        self.consecutive_failures = 0;
465    }
466
467    /// 内部状態を設定
468    pub fn set_state<T: Any + Send + Sync + 'static>(&mut self, state: T) {
469        self.internal_state = Some(Box::new(state));
470    }
471
472    /// 内部状態を取得
473    pub fn get_state<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
474        self.internal_state.as_ref()?.downcast_ref()
475    }
476
477    /// 内部状態を可変で取得
478    pub fn get_state_mut<T: Any + Send + Sync + 'static>(&mut self) -> Option<&mut T> {
479        self.internal_state.as_mut()?.downcast_mut()
480    }
481
482    /// 非同期タスクを追加(O(1))
483    pub fn add_pending_task(&mut self, task_id: TaskId) {
484        self.pending_tasks.insert(task_id);
485    }
486
487    /// 完了したタスクを削除(O(1))
488    pub fn complete_task(&mut self, task_id: TaskId) {
489        self.pending_tasks.remove(&task_id);
490    }
491}
492
493/// 行動履歴
494///
495/// VecDeque を使用したリングバッファ実装。
496/// push() が O(1) で効率的に動作する。
497pub struct ActionHistory {
498    /// 最近のアクション(リングバッファ)
499    entries: VecDeque<HistoryEntry>,
500    /// 最大保持数
501    max_entries: usize,
502}
503
504impl Default for ActionHistory {
505    fn default() -> Self {
506        Self::new(100) // デフォルト: 100エントリ
507    }
508}
509
510impl ActionHistory {
511    pub fn new(max_entries: usize) -> Self {
512        Self {
513            entries: VecDeque::with_capacity(max_entries),
514            max_entries,
515        }
516    }
517
518    /// エントリを追加(O(1))
519    pub fn push(&mut self, entry: HistoryEntry) {
520        if self.max_entries > 0 && self.entries.len() >= self.max_entries {
521            self.entries.pop_front(); // O(1) - VecDeque の利点
522        }
523        self.entries.push_back(entry);
524    }
525
526    /// 最新のエントリを取得
527    pub fn latest(&self) -> Option<&HistoryEntry> {
528        self.entries.back()
529    }
530
531    /// エントリ数を取得
532    pub fn len(&self) -> usize {
533        self.entries.len()
534    }
535
536    /// 空かどうか
537    pub fn is_empty(&self) -> bool {
538        self.entries.is_empty()
539    }
540
541    /// イテレータを取得
542    pub fn iter(&self) -> impl Iterator<Item = &HistoryEntry> {
543        self.entries.iter()
544    }
545}
546
547/// 履歴エントリ
548#[derive(Debug, Clone)]
549pub struct HistoryEntry {
550    pub tick: u64,
551    pub action_name: String,
552    pub success: bool,
553}
554
555/// ローカルキャッシュ
556#[derive(Default)]
557pub struct LocalCache {
558    /// key-value キャッシュ
559    data: HashMap<String, CacheEntry>,
560}
561
562impl LocalCache {
563    /// キャッシュに設定
564    pub fn set(&mut self, key: impl Into<String>, value: Vec<u8>, ttl_ticks: u64) {
565        self.data.insert(
566            key.into(),
567            CacheEntry {
568                value,
569                expires_at_tick: ttl_ticks,
570            },
571        );
572    }
573
574    /// キャッシュから取得
575    pub fn get(&self, key: &str, current_tick: u64) -> Option<&[u8]> {
576        let entry = self.data.get(key)?;
577        if entry.expires_at_tick > current_tick {
578            Some(&entry.value)
579        } else {
580            None
581        }
582    }
583
584    /// 期限切れエントリを削除
585    pub fn cleanup(&mut self, current_tick: u64) {
586        self.data.retain(|_, v| v.expires_at_tick > current_tick);
587    }
588}
589
590/// キャッシュエントリ
591struct CacheEntry {
592    value: Vec<u8>,
593    expires_at_tick: u64,
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599
600    #[test]
601    fn test_swarm_state_creation() {
602        let state = SwarmState::new(3);
603        assert_eq!(state.workers.len(), 3);
604        assert_eq!(state.shared.tick, 0);
605    }
606
607    #[test]
608    fn test_swarm_state_advance_tick() {
609        let mut state = SwarmState::new(1);
610        assert_eq!(state.shared.tick, 0);
611
612        state.advance_tick();
613        assert_eq!(state.shared.tick, 1);
614
615        state.advance_tick();
616        assert_eq!(state.shared.tick, 2);
617    }
618
619    #[test]
620    fn test_worker_states_access() {
621        let mut states = WorkerStates::new(3);
622        assert_eq!(states.len(), 3);
623        assert!(!states.is_empty());
624
625        // get_mut でアクセス
626        let ws = states.get_mut(AgentId(1)).unwrap();
627        assert_eq!(ws.id.0, 1);
628
629        // 存在しない ID
630        assert!(states.get(AgentId(10)).is_none());
631    }
632
633    #[test]
634    fn test_worker_state_internal() {
635        let mut ws = WorkerState::new(AgentId(0));
636
637        // 初期状態は None
638        assert!(ws.get_state::<i32>().is_none());
639
640        // 状態を設定
641        ws.set_state(42i32);
642        assert_eq!(ws.get_state::<i32>(), Some(&42));
643
644        // 可変アクセス
645        if let Some(state) = ws.get_state_mut::<i32>() {
646            *state = 100;
647        }
648        assert_eq!(ws.get_state::<i32>(), Some(&100));
649
650        // 型が違う場合は None
651        assert!(ws.get_state::<String>().is_none());
652    }
653
654    #[test]
655    fn test_worker_state_pending_tasks() {
656        let mut ws = WorkerState::new(AgentId(0));
657        assert!(ws.pending_tasks.is_empty());
658
659        ws.add_pending_task(TaskId(1));
660        ws.add_pending_task(TaskId(2));
661        assert_eq!(ws.pending_tasks.len(), 2);
662        assert!(ws.pending_tasks.contains(&TaskId(1)));
663        assert!(ws.pending_tasks.contains(&TaskId(2)));
664
665        ws.complete_task(TaskId(1));
666        assert_eq!(ws.pending_tasks.len(), 1);
667        assert!(!ws.pending_tasks.contains(&TaskId(1)));
668        assert!(ws.pending_tasks.contains(&TaskId(2)));
669    }
670
671    #[test]
672    fn test_action_history() {
673        let mut history = ActionHistory::new(3);
674
675        history.push(HistoryEntry {
676            tick: 0,
677            action_name: "action1".to_string(),
678            success: true,
679        });
680        history.push(HistoryEntry {
681            tick: 1,
682            action_name: "action2".to_string(),
683            success: false,
684        });
685
686        assert_eq!(history.len(), 2);
687        assert_eq!(history.latest().unwrap().action_name, "action2");
688
689        // 最大数を超えると古いものが削除される
690        history.push(HistoryEntry {
691            tick: 2,
692            action_name: "action3".to_string(),
693            success: true,
694        });
695        history.push(HistoryEntry {
696            tick: 3,
697            action_name: "action4".to_string(),
698            success: true,
699        });
700
701        assert_eq!(history.len(), 3);
702        // action1 は削除され、action2 が先頭になる
703        let entries: Vec<_> = history.iter().collect();
704        assert_eq!(entries[0].action_name, "action2");
705    }
706
707    #[test]
708    fn test_local_cache() {
709        let mut cache = LocalCache::default();
710
711        cache.set("key1", vec![1, 2, 3], 10);
712        cache.set("key2", vec![4, 5, 6], 5);
713
714        // 有効期限内
715        assert_eq!(cache.get("key1", 0), Some([1u8, 2, 3].as_slice()));
716        assert_eq!(cache.get("key2", 4), Some([4u8, 5, 6].as_slice()));
717
718        // 有効期限切れ
719        assert!(cache.get("key2", 5).is_none());
720        assert!(cache.get("key2", 10).is_none());
721
722        // key1 はまだ有効
723        assert_eq!(cache.get("key1", 9), Some([1u8, 2, 3].as_slice()));
724
725        // cleanup
726        cache.cleanup(6);
727        assert!(cache.get("key1", 0).is_some()); // まだ残っている
728        cache.cleanup(11);
729        assert!(cache.get("key1", 0).is_none()); // 削除された
730    }
731
732    #[test]
733    fn test_environment() {
734        let mut env = Environment::default();
735        env.variables
736            .insert("PATH".to_string(), "/usr/bin".to_string());
737        env.flags.insert("debug".to_string(), true);
738
739        assert_eq!(env.variables.get("PATH"), Some(&"/usr/bin".to_string()));
740        assert_eq!(env.flags.get("debug"), Some(&true));
741    }
742}