Skip to main content

swarm_engine_core/learn/daemon/
processor.rs

1//! Processor - 学習実行
2//!
3//! Trigger 発火時に呼び出され、以下の処理を実行:
4//! - Offline 分析 → OfflineModel 生成
5//! - LoRA 学習 → TrainedModel 生成
6
7use std::sync::Arc;
8
9use crate::learn::learn_model::LearnModel;
10use crate::learn::lora::{LoraTrainer, LoraTrainerError, TrainedModel};
11use crate::learn::offline::OfflineModel;
12use crate::learn::snapshot::LearningStore;
13use crate::learn::store::{EpisodeStore, StoreError};
14
15// ============================================================================
16// ProcessorMode
17// ============================================================================
18
19/// 処理モード
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
21pub enum ProcessorMode {
22    /// Offline 分析のみ(OfflineModel 生成)
23    #[default]
24    OfflineOnly,
25    /// LoRA 学習のみ(TrainedModel 生成)
26    LoraOnly,
27    /// 両方実行
28    Full,
29}
30
31impl std::str::FromStr for ProcessorMode {
32    type Err = String;
33
34    fn from_str(s: &str) -> Result<Self, Self::Err> {
35        match s.to_lowercase().as_str() {
36            "offline" | "offline_only" => Ok(Self::OfflineOnly),
37            "lora" | "lora_only" => Ok(Self::LoraOnly),
38            "full" | "both" => Ok(Self::Full),
39            _ => Err(format!("Unknown processor mode: {}", s)),
40        }
41    }
42}
43
44impl std::fmt::Display for ProcessorMode {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            Self::OfflineOnly => write!(f, "offline"),
48            Self::LoraOnly => write!(f, "lora"),
49            Self::Full => write!(f, "full"),
50        }
51    }
52}
53
54// ============================================================================
55// ProcessResult
56// ============================================================================
57
58/// 処理結果
59#[derive(Debug)]
60pub enum ProcessResult {
61    /// Offline 分析結果
62    Offline(OfflineModel),
63    /// LoRA 学習結果
64    Lora(TrainedModel),
65    /// 両方の結果
66    Full {
67        offline: OfflineModel,
68        lora: TrainedModel,
69    },
70}
71
72impl ProcessResult {
73    /// LoRA モデルを取得(あれば)
74    pub fn lora_model(&self) -> Option<&TrainedModel> {
75        match self {
76            Self::Lora(m) => Some(m),
77            Self::Full { lora, .. } => Some(lora),
78            Self::Offline(_) => None,
79        }
80    }
81
82    /// Offline モデルを取得(あれば)
83    pub fn offline_model(&self) -> Option<&OfflineModel> {
84        match self {
85            Self::Offline(m) => Some(m),
86            Self::Full { offline, .. } => Some(offline),
87            Self::Lora(_) => None,
88        }
89    }
90}
91
92// ============================================================================
93// ProcessorError
94// ============================================================================
95
96/// Processor のエラー型
97#[derive(Debug)]
98pub enum ProcessorError {
99    /// Store エラー
100    Store(StoreError),
101    /// LoRA Trainer エラー
102    LoraTrainer(LoraTrainerError),
103    /// IO エラー
104    Io(std::io::Error),
105    /// データ不足
106    InsufficientData(String),
107    /// その他
108    Other(String),
109}
110
111impl std::fmt::Display for ProcessorError {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        match self {
114            Self::Store(e) => write!(f, "Store error: {}", e),
115            Self::LoraTrainer(e) => write!(f, "LoRA trainer error: {}", e),
116            Self::Io(e) => write!(f, "IO error: {}", e),
117            Self::InsufficientData(msg) => write!(f, "Insufficient data: {}", msg),
118            Self::Other(msg) => write!(f, "{}", msg),
119        }
120    }
121}
122
123impl std::error::Error for ProcessorError {}
124
125impl From<StoreError> for ProcessorError {
126    fn from(e: StoreError) -> Self {
127        Self::Store(e)
128    }
129}
130
131impl From<LoraTrainerError> for ProcessorError {
132    fn from(e: LoraTrainerError) -> Self {
133        Self::LoraTrainer(e)
134    }
135}
136
137impl From<std::io::Error> for ProcessorError {
138    fn from(e: std::io::Error) -> Self {
139        Self::Io(e)
140    }
141}
142
143// ============================================================================
144// ProcessorConfig
145// ============================================================================
146
147/// Processor の設定
148#[derive(Debug, Clone)]
149pub struct ProcessorConfig {
150    /// 処理モード
151    pub mode: ProcessorMode,
152    /// シナリオ名(Offline 分析用)
153    pub scenario: String,
154    /// Offline 分析に使用するセッション数
155    pub max_sessions: usize,
156}
157
158impl Default for ProcessorConfig {
159    fn default() -> Self {
160        Self {
161            mode: ProcessorMode::OfflineOnly,
162            scenario: "default".to_string(),
163            max_sessions: 20,
164        }
165    }
166}
167
168impl ProcessorConfig {
169    /// 新しい設定を作成
170    pub fn new(scenario: impl Into<String>) -> Self {
171        Self {
172            scenario: scenario.into(),
173            ..Default::default()
174        }
175    }
176
177    /// 処理モードを設定
178    pub fn mode(mut self, mode: ProcessorMode) -> Self {
179        self.mode = mode;
180        self
181    }
182
183    /// 最大セッション数を設定
184    pub fn max_sessions(mut self, n: usize) -> Self {
185        self.max_sessions = n;
186        self
187    }
188}
189
190// ============================================================================
191// Processor
192// ============================================================================
193
194/// 学習処理を実行
195pub struct Processor {
196    /// 設定
197    config: ProcessorConfig,
198    /// LearningStore(Offline 分析用)
199    learning_store: Option<LearningStore>,
200    /// LoRA Trainer(LoRA 学習用)
201    lora_trainer: Option<LoraTrainer>,
202    /// LearnModel(LoRA 学習用)
203    learn_model: Option<Arc<dyn LearnModel>>,
204}
205
206impl Processor {
207    /// 新しい Processor を作成
208    pub fn new(config: ProcessorConfig) -> Self {
209        Self {
210            config,
211            learning_store: None,
212            lora_trainer: None,
213            learn_model: None,
214        }
215    }
216
217    /// LearningStore を設定(Offline 分析用)
218    pub fn with_learning_store(mut self, store: LearningStore) -> Self {
219        self.learning_store = Some(store);
220        self
221    }
222
223    /// LoRA Trainer を設定
224    pub fn with_lora_trainer(mut self, trainer: LoraTrainer) -> Self {
225        self.lora_trainer = Some(trainer);
226        self
227    }
228
229    /// LearnModel を設定(LoRA 学習用)
230    pub fn with_learn_model(mut self, model: Arc<dyn LearnModel>) -> Self {
231        self.learn_model = Some(model);
232        self
233    }
234
235    /// 設定を取得
236    pub fn config(&self) -> &ProcessorConfig {
237        &self.config
238    }
239
240    /// 学習処理を実行
241    pub async fn run(
242        &self,
243        episode_store: &dyn EpisodeStore,
244    ) -> Result<ProcessResult, ProcessorError> {
245        tracing::info!(
246            mode = %self.config.mode,
247            scenario = %self.config.scenario,
248            "Starting learning process"
249        );
250
251        match self.config.mode {
252            ProcessorMode::OfflineOnly => {
253                let model = self.run_offline()?;
254                Ok(ProcessResult::Offline(model))
255            }
256            ProcessorMode::LoraOnly => {
257                let model = self.run_lora(episode_store).await?;
258                Ok(ProcessResult::Lora(model))
259            }
260            ProcessorMode::Full => {
261                let offline = self.run_offline()?;
262                let lora = self.run_lora(episode_store).await?;
263                Ok(ProcessResult::Full { offline, lora })
264            }
265        }
266    }
267
268    /// Offline 分析を実行
269    fn run_offline(&self) -> Result<OfflineModel, ProcessorError> {
270        let store = self.learning_store.as_ref().ok_or_else(|| {
271            ProcessorError::Other("LearningStore not configured for offline analysis".into())
272        })?;
273
274        tracing::info!(
275            scenario = %self.config.scenario,
276            max_sessions = self.config.max_sessions,
277            "Running offline analysis"
278        );
279
280        let model = store.run_offline_learning(&self.config.scenario, self.config.max_sessions)?;
281
282        tracing::info!(
283            analyzed_sessions = model.analyzed_sessions,
284            ucb1_c = model.parameters.ucb1_c,
285            "Offline analysis completed"
286        );
287
288        Ok(model)
289    }
290
291    /// LoRA 学習を実行
292    async fn run_lora(
293        &self,
294        episode_store: &dyn EpisodeStore,
295    ) -> Result<TrainedModel, ProcessorError> {
296        let trainer = self
297            .lora_trainer
298            .as_ref()
299            .ok_or_else(|| ProcessorError::Other("LoraTrainer not configured".into()))?;
300
301        let learn_model = self.learn_model.as_ref().ok_or_else(|| {
302            ProcessorError::Other("LearnModel not configured for LoRA training".into())
303        })?;
304
305        // Episode 数を確認
306        let episode_count = episode_store.count(None)?;
307        if episode_count == 0 {
308            return Err(ProcessorError::InsufficientData(
309                "No episodes available for LoRA training".into(),
310            ));
311        }
312
313        tracing::info!(
314            episode_count,
315            learn_model = learn_model.name(),
316            "Running LoRA training"
317        );
318
319        let model = trainer.train(learn_model.as_ref(), None).await?;
320
321        tracing::info!(
322            model_id = %model.id,
323            sample_count = model.sample_count,
324            "LoRA training completed"
325        );
326
327        Ok(model)
328    }
329}
330
331// ============================================================================
332// Tests
333// ============================================================================
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_processor_mode_from_str() {
341        assert_eq!(
342            "offline".parse::<ProcessorMode>().unwrap(),
343            ProcessorMode::OfflineOnly
344        );
345        assert_eq!(
346            "lora".parse::<ProcessorMode>().unwrap(),
347            ProcessorMode::LoraOnly
348        );
349        assert_eq!(
350            "full".parse::<ProcessorMode>().unwrap(),
351            ProcessorMode::Full
352        );
353        assert!("invalid".parse::<ProcessorMode>().is_err());
354    }
355
356    #[test]
357    fn test_processor_config_builder() {
358        let config = ProcessorConfig::new("test-scenario")
359            .mode(ProcessorMode::Full)
360            .max_sessions(50);
361
362        assert_eq!(config.scenario, "test-scenario");
363        assert_eq!(config.mode, ProcessorMode::Full);
364        assert_eq!(config.max_sessions, 50);
365    }
366
367    #[test]
368    fn test_process_result_accessors() {
369        // Offline only
370        let offline_model = OfflineModel::default();
371        let result = ProcessResult::Offline(offline_model);
372        assert!(result.offline_model().is_some());
373        assert!(result.lora_model().is_none());
374    }
375}