Expand description
§tsai_train
Training loop, callbacks, metrics, and checkpointing for tsai-rs.
This crate provides:
Learnerfor managing the training process- Callback system with lifecycle hooks
- Learning rate schedulers (one-cycle, etc.)
- Metrics (accuracy, MSE, MAE)
- Checkpointing and model saving
- Compatibility facades (TSClassifier, TSRegressor, TSForecaster)
§Example
ⓘ
use tsai_train::{Learner, LearnerConfig};
use tsai_data::TSDataLoaders;
use tsai_models::InceptionTimePlus;
let model = InceptionTimePlus::new(config, &device);
let learner = Learner::new(model, dls)
.with_optimizer(Adam::new(1e-3))
.with_loss(CrossEntropyLoss::new());
learner.fit_one_cycle(10, 1e-3)?;Re-exports§
pub use callback::BatchSubsamplerCallback;pub use callback::Callback;pub use callback::CallbackContext;pub use callback::CallbackList;pub use callback::CheckpointMetadata;pub use callback::EarlyStoppingCallback;pub use callback::GradientClipCallback;pub use callback::GradientClipMode;pub use callback::HistoryCallback;pub use callback::MixedPrecisionCallback;pub use callback::NoiseInjection;pub use callback::NoisyStudentCallback;pub use callback::NoisyStudentStats;pub use callback::PredictionDynamicsCallback;pub use callback::PredictionDynamicsSummary;pub use callback::PredictionTrackingMode;pub use callback::ProgressCallback;pub use callback::PseudoLabel;pub use callback::PseudoLabelFilter;pub use callback::SamplePredictionHistory;pub use callback::SaveModelCallback;pub use callback::SaveModelMode;pub use callback::ShowGraphCallback;pub use callback::SubsampleStrategy;pub use callback::TerminateOnNanCallback;pub use callback::TransformSchedule;pub use callback::TransformSchedulerCallback;pub use callback::WeightedPerSampleLossCallback;pub use callback::WeightStrategy;pub use error::Result;pub use error::TrainError;pub use learner::Learner;pub use learner::LearnerConfig;pub use learner::TrainingState;pub use losses::CrossEntropyLoss;pub use losses::FocalLoss;pub use losses::HuberLoss;pub use losses::LabelSmoothingLoss;pub use losses::LogCoshLoss;pub use losses::MSELoss;pub use optimizer::RAdam;pub use optimizer::RAdamConfig;pub use optimizer::Ranger;pub use optimizer::RangerConfig;pub use metrics::Accuracy;pub use metrics::F1Score;pub use metrics::Metric;pub use metrics::Precision;pub use metrics::Recall;pub use metrics::AUC;pub use metrics::MCC;pub use metrics::MAE;pub use metrics::MAPE;pub use metrics::MSE;pub use metrics::RMSE;pub use scheduler::ConstantLR;pub use scheduler::CosineAnnealingLR;pub use scheduler::CosineAnnealingWarmRestarts;pub use scheduler::ExponentialLR;pub use scheduler::LinearWarmup;pub use scheduler::OneCycleLR;pub use scheduler::PolynomialLR;pub use scheduler::ReduceLROnPlateau;pub use scheduler::ReduceMode;pub use scheduler::Scheduler;pub use scheduler::StepLR;pub use training::train_classification;pub use training::train_regression;pub use training::ClassificationTrainer;pub use training::ClassificationTrainerConfig;pub use training::TrainingOutput;pub use training::RegressionTrainer;pub use training::RegressionTrainerConfig;pub use training::RegressionOutput;pub use evaluation::evaluate_classification;pub use evaluation::ConfusionMatrix;pub use evaluation::EvaluationResult;pub use export::quick_load;pub use export::quick_save;pub use export::save_model_bundle;pub use export::ExportMetadata;pub use export::LearnerExport;pub use hpo::GridSearch;pub use hpo::HpoError;pub use hpo::HyperparameterSpace;pub use hpo::ParamSet;pub use hpo::ParamValue;pub use hpo::RandomSearch;pub use hpo::SearchResult;pub use hpo::SuccessiveHalving;pub use hpo::TrialResult;
Modules§
- callback
- Callback system for training hooks.
- compat
- Compatibility facades for sklearn-like API.
- error
- Error types for training.
- evaluation
- Model evaluation utilities.
- export
- Learner export and import utilities.
- hpo
- Hyperparameter optimization utilities.
- learner
- Learner for managing training.
- losses
- Loss functions.
- metrics
- Training metrics.
- optimizer
- Custom optimizers for time series deep learning.
- scheduler
- Learning rate schedulers.
- training
- Training loop implementation.
Traits§
- TSClassification
Model - Trait for time series classification models.
- TSForecasting
Model - Trait for time series forecasting models.
- TSRegression
Model - Trait for time series regression models.