Crate tsai_train

Crate tsai_train 

Source
Expand description

§tsai_train

Training loop, callbacks, metrics, and checkpointing for tsai-rs.

This crate provides:

  • Learner for managing the training process
  • Callback system with lifecycle hooks
  • Learning rate schedulers (one-cycle, etc.)
  • Metrics (accuracy, MSE, MAE)
  • Checkpointing and model saving
  • Compatibility facades (TSClassifier, TSRegressor, TSForecaster)

§Example

use tsai_train::{Learner, LearnerConfig};
use tsai_data::TSDataLoaders;
use tsai_models::InceptionTimePlus;

let model = InceptionTimePlus::new(config, &device);
let learner = Learner::new(model, dls)
    .with_optimizer(Adam::new(1e-3))
    .with_loss(CrossEntropyLoss::new());

learner.fit_one_cycle(10, 1e-3)?;

Re-exports§

pub use callback::BatchSubsamplerCallback;
pub use callback::Callback;
pub use callback::CallbackContext;
pub use callback::CallbackList;
pub use callback::CheckpointMetadata;
pub use callback::EarlyStoppingCallback;
pub use callback::GradientClipCallback;
pub use callback::GradientClipMode;
pub use callback::HistoryCallback;
pub use callback::MixedPrecisionCallback;
pub use callback::NoiseInjection;
pub use callback::NoisyStudentCallback;
pub use callback::NoisyStudentStats;
pub use callback::PredictionDynamicsCallback;
pub use callback::PredictionDynamicsSummary;
pub use callback::PredictionTrackingMode;
pub use callback::ProgressCallback;
pub use callback::PseudoLabel;
pub use callback::PseudoLabelFilter;
pub use callback::SamplePredictionHistory;
pub use callback::SaveModelCallback;
pub use callback::SaveModelMode;
pub use callback::ShowGraphCallback;
pub use callback::SubsampleStrategy;
pub use callback::TerminateOnNanCallback;
pub use callback::TransformSchedule;
pub use callback::TransformSchedulerCallback;
pub use callback::WeightedPerSampleLossCallback;
pub use callback::WeightStrategy;
pub use error::Result;
pub use error::TrainError;
pub use learner::Learner;
pub use learner::LearnerConfig;
pub use learner::TrainingState;
pub use losses::CrossEntropyLoss;
pub use losses::FocalLoss;
pub use losses::HuberLoss;
pub use losses::LabelSmoothingLoss;
pub use losses::LogCoshLoss;
pub use losses::MSELoss;
pub use optimizer::RAdam;
pub use optimizer::RAdamConfig;
pub use optimizer::Ranger;
pub use optimizer::RangerConfig;
pub use metrics::Accuracy;
pub use metrics::F1Score;
pub use metrics::Metric;
pub use metrics::Precision;
pub use metrics::Recall;
pub use metrics::AUC;
pub use metrics::MCC;
pub use metrics::MAE;
pub use metrics::MAPE;
pub use metrics::MSE;
pub use metrics::RMSE;
pub use scheduler::ConstantLR;
pub use scheduler::CosineAnnealingLR;
pub use scheduler::CosineAnnealingWarmRestarts;
pub use scheduler::ExponentialLR;
pub use scheduler::LinearWarmup;
pub use scheduler::OneCycleLR;
pub use scheduler::PolynomialLR;
pub use scheduler::ReduceLROnPlateau;
pub use scheduler::ReduceMode;
pub use scheduler::Scheduler;
pub use scheduler::StepLR;
pub use training::train_classification;
pub use training::train_regression;
pub use training::ClassificationTrainer;
pub use training::ClassificationTrainerConfig;
pub use training::TrainingOutput;
pub use training::RegressionTrainer;
pub use training::RegressionTrainerConfig;
pub use training::RegressionOutput;
pub use evaluation::evaluate_classification;
pub use evaluation::ConfusionMatrix;
pub use evaluation::EvaluationResult;
pub use export::quick_load;
pub use export::quick_save;
pub use export::save_model_bundle;
pub use export::ExportMetadata;
pub use export::LearnerExport;
pub use hpo::GridSearch;
pub use hpo::HpoError;
pub use hpo::HyperparameterSpace;
pub use hpo::ParamSet;
pub use hpo::ParamValue;
pub use hpo::RandomSearch;
pub use hpo::SearchResult;
pub use hpo::SuccessiveHalving;
pub use hpo::TrialResult;

Modules§

callback
Callback system for training hooks.
compat
Compatibility facades for sklearn-like API.
error
Error types for training.
evaluation
Model evaluation utilities.
export
Learner export and import utilities.
hpo
Hyperparameter optimization utilities.
learner
Learner for managing training.
losses
Loss functions.
metrics
Training metrics.
optimizer
Custom optimizers for time series deep learning.
scheduler
Learning rate schedulers.
training
Training loop implementation.

Traits§

TSClassificationModel
Trait for time series classification models.
TSForecastingModel
Trait for time series forecasting models.
TSRegressionModel
Trait for time series regression models.