sklears_core/
async_traits.rs

1/// Enhanced async trait implementations for non-blocking ML operations
2///
3/// This module provides comprehensive async support for machine learning operations,
4/// including streaming data processing, batch operations, and progress tracking.
5use crate::error::Result;
6use crate::types::FloatBounds;
7use futures_core::{Future, Stream};
8use std::pin::Pin;
9use std::time::Duration;
10
11/// Type alias for async partial fit future
12pub type AsyncPartialFitFuture<'a, T> =
13    Pin<Box<dyn Future<Output = Result<Option<T>>> + Send + 'a>>;
14
15/// Type alias for async predict with confidence future
16pub type AsyncPredictConfidenceFuture<'a, Output> =
17    Pin<Box<dyn Future<Output = Result<(Output, ConfidenceInterval)>> + Send + 'a>>;
18
19/// Type alias for async score future
20pub type AsyncScoreFuture<'a, Score> =
21    Pin<Box<dyn Future<Output = Result<Vec<Score>>> + Send + 'a>>;
22
23/// Type alias for async fit future
24pub type AsyncFitFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T>> + Send + 'a>>;
25
26/// Type alias for async transform future
27pub type AsyncTransformFuture<'a, Output> =
28    Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
29
30/// Type alias for async cross-validation stream
31pub type AsyncCVStream<'a, Score> = Pin<Box<dyn Stream<Item = Result<(usize, Score)>> + Send + 'a>>;
32
33/// Type alias for async ensemble fit stream
34pub type AsyncEnsembleFitStream<'a, Model> =
35    Pin<Box<dyn Stream<Item = Result<(usize, Model)>> + Send + 'a>>;
36
37/// Type alias for async ensemble predict stream
38pub type AsyncEnsemblePredictStream<'a, Output> =
39    Pin<Box<dyn Stream<Item = Result<(usize, Output)>> + Send + 'a>>;
40
41/// Type alias for async optimization stream
42pub type AsyncOptimizationStream<'a, Config, Score> =
43    Pin<Box<dyn Stream<Item = Result<OptimizationResult<Config, Score>>> + Send + 'a>>;
44
45/// Type alias for config factory function
46pub type ConfigFactory<Config> =
47    Box<dyn Fn(&std::collections::HashMap<String, f64>) -> Config + Send + Sync>;
48
49/// Type alias for async unit future
50pub type AsyncUnitFuture<'a> = Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
51
52/// Progress information for long-running operations
53#[derive(Debug, Clone)]
54pub struct ProgressInfo {
55    /// Current progress (0.0 to 1.0)
56    pub progress: f64,
57    /// Current step/iteration
58    pub current_step: usize,
59    /// Total steps (if known)
60    pub total_steps: Option<usize>,
61    /// Elapsed time
62    pub elapsed: Duration,
63    /// Estimated time remaining
64    pub eta: Option<Duration>,
65    /// Current metric value (e.g., loss, accuracy)
66    pub current_metric: Option<f64>,
67    /// Additional status message
68    pub message: String,
69}
70
71impl ProgressInfo {
72    /// Create new progress info
73    pub fn new(progress: f64, current_step: usize) -> Self {
74        Self {
75            progress: progress.clamp(0.0, 1.0),
76            current_step,
77            total_steps: None,
78            elapsed: Duration::from_secs(0),
79            eta: None,
80            current_metric: None,
81            message: String::new(),
82        }
83    }
84
85    /// Set total steps
86    pub fn with_total_steps(mut self, total: usize) -> Self {
87        self.total_steps = Some(total);
88        if total > 0 {
89            self.progress = self.current_step as f64 / total as f64;
90        }
91        self
92    }
93
94    /// Set elapsed time
95    pub fn with_elapsed(mut self, elapsed: Duration) -> Self {
96        self.elapsed = elapsed;
97        self
98    }
99
100    /// Set estimated time remaining
101    pub fn with_eta(mut self, eta: Duration) -> Self {
102        self.eta = Some(eta);
103        self
104    }
105
106    /// Set current metric value
107    pub fn with_metric(mut self, metric: f64) -> Self {
108        self.current_metric = Some(metric);
109        self
110    }
111
112    /// Set status message
113    pub fn with_message<S: Into<String>>(mut self, message: S) -> Self {
114        self.message = message.into();
115        self
116    }
117}
118
119/// Configuration for async operations
120#[derive(Debug, Clone)]
121pub struct AsyncConfig {
122    /// Batch size for streaming operations
123    pub batch_size: usize,
124    /// Timeout for individual operations
125    pub operation_timeout: Option<Duration>,
126    /// Whether to report progress
127    pub enable_progress: bool,
128    /// Progress reporting interval
129    pub progress_interval: Duration,
130    /// Maximum concurrent operations
131    pub max_concurrency: usize,
132}
133
134impl Default for AsyncConfig {
135    fn default() -> Self {
136        Self {
137            batch_size: 1000,
138            operation_timeout: Some(Duration::from_secs(300)), // 5 minutes
139            enable_progress: true,
140            progress_interval: Duration::from_secs(1),
141            max_concurrency: num_cpus::get(),
142        }
143    }
144}
145
146/// Enhanced async fit trait with progress tracking and cancellation
147pub trait AsyncFitAdvanced<X, Y, State = crate::traits::Untrained> {
148    /// The fitted model type
149    type Fitted;
150
151    /// Error type
152    type Error: std::error::Error + Send + Sync;
153
154    /// Fit the model asynchronously with progress tracking
155    fn fit_async_with_progress<'a>(
156        self,
157        x: &'a X,
158        y: &'a Y,
159        config: &'a AsyncConfig,
160    ) -> AsyncFitFuture<'a, Self::Fitted>
161    where
162        Self: 'a;
163
164    /// Fit the model with progress reporting via a stream
165    fn fit_async_with_progress_stream<'a>(
166        self,
167        x: &'a X,
168        y: &'a Y,
169        config: &'a AsyncConfig,
170    ) -> Pin<Box<dyn Stream<Item = Result<ProgressInfo>> + Send + 'a>>
171    where
172        Self: 'a;
173
174    /// Fit the model with cancellation support
175    fn fit_async_cancellable<'a>(
176        self,
177        x: &'a X,
178        y: &'a Y,
179        cancel_token: CancellationToken,
180    ) -> AsyncPartialFitFuture<'a, Self::Fitted>
181    where
182        Self: 'a;
183}
184
185/// Enhanced async predict trait with batch processing
186pub trait AsyncPredictAdvanced<X, Output> {
187    /// Error type
188    type Error: std::error::Error + Send + Sync;
189
190    /// Make predictions asynchronously with batching
191    fn predict_async_batched<'a>(
192        &'a self,
193        x: &'a X,
194        config: &'a AsyncConfig,
195    ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
196
197    /// Stream predictions for large datasets
198    fn predict_stream<'a>(
199        &'a self,
200        x_stream: Pin<Box<dyn Stream<Item = X> + Send + 'a>>,
201        config: &'a AsyncConfig,
202    ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>;
203
204    /// Predict with confidence intervals (if supported)
205    fn predict_async_with_uncertainty<'a>(
206        &'a self,
207        x: &'a X,
208        confidence_level: f64,
209    ) -> AsyncPredictConfidenceFuture<'a, Output>
210    where
211        Self: 'a;
212}
213
214/// Enhanced async transform trait with streaming support
215pub trait AsyncTransformAdvanced<X, Output = X> {
216    /// Error type
217    type Error: std::error::Error + Send + Sync;
218
219    /// Transform data asynchronously with progress tracking
220    fn transform_async_with_progress<'a>(
221        &'a self,
222        x: &'a X,
223        config: &'a AsyncConfig,
224    ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
225
226    /// Stream data transformation
227    fn transform_stream<'a>(
228        &'a self,
229        x_stream: Pin<Box<dyn Stream<Item = X> + Send + 'a>>,
230        config: &'a AsyncConfig,
231    ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>;
232
233    /// Transform with memory-efficient chunking
234    fn transform_async_chunked<'a>(
235        &'a self,
236        x: &'a X,
237        chunk_size: usize,
238    ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>;
239}
240
241/// Async partial fit trait for online learning
242pub trait AsyncPartialFit<X, Y> {
243    /// Error type
244    type Error: std::error::Error + Send + Sync;
245
246    /// Perform partial fit asynchronously
247    fn partial_fit_async<'a>(
248        &'a mut self,
249        x: &'a X,
250        y: &'a Y,
251    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
252
253    /// Stream partial fit for continuous learning
254    fn partial_fit_stream<'a>(
255        &'a mut self,
256        data_stream: Pin<Box<dyn Stream<Item = (X, Y)> + Send + 'a>>,
257        config: &'a AsyncConfig,
258    ) -> Pin<Box<dyn Stream<Item = Result<ProgressInfo>> + Send + 'a>>;
259
260    /// Adaptive learning with dynamic batch sizing
261    fn adaptive_partial_fit<'a>(
262        &'a mut self,
263        data_stream: Pin<Box<dyn Stream<Item = (X, Y)> + Send + 'a>>,
264        adaptation_config: AdaptationConfig,
265    ) -> Pin<Box<dyn Stream<Item = Result<AdaptationInfo>> + Send + 'a>>;
266}
267
268/// Configuration for adaptive learning
269#[derive(Debug, Clone)]
270pub struct AdaptationConfig {
271    /// Initial batch size
272    pub initial_batch_size: usize,
273    /// Minimum batch size
274    pub min_batch_size: usize,
275    /// Maximum batch size
276    pub max_batch_size: usize,
277    /// Learning rate adaptation factor
278    pub adaptation_rate: f64,
279    /// Performance threshold for batch size increase
280    pub performance_threshold: f64,
281    /// Memory usage threshold (bytes)
282    pub memory_threshold: usize,
283}
284
285/// Information about adaptive learning progress
286#[derive(Debug, Clone)]
287pub struct AdaptationInfo {
288    /// Current batch size
289    pub current_batch_size: usize,
290    /// Current learning rate
291    pub current_learning_rate: f64,
292    /// Performance metric
293    pub performance_metric: f64,
294    /// Memory usage
295    pub memory_usage: usize,
296    /// General progress information
297    pub progress: ProgressInfo,
298}
299
300/// Confidence interval for predictions
301#[derive(Debug, Clone)]
302pub struct ConfidenceInterval {
303    /// Lower bound
304    pub lower: f64,
305    /// Upper bound
306    pub upper: f64,
307    /// Confidence level (0.0 to 1.0)
308    pub confidence_level: f64,
309}
310
311/// Cancellation token for async operations
312#[derive(Debug, Clone)]
313pub struct CancellationToken {
314    inner: std::sync::Arc<std::sync::atomic::AtomicBool>,
315}
316
317impl CancellationToken {
318    /// Create a new cancellation token
319    pub fn new() -> Self {
320        Self {
321            inner: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
322        }
323    }
324
325    /// Cancel the operation
326    pub fn cancel(&self) {
327        self.inner.store(true, std::sync::atomic::Ordering::Relaxed);
328    }
329
330    /// Check if cancellation was requested
331    pub fn is_cancelled(&self) -> bool {
332        self.inner.load(std::sync::atomic::Ordering::Relaxed)
333    }
334}
335
336impl Default for CancellationToken {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342/// Async cross-validation with parallel fold execution
343pub trait AsyncCrossValidation<X, Y> {
344    type Score: FloatBounds + Send;
345    type Model: Clone + Send + Sync;
346
347    /// Perform k-fold cross-validation asynchronously
348    fn cross_validate_async<'a>(
349        &'a self,
350        model: Self::Model,
351        x: &'a X,
352        y: &'a Y,
353        cv_folds: usize,
354        config: &'a AsyncConfig,
355    ) -> AsyncScoreFuture<'a, Self::Score>
356    where
357        X: Clone + Send + Sync,
358        Y: Clone + Send + Sync;
359
360    /// Stream cross-validation results as they complete
361    fn cross_validate_stream<'a>(
362        &'a self,
363        model: Self::Model,
364        x: &'a X,
365        y: &'a Y,
366        cv_folds: usize,
367        config: &'a AsyncConfig,
368    ) -> AsyncCVStream<'a, Self::Score>
369    where
370        X: Clone + Send + Sync,
371        Y: Clone + Send + Sync;
372}
373
374/// Async ensemble methods
375pub trait AsyncEnsemble<X, Y, Output> {
376    type Model: Send + Sync;
377    type Error: std::error::Error + Send + Sync;
378
379    /// Train ensemble members asynchronously
380    fn fit_ensemble_async<'a>(
381        models: Vec<Self::Model>,
382        x: &'a X,
383        y: &'a Y,
384        config: &'a AsyncConfig,
385    ) -> AsyncEnsembleFitStream<'a, Self::Model>
386    where
387        X: Send + Sync,
388        Y: Send + Sync,
389        Self::Model: 'a;
390
391    /// Make ensemble predictions asynchronously
392    fn predict_ensemble_async<'a>(
393        models: &'a [Self::Model],
394        x: &'a X,
395        config: &'a AsyncConfig,
396    ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>
397    where
398        X: Send + Sync;
399
400    /// Stream ensemble predictions with individual model results
401    fn predict_ensemble_stream<'a>(
402        models: &'a [Self::Model],
403        x: &'a X,
404        config: &'a AsyncConfig,
405    ) -> AsyncEnsemblePredictStream<'a, Output>
406    where
407        X: Send + Sync;
408}
409
410/// Async hyperparameter optimization
411pub trait AsyncHyperparameterOptimization<X, Y, Config> {
412    type Score: FloatBounds + Send;
413    type Error: std::error::Error + Send + Sync;
414
415    /// Optimize hyperparameters asynchronously
416    fn optimize_async<'a>(
417        &'a self,
418        x: &'a X,
419        y: &'a Y,
420        param_space: ParameterSpace<Config>,
421        optimization_config: OptimizationConfig,
422    ) -> AsyncOptimizationStream<'a, Config, Self::Score>
423    where
424        X: Send + Sync,
425        Y: Send + Sync,
426        Config: Send + Sync;
427}
428
429/// Parameter space definition for optimization
430pub struct ParameterSpace<Config> {
431    /// Parameter ranges and distributions
432    pub parameters: std::collections::HashMap<String, ParameterRange>,
433    /// Parameter dependencies
434    pub dependencies: Vec<ParameterDependency>,
435    /// Configuration factory function
436    pub config_factory: ConfigFactory<Config>,
437}
438
439/// Parameter range definition
440#[derive(Debug, Clone)]
441pub enum ParameterRange {
442    /// Continuous range [min, max]
443    Continuous { min: f64, max: f64 },
444    /// Discrete choices
445    Discrete { values: Vec<f64> },
446    /// Log-scale continuous range
447    LogContinuous { min: f64, max: f64 },
448    /// Integer range [min, max]
449    Integer { min: i64, max: i64 },
450}
451
452/// Parameter dependency definition
453pub struct ParameterDependency {
454    /// Dependent parameter name
455    pub dependent: String,
456    /// Parent parameter name
457    pub parent: String,
458    /// Condition for dependency
459    pub condition: Box<dyn Fn(f64) -> bool + Send + Sync>,
460}
461
462/// Optimization configuration
463#[derive(Debug, Clone)]
464pub struct OptimizationConfig {
465    /// Maximum number of evaluations
466    pub max_evaluations: usize,
467    /// Optimization algorithm
468    pub algorithm: OptimizationAlgorithm,
469    /// Early stopping configuration
470    pub early_stopping: Option<EarlyStoppingConfig>,
471    /// Parallel evaluation configuration
472    pub parallel_config: AsyncConfig,
473}
474
475/// Optimization algorithm selection
476#[derive(Debug, Clone)]
477pub enum OptimizationAlgorithm {
478    /// Random search
479    Random,
480    /// Bayesian optimization with Gaussian processes
481    BayesianOptimization {
482        acquisition_function: AcquisitionFunction,
483        n_initial_points: usize,
484    },
485    /// Tree-structured Parzen estimators
486    TPE {
487        n_startup_trials: usize,
488        n_ei_candidates: usize,
489    },
490    /// Hyperband algorithm
491    Hyperband { max_resource: usize, eta: f64 },
492}
493
494/// Acquisition function for Bayesian optimization
495#[derive(Debug, Clone)]
496pub enum AcquisitionFunction {
497    /// Expected improvement
498    ExpectedImprovement,
499    /// Upper confidence bound
500    UpperConfidenceBound { kappa: f64 },
501    /// Probability of improvement
502    ProbabilityOfImprovement,
503}
504
505/// Early stopping configuration
506#[derive(Debug, Clone)]
507pub struct EarlyStoppingConfig {
508    /// Patience (iterations without improvement)
509    pub patience: usize,
510    /// Minimum improvement threshold
511    pub min_improvement: f64,
512    /// Direction of optimization (maximize or minimize)
513    pub maximize: bool,
514}
515
516/// Optimization result
517#[derive(Debug, Clone)]
518pub struct OptimizationResult<Config, Score> {
519    /// Trial number
520    pub trial: usize,
521    /// Parameter configuration
522    pub config: Config,
523    /// Achieved score
524    pub score: Score,
525    /// Evaluation time
526    pub evaluation_time: Duration,
527    /// Additional metrics
528    pub metrics: std::collections::HashMap<String, f64>,
529}
530
531/// Async model persistence
532pub trait AsyncModelPersistence {
533    type Error: std::error::Error + Send + Sync;
534
535    /// Save model asynchronously
536    fn save_async<'a>(
537        &'a self,
538        path: &'a std::path::Path,
539    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
540
541    /// Load model asynchronously
542    fn load_async<'a>(
543        path: &'a std::path::Path,
544    ) -> Pin<Box<dyn Future<Output = Result<Self>> + Send + 'a>>
545    where
546        Self: Sized;
547
548    /// Save model with compression
549    fn save_compressed_async<'a>(
550        &'a self,
551        path: &'a std::path::Path,
552        compression_level: u32,
553    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
554}
555
556#[allow(non_snake_case)]
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use std::time::Duration;
561
562    #[test]
563    fn test_progress_info() {
564        let progress = ProgressInfo::new(0.5, 50)
565            .with_total_steps(100)
566            .with_elapsed(Duration::from_secs(30))
567            .with_eta(Duration::from_secs(30))
568            .with_metric(0.85)
569            .with_message("Training in progress");
570
571        assert_eq!(progress.progress, 0.5);
572        assert_eq!(progress.current_step, 50);
573        assert_eq!(progress.total_steps, Some(100));
574        assert_eq!(progress.elapsed, Duration::from_secs(30));
575        assert_eq!(progress.eta, Some(Duration::from_secs(30)));
576        assert_eq!(progress.current_metric, Some(0.85));
577        assert_eq!(progress.message, "Training in progress");
578    }
579
580    #[test]
581    fn test_cancellation_token() {
582        let token = CancellationToken::new();
583        assert!(!token.is_cancelled());
584
585        token.cancel();
586        assert!(token.is_cancelled());
587    }
588
589    #[test]
590    fn test_async_config_default() {
591        let config = AsyncConfig::default();
592        assert_eq!(config.batch_size, 1000);
593        assert!(config.enable_progress);
594        assert_eq!(config.progress_interval, Duration::from_secs(1));
595    }
596
597    #[test]
598    fn test_confidence_interval() {
599        let ci = ConfidenceInterval {
600            lower: 0.1,
601            upper: 0.9,
602            confidence_level: 0.95,
603        };
604
605        assert_eq!(ci.lower, 0.1);
606        assert_eq!(ci.upper, 0.9);
607        assert_eq!(ci.confidence_level, 0.95);
608    }
609}