scirs2_cluster/adaptive/
mod.rs1use crate::error::ClusteringError;
14
15#[derive(Debug, Clone)]
17pub struct AdaptiveBatchConfig {
18 pub initial_batch_size: usize,
20 pub min_batch: usize,
22 pub max_batch: usize,
24 pub growth_factor: f64,
26 pub decay_factor: f64,
28 pub window: usize,
30}
31
32impl Default for AdaptiveBatchConfig {
33 fn default() -> Self {
34 Self {
35 initial_batch_size: 32,
36 min_batch: 16,
37 max_batch: 2048,
38 growth_factor: 1.5,
39 decay_factor: 0.8,
40 window: 6, }
42 }
43}
44
45pub struct BatchSizeController {
47 pub current_size: usize,
49 pub loss_history: Vec<f64>,
51 config: AdaptiveBatchConfig,
52}
53
54impl BatchSizeController {
55 pub fn new(config: AdaptiveBatchConfig) -> Self {
57 let initial = config
58 .initial_batch_size
59 .clamp(config.min_batch, config.max_batch);
60 Self {
61 current_size: initial,
62 loss_history: Vec::new(),
63 config,
64 }
65 }
66
67 pub fn record_loss(&mut self, loss: f64) {
69 self.loss_history.push(loss);
70 }
71
72 pub fn recommend_size(&self) -> usize {
82 let w = self.config.window.max(2);
83 let half = w / 2;
84
85 if self.loss_history.len() < w {
86 return self.current_size;
87 }
88
89 let recent: &[f64] = &self.loss_history[self.loss_history.len() - half..];
90 let prev: &[f64] =
91 &self.loss_history[self.loss_history.len() - w..self.loss_history.len() - half];
92
93 let mean_recent = mean(recent);
94 let mean_prev = mean(prev);
95 let std_recent = std_dev(recent);
96
97 let relative_std = if mean_recent.abs() > 1e-12 {
99 std_recent / mean_recent.abs()
100 } else {
101 std_recent
102 };
103
104 if relative_std < 0.01 {
105 let new_size =
106 ((self.current_size as f64) * self.config.growth_factor).round() as usize;
107 return new_size.clamp(self.config.min_batch, self.config.max_batch);
108 }
109
110 if mean_recent > mean_prev {
112 let new_size = ((self.current_size as f64) * self.config.decay_factor).round() as usize;
113 return new_size.clamp(self.config.min_batch, self.config.max_batch);
114 }
115
116 self.current_size
117 }
118
119 pub fn adapt(&mut self, loss: f64) -> usize {
121 self.record_loss(loss);
122 let new_size = self.recommend_size();
123 self.current_size = new_size;
124 new_size
125 }
126
127 pub fn reset(&mut self) {
129 self.current_size = self
130 .config
131 .initial_batch_size
132 .clamp(self.config.min_batch, self.config.max_batch);
133 self.loss_history.clear();
134 }
135
136 pub fn validate(&self) -> Result<(), ClusteringError> {
138 if self.config.growth_factor <= 1.0 {
139 return Err(ClusteringError::InvalidInput(
140 "growth_factor must be > 1".into(),
141 ));
142 }
143 if self.config.decay_factor <= 0.0 || self.config.decay_factor >= 1.0 {
144 return Err(ClusteringError::InvalidInput(
145 "decay_factor must be in (0, 1)".into(),
146 ));
147 }
148 if self.config.min_batch > self.config.max_batch {
149 return Err(ClusteringError::InvalidInput(
150 "min_batch must be ≤ max_batch".into(),
151 ));
152 }
153 Ok(())
154 }
155}
156
157fn mean(xs: &[f64]) -> f64 {
162 if xs.is_empty() {
163 return 0.0;
164 }
165 xs.iter().sum::<f64>() / xs.len() as f64
166}
167
168fn std_dev(xs: &[f64]) -> f64 {
169 if xs.len() < 2 {
170 return 0.0;
171 }
172 let m = mean(xs);
173 let var = xs.iter().map(|x| (x - m) * (x - m)).sum::<f64>() / xs.len() as f64;
174 var.sqrt()
175}
176
177#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_initial_size_clamped() {
187 let config = AdaptiveBatchConfig {
188 initial_batch_size: 4,
189 min_batch: 16,
190 max_batch: 2048,
191 ..Default::default()
192 };
193 let ctrl = BatchSizeController::new(config);
194 assert_eq!(ctrl.current_size, 16);
195 }
196
197 #[test]
198 fn test_not_enough_history_returns_current() {
199 let mut ctrl = BatchSizeController::new(AdaptiveBatchConfig::default());
200 ctrl.record_loss(1.0);
201 ctrl.record_loss(0.9);
202 assert_eq!(ctrl.recommend_size(), ctrl.current_size);
204 }
205
206 #[test]
207 fn test_decreasing_loss_grows_batch() {
208 let config = AdaptiveBatchConfig {
209 initial_batch_size: 64,
210 min_batch: 16,
211 max_batch: 2048,
212 growth_factor: 2.0,
213 decay_factor: 0.5,
214 window: 6,
215 };
216 let mut ctrl = BatchSizeController::new(config);
217
218 for i in 0..6 {
220 ctrl.record_loss(1.0 - 0.001 * i as f64);
221 }
222 let size = ctrl.recommend_size();
223 assert!(
224 size > 64,
225 "Batch size should grow on stable decreasing loss, got {}",
226 size
227 );
228 }
229
230 #[test]
231 fn test_increasing_loss_shrinks_batch() {
232 let config = AdaptiveBatchConfig {
233 initial_batch_size: 256,
234 min_batch: 16,
235 max_batch: 2048,
236 growth_factor: 1.5,
237 decay_factor: 0.5,
238 window: 6,
239 };
240 let mut ctrl = BatchSizeController::new(config);
241
242 ctrl.record_loss(0.1);
244 ctrl.record_loss(0.11);
245 ctrl.record_loss(0.12);
246 ctrl.record_loss(1.5);
247 ctrl.record_loss(1.6);
248 ctrl.record_loss(1.7);
249
250 let size = ctrl.recommend_size();
251 assert!(
252 size < 256,
253 "Batch size should shrink on increasing loss, got {}",
254 size
255 );
256 }
257
258 #[test]
259 fn test_adapt_updates_current_size() {
260 let mut ctrl = BatchSizeController::new(AdaptiveBatchConfig {
261 initial_batch_size: 256,
262 window: 6,
263 ..Default::default()
264 });
265
266 ctrl.adapt(0.1);
268 ctrl.adapt(0.11);
269 ctrl.adapt(0.12);
270 ctrl.adapt(1.5);
271 ctrl.adapt(1.6);
272 let final_size = ctrl.adapt(1.7);
273 assert_eq!(
274 final_size, ctrl.current_size,
275 "adapt() should update current_size"
276 );
277 }
278
279 #[test]
280 fn test_bounds_respected() {
281 let config = AdaptiveBatchConfig {
282 initial_batch_size: 17,
283 min_batch: 16,
284 max_batch: 18,
285 growth_factor: 1000.0, decay_factor: 0.001, window: 6,
288 };
289 let mut ctrl = BatchSizeController::new(config);
290
291 for i in 0..6 {
293 ctrl.record_loss(1.0 - 0.0001 * i as f64);
294 }
295 let grown = ctrl.recommend_size();
296 assert!(grown <= 18, "Must not exceed max_batch");
297
298 ctrl.reset();
300 ctrl.record_loss(0.01);
301 ctrl.record_loss(0.01);
302 ctrl.record_loss(0.01);
303 ctrl.record_loss(10.0);
304 ctrl.record_loss(10.0);
305 ctrl.record_loss(10.0);
306 let shrunk = ctrl.recommend_size();
307 assert!(shrunk >= 16, "Must not go below min_batch");
308 }
309
310 #[test]
311 fn test_validate_config() {
312 let ctrl = BatchSizeController::new(AdaptiveBatchConfig::default());
313 assert!(ctrl.validate().is_ok());
314
315 let bad = BatchSizeController::new(AdaptiveBatchConfig {
316 growth_factor: 0.5, ..Default::default()
318 });
319 assert!(bad.validate().is_err());
320 }
321}