tsai/
lib.rs

1//! # tsai
2//!
3//! Time series deep learning in Rust - a feature-parity port of Python tsai.
4//!
5//! tsai-rs provides a comprehensive toolkit for time series analysis using deep learning:
6//!
7//! - **Data handling**: Datasets, dataloaders, and preprocessing
8//! - **Transforms**: Augmentations, label mixing, and imaging transforms
9//! - **Models**: CNN, Transformer, ROCKET, RNN, and Tabular architectures
10//! - **Training**: Learner, callbacks, metrics, and schedulers
11//! - **Analysis**: Confusion matrix, top losses, importance
12//! - **Explainability**: Attribution maps, activation capture
13//!
14//! ## Quick Start
15//!
16//! ```rust,ignore
17//! use tsai::prelude::*;
18//!
19//! // Load data
20//! let x = read_npy("data/X_train.npy")?;
21//! let y = read_npy("data/y_train.npy")?;
22//! let dataset = TSDataset::from_arrays(x, Some(y))?;
23//!
24//! // Create dataloaders
25//! let (train_ds, valid_ds) = train_test_split(&dataset, 0.2, Seed::new(42))?;
26//! let dls = TSDataLoaders::builder(train_ds, valid_ds)
27//!     .batch_size(64)
28//!     .build()?;
29//!
30//! // Create model
31//! let config = InceptionTimePlusConfig::new(dls.n_vars(), dls.seq_len(), n_classes);
32//! let model = config.init(&device);
33//!
34//! // Train
35//! let learner = Learner::new(model, dls, LearnerConfig::default(), &device);
36//! learner.fit_one_cycle(25, 1e-3)?;
37//! ```
38//!
39//! ## Feature Flags
40//!
41//! - `backend-ndarray` (default): CPU backend using ndarray
42//! - `backend-wgpu`: GPU backend using WGPU (Metal on macOS, Vulkan on Linux/Windows)
43//! - `backend-tch`: PyTorch backend via tch-rs
44//! - `wandb`: Weights & Biases integration
45
46#![deny(unsafe_code)]
47#![warn(missing_docs)]
48#![warn(clippy::all)]
49
50// Re-export all crates
51pub use tsai_analysis as analysis;
52pub use tsai_core as core;
53pub use tsai_data as data;
54pub use tsai_explain as explain;
55pub use tsai_models as models;
56pub use tsai_train as train;
57pub use tsai_transforms as transforms;
58
59/// Prelude module for convenient imports.
60///
61/// ```rust,ignore
62/// use tsai::prelude::*;
63/// ```
64pub mod prelude {
65    // Core types
66    pub use tsai_core::{Result, Seed, Split, TSBatch, TSShape, TSTensor, Transform};
67
68    // Data
69    pub use tsai_data::{
70        read_csv, read_npy, read_npz, read_parquet, train_test_split, train_valid_test_split,
71        TSDataLoader, TSDataLoaders, TSDataset, TSDatasets,
72    };
73
74    // Transforms
75    pub use tsai_transforms::{
76        Compose, CutMix1d, CutOut, GaussianNoise, Identity, MagScale, MixUp1d, TimeWarp,
77    };
78
79    // Models
80    pub use tsai_models::{
81        InceptionTimePlus, InceptionTimePlusConfig, MiniRocket, MiniRocketConfig, PatchTST,
82        PatchTSTConfig, RNNPlus, RNNPlusConfig, ResNetPlus, ResNetPlusConfig, TSTPlus, TSTConfig,
83    };
84
85    // Training
86    pub use tsai_train::{
87        Accuracy, Callback, CrossEntropyLoss, Learner, LearnerConfig, MSE, MSELoss, Metric,
88        OneCycleLR, Scheduler,
89    };
90
91    // Analysis
92    pub use tsai_analysis::{confusion_matrix, top_losses, ConfusionMatrix};
93
94    // Explain
95    pub use tsai_explain::{AttributionMap, AttributionMethod};
96}
97
98/// All module for importing everything.
99///
100/// Mirrors the `from tsai.all import *` pattern from Python.
101pub mod all {
102    pub use super::prelude::*;
103
104    // Additional exports
105    pub use tsai_core::backend;
106    pub use tsai_data::{RandomSampler, SequentialSampler, StratifiedSampler};
107    pub use tsai_train::{
108        CallbackContext, CallbackList, EarlyStoppingCallback, OneCycleLR, ProgressCallback,
109    };
110    pub use tsai_transforms::{GAFType, RecurrencePlotConfig, TSToGADF, TSToGASF, TSToRP};
111}
112
113/// Compatibility module for sklearn-like API.
114pub mod compat {
115    pub use tsai_train::compat::{
116        TSClassifier, TSClassifierConfig, TSForecaster, TSForecasterConfig, TSRegressor,
117        TSRegressorConfig,
118    };
119}