Skip to main content

swarm_engine_core/learn/
record.rs

1//! Record - 生イベントの抽象化
2//!
3//! ## 設計思想
4//!
5//! ActionEvent と LlmDebugEvent を統一的に扱うための抽象化層。
6//! Episode は Record のコレクションから構築される。
7//!
8//! ```text
9//! ActionEvent ──┐
10//!               ├──▶ Record ──▶ Episode
11//! LlmDebugEvent ┘
12//! ```
13
14use serde::{Deserialize, Serialize};
15
16use crate::events::ActionEvent;
17use crate::util::epoch_millis;
18
19// ============================================================================
20// Record
21// ============================================================================
22
23/// 生イベントから変換された Record
24///
25/// ActionEvent と LlmDebugEvent を統一的に扱うための抽象化。
26/// EpisodeContext は Record のリストを保持する。
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum Record {
29    /// ActionEvent から変換
30    Action(ActionRecord),
31    /// LlmDebugEvent から変換
32    Llm(LlmCallRecord),
33}
34
35impl Record {
36    /// Action Record かどうか
37    pub fn is_action(&self) -> bool {
38        matches!(self, Self::Action(_))
39    }
40
41    /// Llm Record かどうか
42    pub fn is_llm(&self) -> bool {
43        matches!(self, Self::Llm(_))
44    }
45
46    /// ActionRecord を取得
47    pub fn as_action(&self) -> Option<&ActionRecord> {
48        match self {
49            Self::Action(r) => Some(r),
50            _ => None,
51        }
52    }
53
54    /// LlmCallRecord を取得
55    pub fn as_llm(&self) -> Option<&LlmCallRecord> {
56        match self {
57            Self::Llm(r) => Some(r),
58            _ => None,
59        }
60    }
61
62    /// Worker ID を取得(両方のレコードタイプから)
63    pub fn worker_id(&self) -> Option<usize> {
64        match self {
65            Self::Action(r) => Some(r.worker_id),
66            Self::Llm(r) => r.worker_id,
67        }
68    }
69
70    /// タイムスタンプを取得(ソート用)
71    pub fn timestamp_ms(&self) -> u64 {
72        match self {
73            Self::Action(r) => r.tick,
74            Self::Llm(r) => r.timestamp_ms,
75        }
76    }
77}
78
79impl From<ActionRecord> for Record {
80    fn from(record: ActionRecord) -> Self {
81        Self::Action(record)
82    }
83}
84
85impl From<LlmCallRecord> for Record {
86    fn from(record: LlmCallRecord) -> Self {
87        Self::Llm(record)
88    }
89}
90
91impl From<&ActionEvent> for Record {
92    fn from(event: &ActionEvent) -> Self {
93        Self::Action(ActionRecord::from(event))
94    }
95}
96
97// ============================================================================
98// FromRecord - 型安全なクエリのための Trait
99// ============================================================================
100
101/// Record から特定の型を抽出するための Trait
102///
103/// 新しい Record 種別を追加したら、この Trait を実装することで
104/// EpisodeContext::iter::<T>() でクエリ可能になる。
105pub trait FromRecord: Sized {
106    fn from_record(record: &Record) -> Option<&Self>;
107}
108
109impl FromRecord for ActionRecord {
110    fn from_record(record: &Record) -> Option<&Self> {
111        record.as_action()
112    }
113}
114
115impl FromRecord for LlmCallRecord {
116    fn from_record(record: &Record) -> Option<&Self> {
117        record.as_llm()
118    }
119}
120
121// ============================================================================
122// ActionRecord
123// ============================================================================
124
125/// アクション実行の記録
126///
127/// ActionEvent から変換可能。
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ActionRecord {
130    /// Tick
131    pub tick: u64,
132    /// Worker ID
133    pub worker_id: usize,
134    /// アクション名
135    pub action: String,
136    /// ターゲット
137    pub target: Option<String>,
138    /// 成功/失敗
139    pub success: bool,
140    /// エラーメッセージ
141    pub error: Option<String>,
142    /// 実行時間(ミリ秒)
143    pub duration_ms: u64,
144    /// 選択ロジック(UCB1, Greedy 等)
145    pub selection_logic: Option<String>,
146    /// Guidance からの指示だったか
147    pub from_guidance: bool,
148    /// 前回のアクション
149    pub previous_action: Option<String>,
150    /// 使用した LoRA
151    pub lora: Option<String>,
152}
153
154impl ActionRecord {
155    pub fn new(tick: u64, worker_id: usize, action: impl Into<String>) -> Self {
156        Self {
157            tick,
158            worker_id,
159            action: action.into(),
160            target: None,
161            success: true,
162            error: None,
163            duration_ms: 0,
164            selection_logic: None,
165            from_guidance: false,
166            previous_action: None,
167            lora: None,
168        }
169    }
170
171    pub fn target(mut self, target: impl Into<String>) -> Self {
172        self.target = Some(target.into());
173        self
174    }
175
176    pub fn success(mut self, success: bool) -> Self {
177        self.success = success;
178        self
179    }
180
181    pub fn error(mut self, error: impl Into<String>) -> Self {
182        self.error = Some(error.into());
183        self.success = false;
184        self
185    }
186
187    pub fn duration_ms(mut self, duration_ms: u64) -> Self {
188        self.duration_ms = duration_ms;
189        self
190    }
191
192    pub fn selection_logic(mut self, logic: impl Into<String>) -> Self {
193        self.selection_logic = Some(logic.into());
194        self
195    }
196
197    pub fn from_guidance(mut self) -> Self {
198        self.from_guidance = true;
199        self
200    }
201
202    pub fn previous_action(mut self, action: impl Into<String>) -> Self {
203        self.previous_action = Some(action.into());
204        self
205    }
206
207    pub fn lora(mut self, lora: impl Into<String>) -> Self {
208        self.lora = Some(lora.into());
209        self
210    }
211
212    /// 終端アクション(タスク完了)かどうか
213    pub fn is_terminal(&self) -> bool {
214        self.action == "done"
215    }
216}
217
218impl From<&ActionEvent> for ActionRecord {
219    fn from(event: &ActionEvent) -> Self {
220        Self {
221            tick: event.tick,
222            worker_id: event.worker_id.0,
223            action: event.action.clone(),
224            target: event.target.clone(),
225            success: event.result.success,
226            error: event.result.error.clone(),
227            duration_ms: event.duration.as_millis() as u64,
228            selection_logic: event.context.selection_logic.clone(),
229            from_guidance: event.context.from_guidance,
230            previous_action: event.context.previous_action.clone(),
231            lora: event.context.lora.as_ref().and_then(|l| l.name.clone()),
232        }
233    }
234}
235
236// ============================================================================
237// LlmCallRecord
238// ============================================================================
239
240/// LLM呼び出しの記録
241///
242/// LlmDebugEvent から変換可能(swarm-engine-llm crate で実装)。
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct LlmCallRecord {
245    /// 呼び出し種別("decide", "call_raw" 等)
246    pub call_type: String,
247    /// プロンプト
248    pub prompt: String,
249    /// レスポンス
250    pub response: String,
251    /// モデル名
252    pub model: String,
253    /// エンドポイント
254    pub endpoint: String,
255    /// LoRAアダプター名
256    pub lora: Option<String>,
257    /// レイテンシ(ミリ秒)
258    pub latency_ms: u64,
259    /// タイムスタンプ(Unix epoch ms)
260    pub timestamp_ms: u64,
261    /// Worker ID
262    pub worker_id: Option<usize>,
263    /// エラー(あれば)
264    pub error: Option<String>,
265}
266
267impl LlmCallRecord {
268    pub fn new(call_type: impl Into<String>, model: impl Into<String>) -> Self {
269        Self {
270            call_type: call_type.into(),
271            prompt: String::new(),
272            response: String::new(),
273            model: model.into(),
274            endpoint: String::new(),
275            lora: None,
276            latency_ms: 0,
277            timestamp_ms: epoch_millis(),
278            worker_id: None,
279            error: None,
280        }
281    }
282
283    pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
284        self.prompt = prompt.into();
285        self
286    }
287
288    pub fn response(mut self, response: impl Into<String>) -> Self {
289        self.response = response.into();
290        self
291    }
292
293    pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
294        self.endpoint = endpoint.into();
295        self
296    }
297
298    pub fn lora(mut self, lora: impl Into<String>) -> Self {
299        self.lora = Some(lora.into());
300        self
301    }
302
303    pub fn latency_ms(mut self, latency: u64) -> Self {
304        self.latency_ms = latency;
305        self
306    }
307
308    pub fn worker_id(mut self, id: usize) -> Self {
309        self.worker_id = Some(id);
310        self
311    }
312
313    pub fn error(mut self, error: impl Into<String>) -> Self {
314        self.error = Some(error.into());
315        self
316    }
317
318    /// 成功したか(エラーがなく、レスポンスがある)
319    pub fn is_success(&self) -> bool {
320        self.error.is_none() && !self.response.is_empty()
321    }
322}
323
324// ============================================================================
325// RecordStream - Record のコレクション操作
326// ============================================================================
327
328/// Record のストリームを操作するためのヘルパー
329pub struct RecordStream<'a> {
330    records: &'a [Record],
331}
332
333impl<'a> RecordStream<'a> {
334    pub fn new(records: &'a [Record]) -> Self {
335        Self { records }
336    }
337
338    /// Action Records のみ抽出
339    pub fn actions(&self) -> impl Iterator<Item = &ActionRecord> {
340        self.records.iter().filter_map(Record::as_action)
341    }
342
343    /// Llm Records のみ抽出
344    pub fn llm_calls(&self) -> impl Iterator<Item = &LlmCallRecord> {
345        self.records.iter().filter_map(Record::as_llm)
346    }
347
348    /// Worker ID でフィルタ
349    pub fn by_worker(&self, worker_id: usize) -> impl Iterator<Item = &Record> {
350        self.records
351            .iter()
352            .filter(move |r| r.worker_id() == Some(worker_id))
353    }
354
355    /// Worker ID ごとにグルーピング
356    pub fn group_by_worker(&self) -> std::collections::HashMap<usize, Vec<&Record>> {
357        let mut groups = std::collections::HashMap::new();
358        for record in self.records {
359            if let Some(worker_id) = record.worker_id() {
360                groups
361                    .entry(worker_id)
362                    .or_insert_with(Vec::new)
363                    .push(record);
364            }
365        }
366        groups
367    }
368
369    /// タイムスタンプでソートした Iterator
370    pub fn sorted_by_time(&self) -> Vec<&Record> {
371        let mut sorted: Vec<_> = self.records.iter().collect();
372        sorted.sort_by_key(|r| r.timestamp_ms());
373        sorted
374    }
375
376    /// Record 数
377    pub fn len(&self) -> usize {
378        self.records.len()
379    }
380
381    /// 空かどうか
382    pub fn is_empty(&self) -> bool {
383        self.records.is_empty()
384    }
385}
386
387// ============================================================================
388// Tests
389// ============================================================================
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_record_from_action_record() {
397        let action = ActionRecord::new(1, 0, "CheckStatus").success(true);
398        let record = Record::from(action);
399
400        assert!(record.is_action());
401        assert!(!record.is_llm());
402        assert_eq!(record.worker_id(), Some(0));
403    }
404
405    #[test]
406    fn test_record_from_llm_call_record() {
407        let llm = LlmCallRecord::new("decide", "qwen2.5")
408            .worker_id(1)
409            .prompt("test")
410            .response("ok");
411        let record = Record::from(llm);
412
413        assert!(!record.is_action());
414        assert!(record.is_llm());
415        assert_eq!(record.worker_id(), Some(1));
416    }
417
418    #[test]
419    fn test_record_stream_filtering() {
420        let records = vec![
421            Record::from(ActionRecord::new(1, 0, "A").success(true)),
422            Record::from(LlmCallRecord::new("decide", "model").worker_id(0)),
423            Record::from(ActionRecord::new(2, 1, "B").success(true)),
424        ];
425
426        let stream = RecordStream::new(&records);
427
428        assert_eq!(stream.actions().count(), 2);
429        assert_eq!(stream.llm_calls().count(), 1);
430        assert_eq!(stream.by_worker(0).count(), 2);
431        assert_eq!(stream.by_worker(1).count(), 1);
432    }
433}