Skip to main content

swarm_engine_core/learn/daemon/
applier.rs

1//! Applier - 学習済みモデルの適用
2//!
3//! 学習完了後に TrainedModel を llama-server に適用する。
4//! Auto-apply モードでは自動的に適用、それ以外では通知のみ。
5
6use std::sync::Arc;
7
8use crate::learn::lora::{ApplicatorError, ModelApplicator, TrainedModel};
9
10// ============================================================================
11// ApplyMode
12// ============================================================================
13
14/// 適用モード
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum ApplyMode {
17    /// 手動適用(通知のみ)
18    #[default]
19    Manual,
20    /// 自動適用
21    Auto,
22}
23
24// ============================================================================
25// ApplierConfig
26// ============================================================================
27
28/// Applier の設定
29#[derive(Debug, Clone)]
30pub struct ApplierConfig {
31    /// 適用モード
32    pub mode: ApplyMode,
33    /// ロールバック履歴の最大数
34    pub max_history: usize,
35}
36
37impl Default for ApplierConfig {
38    fn default() -> Self {
39        Self {
40            mode: ApplyMode::Manual,
41            max_history: 5,
42        }
43    }
44}
45
46impl ApplierConfig {
47    /// Auto-apply を有効化
48    pub fn auto_apply(mut self) -> Self {
49        self.mode = ApplyMode::Auto;
50        self
51    }
52
53    /// 履歴の最大数を設定
54    pub fn max_history(mut self, n: usize) -> Self {
55        self.max_history = n;
56        self
57    }
58}
59
60// ============================================================================
61// ApplierError
62// ============================================================================
63
64/// Applier のエラー型
65#[derive(Debug)]
66pub enum ApplierError {
67    /// Applicator エラー
68    Applicator(ApplicatorError),
69    /// 適用がスキップされた(手動モード)
70    Skipped(String),
71    /// その他
72    Other(String),
73}
74
75impl std::fmt::Display for ApplierError {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            Self::Applicator(e) => write!(f, "Applicator error: {}", e),
79            Self::Skipped(msg) => write!(f, "Apply skipped: {}", msg),
80            Self::Other(msg) => write!(f, "{}", msg),
81        }
82    }
83}
84
85impl std::error::Error for ApplierError {}
86
87impl From<ApplicatorError> for ApplierError {
88    fn from(e: ApplicatorError) -> Self {
89        Self::Applicator(e)
90    }
91}
92
93// ============================================================================
94// ApplyResult
95// ============================================================================
96
97/// 適用結果
98#[derive(Debug)]
99pub enum ApplyResult {
100    /// 適用成功
101    Applied {
102        model_id: String,
103        previous_model_id: Option<String>,
104    },
105    /// スキップ(手動モード)
106    Skipped { model_id: String, reason: String },
107}
108
109impl ApplyResult {
110    /// 適用が成功したかどうか
111    pub fn is_applied(&self) -> bool {
112        matches!(self, Self::Applied { .. })
113    }
114}
115
116// ============================================================================
117// Applier
118// ============================================================================
119
120/// 学習済みモデルの適用を担当
121pub struct Applier {
122    /// 設定
123    config: ApplierConfig,
124    /// ModelApplicator
125    applicator: Arc<dyn ModelApplicator>,
126    /// 適用履歴(model_id のリスト)
127    history: Vec<String>,
128}
129
130impl Applier {
131    /// 新しい Applier を作成
132    pub fn new(config: ApplierConfig, applicator: Arc<dyn ModelApplicator>) -> Self {
133        Self {
134            config,
135            applicator,
136            history: Vec::new(),
137        }
138    }
139
140    /// 設定を取得
141    pub fn config(&self) -> &ApplierConfig {
142        &self.config
143    }
144
145    /// 適用履歴を取得
146    pub fn history(&self) -> &[String] {
147        &self.history
148    }
149
150    /// モデルを適用(設定に応じて、非同期)
151    pub async fn apply(&mut self, model: &TrainedModel) -> Result<ApplyResult, ApplierError> {
152        match self.config.mode {
153            ApplyMode::Manual => {
154                tracing::info!(
155                    model_id = %model.id,
156                    "Model ready for manual apply (auto-apply disabled)"
157                );
158                Ok(ApplyResult::Skipped {
159                    model_id: model.id.to_string(),
160                    reason: "Auto-apply disabled".into(),
161                })
162            }
163            ApplyMode::Auto => self.apply_now(model).await,
164        }
165    }
166
167    /// モデルを即座に適用(モード関係なく、非同期)
168    pub async fn apply_now(&mut self, model: &TrainedModel) -> Result<ApplyResult, ApplierError> {
169        let previous_model_id = self.applicator.previous_model_id().map(|id| id.to_string());
170
171        tracing::info!(
172            model_id = %model.id,
173            previous = ?previous_model_id,
174            "Applying trained model"
175        );
176
177        self.applicator.apply(model).await?;
178
179        // 履歴に追加
180        self.history.push(model.id.to_string());
181        if self.history.len() > self.config.max_history {
182            self.history.remove(0);
183        }
184
185        tracing::info!(
186            model_id = %model.id,
187            "Model applied successfully"
188        );
189
190        Ok(ApplyResult::Applied {
191            model_id: model.id.to_string(),
192            previous_model_id,
193        })
194    }
195
196    /// 前のモデルにロールバック(非同期)
197    pub async fn rollback(&self) -> Result<(), ApplierError> {
198        let previous_id = self
199            .applicator
200            .previous_model_id()
201            .ok_or_else(|| ApplierError::Other("No previous model to rollback to".into()))?;
202
203        tracing::info!(target_id = %previous_id, "Rolling back to previous model");
204        self.applicator.rollback(&previous_id).await?;
205
206        Ok(())
207    }
208
209    /// 現在適用中のモデルを取得
210    pub fn current_model(&self) -> Option<TrainedModel> {
211        self.applicator.current()
212    }
213}
214
215// ============================================================================
216// Tests
217// ============================================================================
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::learn::lora::{LoraModelId, NoOpApplicator};
223    use std::path::PathBuf;
224
225    fn create_test_model(id: &str) -> TrainedModel {
226        TrainedModel {
227            id: LoraModelId::from_str(id),
228            base_model: "test-base".to_string(),
229            adapter_path: PathBuf::from("/tmp/test"),
230            learn_model_name: "test".to_string(),
231            episode_ids: vec![],
232            sample_count: 10,
233            created_at: 0,
234            metrics: None,
235        }
236    }
237
238    #[tokio::test]
239    async fn test_applier_manual_mode() {
240        let config = ApplierConfig::default(); // Manual mode
241        let applicator = Arc::new(NoOpApplicator::new());
242        let mut applier = Applier::new(config, applicator);
243
244        let model = create_test_model("test-model-1");
245        let result = applier.apply(&model).await.unwrap();
246
247        assert!(!result.is_applied());
248        match result {
249            ApplyResult::Skipped { model_id, .. } => {
250                assert_eq!(model_id, "test-model-1");
251            }
252            _ => panic!("Expected Skipped"),
253        }
254    }
255
256    #[tokio::test]
257    async fn test_applier_auto_mode() {
258        let config = ApplierConfig::default().auto_apply();
259        let applicator = Arc::new(NoOpApplicator::new());
260        let mut applier = Applier::new(config, applicator);
261
262        let model = create_test_model("test-model-1");
263        let result = applier.apply(&model).await.unwrap();
264
265        assert!(result.is_applied());
266        assert_eq!(applier.history().len(), 1);
267    }
268
269    #[tokio::test]
270    async fn test_applier_history_limit() {
271        let config = ApplierConfig::default().auto_apply().max_history(2);
272        let applicator = Arc::new(NoOpApplicator::new());
273        let mut applier = Applier::new(config, applicator);
274
275        // Apply 3 models
276        for i in 0..3 {
277            let model = create_test_model(&format!("model-{}", i));
278            applier.apply(&model).await.unwrap();
279        }
280
281        // History should only have last 2
282        assert_eq!(applier.history().len(), 2);
283        assert_eq!(applier.history()[0], "model-1");
284        assert_eq!(applier.history()[1], "model-2");
285    }
286}