swarm_engine_core/learn/daemon/
applier.rs1use std::sync::Arc;
7
8use crate::learn::lora::{ApplicatorError, ModelApplicator, TrainedModel};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum ApplyMode {
17 #[default]
19 Manual,
20 Auto,
22}
23
24#[derive(Debug, Clone)]
30pub struct ApplierConfig {
31 pub mode: ApplyMode,
33 pub max_history: usize,
35}
36
37impl Default for ApplierConfig {
38 fn default() -> Self {
39 Self {
40 mode: ApplyMode::Manual,
41 max_history: 5,
42 }
43 }
44}
45
46impl ApplierConfig {
47 pub fn auto_apply(mut self) -> Self {
49 self.mode = ApplyMode::Auto;
50 self
51 }
52
53 pub fn max_history(mut self, n: usize) -> Self {
55 self.max_history = n;
56 self
57 }
58}
59
60#[derive(Debug)]
66pub enum ApplierError {
67 Applicator(ApplicatorError),
69 Skipped(String),
71 Other(String),
73}
74
75impl std::fmt::Display for ApplierError {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 Self::Applicator(e) => write!(f, "Applicator error: {}", e),
79 Self::Skipped(msg) => write!(f, "Apply skipped: {}", msg),
80 Self::Other(msg) => write!(f, "{}", msg),
81 }
82 }
83}
84
85impl std::error::Error for ApplierError {}
86
87impl From<ApplicatorError> for ApplierError {
88 fn from(e: ApplicatorError) -> Self {
89 Self::Applicator(e)
90 }
91}
92
93#[derive(Debug)]
99pub enum ApplyResult {
100 Applied {
102 model_id: String,
103 previous_model_id: Option<String>,
104 },
105 Skipped { model_id: String, reason: String },
107}
108
109impl ApplyResult {
110 pub fn is_applied(&self) -> bool {
112 matches!(self, Self::Applied { .. })
113 }
114}
115
116pub struct Applier {
122 config: ApplierConfig,
124 applicator: Arc<dyn ModelApplicator>,
126 history: Vec<String>,
128}
129
130impl Applier {
131 pub fn new(config: ApplierConfig, applicator: Arc<dyn ModelApplicator>) -> Self {
133 Self {
134 config,
135 applicator,
136 history: Vec::new(),
137 }
138 }
139
140 pub fn config(&self) -> &ApplierConfig {
142 &self.config
143 }
144
145 pub fn history(&self) -> &[String] {
147 &self.history
148 }
149
150 pub async fn apply(&mut self, model: &TrainedModel) -> Result<ApplyResult, ApplierError> {
152 match self.config.mode {
153 ApplyMode::Manual => {
154 tracing::info!(
155 model_id = %model.id,
156 "Model ready for manual apply (auto-apply disabled)"
157 );
158 Ok(ApplyResult::Skipped {
159 model_id: model.id.to_string(),
160 reason: "Auto-apply disabled".into(),
161 })
162 }
163 ApplyMode::Auto => self.apply_now(model).await,
164 }
165 }
166
167 pub async fn apply_now(&mut self, model: &TrainedModel) -> Result<ApplyResult, ApplierError> {
169 let previous_model_id = self.applicator.previous_model_id().map(|id| id.to_string());
170
171 tracing::info!(
172 model_id = %model.id,
173 previous = ?previous_model_id,
174 "Applying trained model"
175 );
176
177 self.applicator.apply(model).await?;
178
179 self.history.push(model.id.to_string());
181 if self.history.len() > self.config.max_history {
182 self.history.remove(0);
183 }
184
185 tracing::info!(
186 model_id = %model.id,
187 "Model applied successfully"
188 );
189
190 Ok(ApplyResult::Applied {
191 model_id: model.id.to_string(),
192 previous_model_id,
193 })
194 }
195
196 pub async fn rollback(&self) -> Result<(), ApplierError> {
198 let previous_id = self
199 .applicator
200 .previous_model_id()
201 .ok_or_else(|| ApplierError::Other("No previous model to rollback to".into()))?;
202
203 tracing::info!(target_id = %previous_id, "Rolling back to previous model");
204 self.applicator.rollback(&previous_id).await?;
205
206 Ok(())
207 }
208
209 pub fn current_model(&self) -> Option<TrainedModel> {
211 self.applicator.current()
212 }
213}
214
215#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::learn::lora::{LoraModelId, NoOpApplicator};
223 use std::path::PathBuf;
224
225 fn create_test_model(id: &str) -> TrainedModel {
226 TrainedModel {
227 id: LoraModelId::from_str(id),
228 base_model: "test-base".to_string(),
229 adapter_path: PathBuf::from("/tmp/test"),
230 learn_model_name: "test".to_string(),
231 episode_ids: vec![],
232 sample_count: 10,
233 created_at: 0,
234 metrics: None,
235 }
236 }
237
238 #[tokio::test]
239 async fn test_applier_manual_mode() {
240 let config = ApplierConfig::default(); let applicator = Arc::new(NoOpApplicator::new());
242 let mut applier = Applier::new(config, applicator);
243
244 let model = create_test_model("test-model-1");
245 let result = applier.apply(&model).await.unwrap();
246
247 assert!(!result.is_applied());
248 match result {
249 ApplyResult::Skipped { model_id, .. } => {
250 assert_eq!(model_id, "test-model-1");
251 }
252 _ => panic!("Expected Skipped"),
253 }
254 }
255
256 #[tokio::test]
257 async fn test_applier_auto_mode() {
258 let config = ApplierConfig::default().auto_apply();
259 let applicator = Arc::new(NoOpApplicator::new());
260 let mut applier = Applier::new(config, applicator);
261
262 let model = create_test_model("test-model-1");
263 let result = applier.apply(&model).await.unwrap();
264
265 assert!(result.is_applied());
266 assert_eq!(applier.history().len(), 1);
267 }
268
269 #[tokio::test]
270 async fn test_applier_history_limit() {
271 let config = ApplierConfig::default().auto_apply().max_history(2);
272 let applicator = Arc::new(NoOpApplicator::new());
273 let mut applier = Applier::new(config, applicator);
274
275 for i in 0..3 {
277 let model = create_test_model(&format!("model-{}", i));
278 applier.apply(&model).await.unwrap();
279 }
280
281 assert_eq!(applier.history().len(), 2);
283 assert_eq!(applier.history()[0], "model-1");
284 assert_eq!(applier.history()[1], "model-2");
285 }
286}