1use crate::error::Result;
6use crate::types::FloatBounds;
7use futures_core::{Future, Stream};
8use std::pin::Pin;
9use std::time::Duration;
10
11pub type AsyncPartialFitFuture<'a, T> =
13 Pin<Box<dyn Future<Output = Result<Option<T>>> + Send + 'a>>;
14
15pub type AsyncPredictConfidenceFuture<'a, Output> =
17 Pin<Box<dyn Future<Output = Result<(Output, ConfidenceInterval)>> + Send + 'a>>;
18
19pub type AsyncScoreFuture<'a, Score> =
21 Pin<Box<dyn Future<Output = Result<Vec<Score>>> + Send + 'a>>;
22
23pub type AsyncFitFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T>> + Send + 'a>>;
25
26pub type AsyncTransformFuture<'a, Output> =
28 Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
29
30pub type AsyncCVStream<'a, Score> = Pin<Box<dyn Stream<Item = Result<(usize, Score)>> + Send + 'a>>;
32
33pub type AsyncEnsembleFitStream<'a, Model> =
35 Pin<Box<dyn Stream<Item = Result<(usize, Model)>> + Send + 'a>>;
36
37pub type AsyncEnsemblePredictStream<'a, Output> =
39 Pin<Box<dyn Stream<Item = Result<(usize, Output)>> + Send + 'a>>;
40
41pub type AsyncOptimizationStream<'a, Config, Score> =
43 Pin<Box<dyn Stream<Item = Result<OptimizationResult<Config, Score>>> + Send + 'a>>;
44
45pub type ConfigFactory<Config> =
47 Box<dyn Fn(&std::collections::HashMap<String, f64>) -> Config + Send + Sync>;
48
49pub type AsyncUnitFuture<'a> = Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
51
52#[derive(Debug, Clone)]
54pub struct ProgressInfo {
55 pub progress: f64,
57 pub current_step: usize,
59 pub total_steps: Option<usize>,
61 pub elapsed: Duration,
63 pub eta: Option<Duration>,
65 pub current_metric: Option<f64>,
67 pub message: String,
69}
70
71impl ProgressInfo {
72 pub fn new(progress: f64, current_step: usize) -> Self {
74 Self {
75 progress: progress.clamp(0.0, 1.0),
76 current_step,
77 total_steps: None,
78 elapsed: Duration::from_secs(0),
79 eta: None,
80 current_metric: None,
81 message: String::new(),
82 }
83 }
84
85 pub fn with_total_steps(mut self, total: usize) -> Self {
87 self.total_steps = Some(total);
88 if total > 0 {
89 self.progress = self.current_step as f64 / total as f64;
90 }
91 self
92 }
93
94 pub fn with_elapsed(mut self, elapsed: Duration) -> Self {
96 self.elapsed = elapsed;
97 self
98 }
99
100 pub fn with_eta(mut self, eta: Duration) -> Self {
102 self.eta = Some(eta);
103 self
104 }
105
106 pub fn with_metric(mut self, metric: f64) -> Self {
108 self.current_metric = Some(metric);
109 self
110 }
111
112 pub fn with_message<S: Into<String>>(mut self, message: S) -> Self {
114 self.message = message.into();
115 self
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct AsyncConfig {
122 pub batch_size: usize,
124 pub operation_timeout: Option<Duration>,
126 pub enable_progress: bool,
128 pub progress_interval: Duration,
130 pub max_concurrency: usize,
132}
133
134impl Default for AsyncConfig {
135 fn default() -> Self {
136 Self {
137 batch_size: 1000,
138 operation_timeout: Some(Duration::from_secs(300)), enable_progress: true,
140 progress_interval: Duration::from_secs(1),
141 max_concurrency: num_cpus::get(),
142 }
143 }
144}
145
146pub trait AsyncFitAdvanced<X, Y, State = crate::traits::Untrained> {
148 type Fitted;
150
151 type Error: std::error::Error + Send + Sync;
153
154 fn fit_async_with_progress<'a>(
156 self,
157 x: &'a X,
158 y: &'a Y,
159 config: &'a AsyncConfig,
160 ) -> AsyncFitFuture<'a, Self::Fitted>
161 where
162 Self: 'a;
163
164 fn fit_async_with_progress_stream<'a>(
166 self,
167 x: &'a X,
168 y: &'a Y,
169 config: &'a AsyncConfig,
170 ) -> Pin<Box<dyn Stream<Item = Result<ProgressInfo>> + Send + 'a>>
171 where
172 Self: 'a;
173
174 fn fit_async_cancellable<'a>(
176 self,
177 x: &'a X,
178 y: &'a Y,
179 cancel_token: CancellationToken,
180 ) -> AsyncPartialFitFuture<'a, Self::Fitted>
181 where
182 Self: 'a;
183}
184
185pub trait AsyncPredictAdvanced<X, Output> {
187 type Error: std::error::Error + Send + Sync;
189
190 fn predict_async_batched<'a>(
192 &'a self,
193 x: &'a X,
194 config: &'a AsyncConfig,
195 ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
196
197 fn predict_stream<'a>(
199 &'a self,
200 x_stream: Pin<Box<dyn Stream<Item = X> + Send + 'a>>,
201 config: &'a AsyncConfig,
202 ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>;
203
204 fn predict_async_with_uncertainty<'a>(
206 &'a self,
207 x: &'a X,
208 confidence_level: f64,
209 ) -> AsyncPredictConfidenceFuture<'a, Output>
210 where
211 Self: 'a;
212}
213
214pub trait AsyncTransformAdvanced<X, Output = X> {
216 type Error: std::error::Error + Send + Sync;
218
219 fn transform_async_with_progress<'a>(
221 &'a self,
222 x: &'a X,
223 config: &'a AsyncConfig,
224 ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
225
226 fn transform_stream<'a>(
228 &'a self,
229 x_stream: Pin<Box<dyn Stream<Item = X> + Send + 'a>>,
230 config: &'a AsyncConfig,
231 ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>;
232
233 fn transform_async_chunked<'a>(
235 &'a self,
236 x: &'a X,
237 chunk_size: usize,
238 ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>;
239}
240
241pub trait AsyncPartialFit<X, Y> {
243 type Error: std::error::Error + Send + Sync;
245
246 fn partial_fit_async<'a>(
248 &'a mut self,
249 x: &'a X,
250 y: &'a Y,
251 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
252
253 fn partial_fit_stream<'a>(
255 &'a mut self,
256 data_stream: Pin<Box<dyn Stream<Item = (X, Y)> + Send + 'a>>,
257 config: &'a AsyncConfig,
258 ) -> Pin<Box<dyn Stream<Item = Result<ProgressInfo>> + Send + 'a>>;
259
260 fn adaptive_partial_fit<'a>(
262 &'a mut self,
263 data_stream: Pin<Box<dyn Stream<Item = (X, Y)> + Send + 'a>>,
264 adaptation_config: AdaptationConfig,
265 ) -> Pin<Box<dyn Stream<Item = Result<AdaptationInfo>> + Send + 'a>>;
266}
267
268#[derive(Debug, Clone)]
270pub struct AdaptationConfig {
271 pub initial_batch_size: usize,
273 pub min_batch_size: usize,
275 pub max_batch_size: usize,
277 pub adaptation_rate: f64,
279 pub performance_threshold: f64,
281 pub memory_threshold: usize,
283}
284
285#[derive(Debug, Clone)]
287pub struct AdaptationInfo {
288 pub current_batch_size: usize,
290 pub current_learning_rate: f64,
292 pub performance_metric: f64,
294 pub memory_usage: usize,
296 pub progress: ProgressInfo,
298}
299
300#[derive(Debug, Clone)]
302pub struct ConfidenceInterval {
303 pub lower: f64,
305 pub upper: f64,
307 pub confidence_level: f64,
309}
310
311#[derive(Debug, Clone)]
313pub struct CancellationToken {
314 inner: std::sync::Arc<std::sync::atomic::AtomicBool>,
315}
316
317impl CancellationToken {
318 pub fn new() -> Self {
320 Self {
321 inner: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
322 }
323 }
324
325 pub fn cancel(&self) {
327 self.inner.store(true, std::sync::atomic::Ordering::Relaxed);
328 }
329
330 pub fn is_cancelled(&self) -> bool {
332 self.inner.load(std::sync::atomic::Ordering::Relaxed)
333 }
334}
335
336impl Default for CancellationToken {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342pub trait AsyncCrossValidation<X, Y> {
344 type Score: FloatBounds + Send;
345 type Model: Clone + Send + Sync;
346
347 fn cross_validate_async<'a>(
349 &'a self,
350 model: Self::Model,
351 x: &'a X,
352 y: &'a Y,
353 cv_folds: usize,
354 config: &'a AsyncConfig,
355 ) -> AsyncScoreFuture<'a, Self::Score>
356 where
357 X: Clone + Send + Sync,
358 Y: Clone + Send + Sync;
359
360 fn cross_validate_stream<'a>(
362 &'a self,
363 model: Self::Model,
364 x: &'a X,
365 y: &'a Y,
366 cv_folds: usize,
367 config: &'a AsyncConfig,
368 ) -> AsyncCVStream<'a, Self::Score>
369 where
370 X: Clone + Send + Sync,
371 Y: Clone + Send + Sync;
372}
373
374pub trait AsyncEnsemble<X, Y, Output> {
376 type Model: Send + Sync;
377 type Error: std::error::Error + Send + Sync;
378
379 fn fit_ensemble_async<'a>(
381 models: Vec<Self::Model>,
382 x: &'a X,
383 y: &'a Y,
384 config: &'a AsyncConfig,
385 ) -> AsyncEnsembleFitStream<'a, Self::Model>
386 where
387 X: Send + Sync,
388 Y: Send + Sync,
389 Self::Model: 'a;
390
391 fn predict_ensemble_async<'a>(
393 models: &'a [Self::Model],
394 x: &'a X,
395 config: &'a AsyncConfig,
396 ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>
397 where
398 X: Send + Sync;
399
400 fn predict_ensemble_stream<'a>(
402 models: &'a [Self::Model],
403 x: &'a X,
404 config: &'a AsyncConfig,
405 ) -> AsyncEnsemblePredictStream<'a, Output>
406 where
407 X: Send + Sync;
408}
409
410pub trait AsyncHyperparameterOptimization<X, Y, Config> {
412 type Score: FloatBounds + Send;
413 type Error: std::error::Error + Send + Sync;
414
415 fn optimize_async<'a>(
417 &'a self,
418 x: &'a X,
419 y: &'a Y,
420 param_space: ParameterSpace<Config>,
421 optimization_config: OptimizationConfig,
422 ) -> AsyncOptimizationStream<'a, Config, Self::Score>
423 where
424 X: Send + Sync,
425 Y: Send + Sync,
426 Config: Send + Sync;
427}
428
429pub struct ParameterSpace<Config> {
431 pub parameters: std::collections::HashMap<String, ParameterRange>,
433 pub dependencies: Vec<ParameterDependency>,
435 pub config_factory: ConfigFactory<Config>,
437}
438
439#[derive(Debug, Clone)]
441pub enum ParameterRange {
442 Continuous { min: f64, max: f64 },
444 Discrete { values: Vec<f64> },
446 LogContinuous { min: f64, max: f64 },
448 Integer { min: i64, max: i64 },
450}
451
452pub struct ParameterDependency {
454 pub dependent: String,
456 pub parent: String,
458 pub condition: Box<dyn Fn(f64) -> bool + Send + Sync>,
460}
461
462#[derive(Debug, Clone)]
464pub struct OptimizationConfig {
465 pub max_evaluations: usize,
467 pub algorithm: OptimizationAlgorithm,
469 pub early_stopping: Option<EarlyStoppingConfig>,
471 pub parallel_config: AsyncConfig,
473}
474
475#[derive(Debug, Clone)]
477pub enum OptimizationAlgorithm {
478 Random,
480 BayesianOptimization {
482 acquisition_function: AcquisitionFunction,
483 n_initial_points: usize,
484 },
485 TPE {
487 n_startup_trials: usize,
488 n_ei_candidates: usize,
489 },
490 Hyperband { max_resource: usize, eta: f64 },
492}
493
494#[derive(Debug, Clone)]
496pub enum AcquisitionFunction {
497 ExpectedImprovement,
499 UpperConfidenceBound { kappa: f64 },
501 ProbabilityOfImprovement,
503}
504
505#[derive(Debug, Clone)]
507pub struct EarlyStoppingConfig {
508 pub patience: usize,
510 pub min_improvement: f64,
512 pub maximize: bool,
514}
515
516#[derive(Debug, Clone)]
518pub struct OptimizationResult<Config, Score> {
519 pub trial: usize,
521 pub config: Config,
523 pub score: Score,
525 pub evaluation_time: Duration,
527 pub metrics: std::collections::HashMap<String, f64>,
529}
530
531pub trait AsyncModelPersistence {
533 type Error: std::error::Error + Send + Sync;
534
535 fn save_async<'a>(
537 &'a self,
538 path: &'a std::path::Path,
539 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
540
541 fn load_async<'a>(
543 path: &'a std::path::Path,
544 ) -> Pin<Box<dyn Future<Output = Result<Self>> + Send + 'a>>
545 where
546 Self: Sized;
547
548 fn save_compressed_async<'a>(
550 &'a self,
551 path: &'a std::path::Path,
552 compression_level: u32,
553 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
554}
555
556#[allow(non_snake_case)]
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use std::time::Duration;
561
562 #[test]
563 fn test_progress_info() {
564 let progress = ProgressInfo::new(0.5, 50)
565 .with_total_steps(100)
566 .with_elapsed(Duration::from_secs(30))
567 .with_eta(Duration::from_secs(30))
568 .with_metric(0.85)
569 .with_message("Training in progress");
570
571 assert_eq!(progress.progress, 0.5);
572 assert_eq!(progress.current_step, 50);
573 assert_eq!(progress.total_steps, Some(100));
574 assert_eq!(progress.elapsed, Duration::from_secs(30));
575 assert_eq!(progress.eta, Some(Duration::from_secs(30)));
576 assert_eq!(progress.current_metric, Some(0.85));
577 assert_eq!(progress.message, "Training in progress");
578 }
579
580 #[test]
581 fn test_cancellation_token() {
582 let token = CancellationToken::new();
583 assert!(!token.is_cancelled());
584
585 token.cancel();
586 assert!(token.is_cancelled());
587 }
588
589 #[test]
590 fn test_async_config_default() {
591 let config = AsyncConfig::default();
592 assert_eq!(config.batch_size, 1000);
593 assert!(config.enable_progress);
594 assert_eq!(config.progress_interval, Duration::from_secs(1));
595 }
596
597 #[test]
598 fn test_confidence_interval() {
599 let ci = ConfidenceInterval {
600 lower: 0.1,
601 upper: 0.9,
602 confidence_level: 0.95,
603 };
604
605 assert_eq!(ci.lower, 0.1);
606 assert_eq!(ci.upper, 0.9);
607 assert_eq!(ci.confidence_level, 0.95);
608 }
609}