Skip to main content

Crate tensorlogic_train

Crate tensorlogic_train 

Source
Expand description

Training scaffolds: loss wiring, schedules, callbacks.

Version: 0.1.0 | Status: Production Ready

This crate provides comprehensive training infrastructure for Tensorlogic models:

  • Loss functions (standard and logical constraint-based)
  • Optimizer wrappers around SciRS2
  • Training loops with callbacks
  • Batch management
  • Validation and metrics
  • Regularization techniques
  • Data augmentation
  • Logging and monitoring
  • Curriculum learning strategies
  • Transfer learning utilities
  • Hyperparameter optimization (grid search, random search)
  • Cross-validation utilities
  • Model ensembling
  • Model pruning and compression
  • Model quantization (int8, int4, int2)
  • Mixed precision training (FP16, BF16)
  • Advanced sampling strategies

Re-exports§

pub use early_stopping::EarlyStoppingConfig;
pub use early_stopping::EarlyStoppingDecision;
pub use early_stopping::EarlyStoppingMonitor;
pub use early_stopping::MonitorMode;
pub use early_stopping::MultiMetricMonitor;
pub use early_stopping::MultiMetricPolicy;
pub use early_stopping::PlateauDetector;
pub use early_stopping::TrainingProgress;
pub use checkpoint::deserialize_checkpoint;
pub use checkpoint::serialize_checkpoint;
pub use checkpoint::CheckpointError;
pub use checkpoint::CheckpointFormat;
pub use checkpoint::CheckpointManager;
pub use checkpoint::CheckpointMetadata;
pub use checkpoint::LossTracker;
pub use checkpoint::OptimizerCheckpoint;
pub use checkpoint::ParamState;
pub use weight_init::compute_fans;
pub use weight_init::constant_init;
pub use weight_init::gain_for_activation;
pub use weight_init::kaiming_normal;
pub use weight_init::kaiming_uniform;
pub use weight_init::lecun_normal;
pub use weight_init::lecun_uniform;
pub use weight_init::normal_init;
pub use weight_init::ones_init;
pub use weight_init::orthogonal_init;
pub use weight_init::uniform_init;
pub use weight_init::xavier_normal;
pub use weight_init::xavier_uniform;
pub use weight_init::zeros_init;
pub use weight_init::FanMode;
pub use weight_init::InitError;
pub use weight_init::InitRng;
pub use weight_init::InitStats;
pub use online_learning::online_evaluate;
pub use online_learning::Ftrl;
pub use online_learning::OGDLoss;
pub use online_learning::OnlineError;
pub use online_learning::OnlineGradientDescent;
pub use online_learning::OnlineLearner;
pub use online_learning::OnlineStats;
pub use online_learning::OnlineUpdateResult;
pub use online_learning::PAVariant;
pub use online_learning::PassiveAggressive;
pub use online_learning::Perceptron;
pub use adversarial::adversarial_training_loss;
pub use adversarial::fgsm;
pub use adversarial::pgd;
pub use adversarial::project_l1;
pub use adversarial::project_l2;
pub use adversarial::project_linf;
pub use adversarial::robustness_eval;
pub use adversarial::AdversarialError;
pub use adversarial::AdversarialExample;
pub use adversarial::AdversarialTrainStats;
pub use adversarial::AttackConfig;
pub use adversarial::AttackLoss;
pub use adversarial::AttackModel;
pub use adversarial::CrossEntropyAttackLoss;
pub use adversarial::LinearAttackModel;
pub use adversarial::MseAttackLoss;
pub use adversarial::PerturbNorm;
pub use neural_ode::dopri5_solve;
pub use neural_ode::rk4_solve;
pub use neural_ode::AdaptiveSolution;
pub use neural_ode::AdjointResult;
pub use neural_ode::NeuralOde;
pub use neural_ode::OdeError;
pub use neural_ode::OdeFunc;
pub use neural_ode::OdeSolution;
pub use neural_ode::OdeSolverConfig;
pub use lora::LoraAdapter;
pub use lora::LoraConfig;
pub use lora::LoraError;
pub use lora::LoraLayer;

Modules§

adversarial
Adversarial training utilities for TensorLogic.
checkpoint
Optimizer checkpointing: save/load optimizer state (momentum buffers, step counts, etc.)
early_stopping
Early stopping monitor for training loops.
lora
LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning.
neural_ode
Neural ODE (Neural Ordinary Differential Equations) implementation.
online_learning
Online learning algorithms: Perceptron, Passive-Aggressive, OGD, and FTRL.
weight_init
Weight initialization strategies for neural network parameters.

Structs§

AccumulationConfig
Configuration for gradient accumulation.
AccumulationStats
Statistics from gradient accumulation.
Accuracy
Accuracy metric for classification.
AdaBeliefOptimizer
AdaBelief optimizer (NeurIPS 2020).
AdaMaxOptimizer
AdaMax optimizer (variant of Adam with infinity norm).
AdagradOptimizer
Adagrad optimizer (Adaptive Gradient).
AdamOptimizer
Adam optimizer.
AdamPOptimizer
AdamP optimizer with projection-based weight decay.
AdamWOptimizer
AdamW optimizer (Adam with decoupled weight decay).
AttentionTransferLoss
Attention transfer for distillation based on attention maps.
AugRng
A simple deterministic Linear Congruential Generator (LCG) RNG.
AugStats
Statistics comparing original and augmented data.
AugmentationPipeline
A composable, ordered sequence of augmentation steps.
AutocastContext
Automatic Mixed Precision (AMP) context manager.
AveragingEnsemble
Averaging ensemble for regression.
BCEWithLogitsLoss
Binary cross-entropy with logits loss (numerically stable).
BaggingHelper
Bagging (Bootstrap Aggregating) utilities.
BalancedAccuracy
Balanced accuracy metric. Average of recall per class, useful for imbalanced datasets.
BatchCallback
Callback that logs batch progress.
BatchConfig
Configuration for batch processing.
BatchIterator
Iterator over batches of data.
BatchReweighter
Batch reweighting based on sample importance.
BayesianOptimization
Bayesian Optimization for hyperparameter tuning.
CallbackList
List of callbacks to execute in order.
CheckpointCallback
Callback for model checkpointing with auto-cleanup.
ClassBalancedSampler
Class-balanced sampling for imbalanced datasets.
CohensKappa
Cohen’s Kappa statistic. Measures inter-rater agreement, accounting for chance agreement. Ranges from -1 to +1, where 1 is perfect agreement, 0 is random chance.
CompetenceCurriculum
Competence-based curriculum: adapts to model’s current competence level.
CompositeAugmenter
Composite augmenter that applies multiple augmentations sequentially.
CompositeRegularization
Composite regularization that combines multiple regularizers.
ConfusionMatrix
Confusion matrix for multi-class classification.
ConsoleLogger
Console logger that outputs to stdout.
ConstraintViolationLoss
Constraint violation loss - penalizes constraint violations.
ContrastiveLoss
Contrastive loss for metric learning. Used to learn embeddings where similar pairs are close and dissimilar pairs are far apart.
CosineAnnealingLrScheduler
Cosine annealing learning rate scheduler. Anneals learning rate using a cosine schedule.
CosineAnnealingScheduler
Cosine annealing with optional warm restarts (SGDR).
CrossEntropyLoss
Cross-entropy loss for classification.
CrossValidationResults
Cross-validation result aggregator.
CsvLoader
CSV data loader.
CsvLogger
CSV logger for easy data analysis.
CurriculumManager
Manager for curriculum learning that tracks training progress.
CurriculumSampler
Curriculum sampling for progressive difficulty.
CutMixAugmenter
CutMix augmentation (ICCV 2019).
CutOutAugmenter
CutOut augmentation.
CyclicLrScheduler
Cyclic learning rate scheduler.
CyclicalScheduler
Cyclical learning rates (CLR) — oscillates between min_lr and max_lr.
DataPreprocessor
Data preprocessor for normalization and standardization.
DataShuffler
Data shuffler for randomizing training data.
Dataset
Dataset container for training data.
DiceCoefficient
Dice Coefficient metric (F1 Score variant for segmentation).
DiceLoss
Dice loss for segmentation tasks.
DiscriminativeFineTuning
Discriminative fine-tuning: use different learning rates for different layers.
DistillationLoss
Knowledge distillation loss that combines student predictions with teacher soft targets.
DropBlock
DropBlock regularization.
DropPath
DropPath (Stochastic Depth) regularization.
DynamicRangeCalibrator
Dynamic range calibration for post-training quantization.
EarlyStoppingCallback
Callback for early stopping based on validation loss.
ElasticNetRegularization
Elastic Net regularization (combination of L1 and L2).
EpisodeSampler
Episode sampler for N-way K-shot tasks.
EpochCallback
Callback that logs training progress.
ExpectedCalibrationError
Expected Calibration Error (ECE) metric.
ExponentialCurriculum
Exponential curriculum: exponentially increase sample percentage.
ExponentialLrScheduler
Exponential learning rate scheduler. Decreases learning rate by a factor of gamma every epoch.
ExponentialStochasticDepth
Exponential stochastic depth scheduler.
F1Score
F1 score metric for classification.
FeatureDistillationLoss
Feature-based distillation that matches intermediate layer representations.
FeatureExtractorMode
Feature extraction mode: freeze entire feature extractor.
FewShotAccuracy
Few-shot accuracy evaluator.
FileLogger
File logger that writes logs to a file.
FocalLoss
Focal loss for addressing class imbalance. Reference: Lin et al., “Focal Loss for Dense Object Detection”
FocalSampler
Focal sampling strategy.
GaussianProcess
Gaussian Process regressor for Bayesian Optimization.
GcConfig
Configuration for gradient centralization.
GcStats
Statistics for gradient centralization.
GlobalPruner
Global pruning across multiple layers.
GradientAccumulationCallback
Gradient Accumulation callback with advanced features.
GradientAccumulationStats
Statistics for gradient accumulation.
GradientAccumulator
Gradient accumulator managing multiple parameter gradients.
GradientBuffer
A single gradient buffer for one parameter.
GradientCentralization
Gradient Centralization optimizer wrapper.
GradientCheckpointConfig
Gradient checkpointing configuration.
GradientMonitor
Gradient flow monitor for tracking gradient statistics during training.
GradientPruner
Gradient-based pruning (prune weights with smallest gradients).
GradientScaler
Gradient scaler for automatic mixed precision.
GradientStats
Gradient statistics for monitoring gradient flow.
GradientSummary
Summary of gradient statistics.
GridSearch
Grid search strategy for hyperparameter optimization.
GroupLassoRegularization
Group Lasso regularization.
HardNegativeMiner
Hard negative mining for handling imbalanced datasets.
HingeLoss
Hinge loss for maximum-margin classification (SVM-style).
HistogramCallback
Callback for tracking weight histograms during training.
HistogramStats
Weight histogram statistics for debugging and monitoring.
HuberLoss
Huber loss for robust regression.
HyperparamResult
Result of a hyperparameter evaluation.
ImportanceSampler
Importance sampling based on sample scores.
IoU
Intersection over Union (IoU) metric for segmentation tasks.
JsonlLogger
JSONL (JSON Lines) logger for machine-readable output.
KFold
K-fold cross-validation.
KLDivergenceLoss
Kullback-Leibler Divergence loss. Measures how one probability distribution diverges from a reference distribution.
L1Regularization
L1 regularization (Lasso).
L2Regularization
L2 regularization (Ridge / Weight Decay).
LabelEncoder
Label encoder for converting string labels to integers.
LabelSmoothingLoss
Label smoothing cross-entropy loss.
LambOptimizer
LAMB optimizer (Layer-wise Adaptive Moments optimizer for Batch training). Designed for large batch training, uses layer-wise adaptation.
LarsOptimizer
LARS optimizer (Layer-wise Adaptive Rate Scaling).
LayerFreezingConfig
Layer freezing configuration for transfer learning.
LayerPruningStats
Pruning statistics for a single layer.
LearningRateFinder
Learning rate finder callback using the LR range test.
LeaveOneOut
Leave-one-out cross-validation.
LinearCurriculum
Linear curriculum: gradually increase the percentage of samples used.
LinearDropBlockScheduler
Linear DropBlock scheduler.
LinearModel
A simple linear model for testing and demonstration.
LinearStochasticDepth
Linear stochastic depth scheduler.
LionConfig
Lion optimizer configuration.
LionOptimizer
Lion optimizer.
LogicalLoss
Logical loss combining multiple objectives.
LookaheadOptimizer
Lookahead optimizer (wrapper that uses slow and fast weights).
LossConfig
Configuration for loss functions.
LrRangeTestAnalyzer
Learning rate range test analyzer for finding optimal learning rates.
MAML
MAML (Model-Agnostic Meta-Learning) implementation.
MAMLConfig
MAML (Model-Agnostic Meta-Learning) configuration.
MagnitudePruner
Magnitude-based pruning (prune smallest weights).
MatchingNetwork
Matching network for few-shot learning.
MatthewsCorrelationCoefficient
Matthews Correlation Coefficient (MCC) metric. Ranges from -1 to +1, where +1 is perfect prediction, 0 is random, -1 is total disagreement. Particularly useful for imbalanced datasets.
MaxNormRegularization
MaxNorm constraint regularizer.
MaximumCalibrationError
Maximum Calibration Error (MCE) metric.
MeanAveragePrecision
Mean Average Precision (mAP) metric for object detection and retrieval.
MeanIoU
Mean Intersection over Union (mIoU) metric for multi-class segmentation.
MemoryBudgetManager
Memory budget manager for training.
MemoryEfficientTraining
Memory-efficient training utilities.
MemoryProfilerCallback
Memory profiler callback for tracking memory usage during training.
MemorySettings
Recommended memory settings.
MemoryStats
Memory statistics for a training session.
MetaStats
Meta-learning statistics tracker.
MetaTask
Meta-learning task representation.
MetricTracker
Metric tracker for managing multiple metrics.
MetricsLogger
Metrics logger that aggregates and logs training metrics.
MixedPrecisionStats
Statistics for mixed precision training.
MixedPrecisionTrainer
Mixed precision training manager.
MixupAugmenter
Mixup augmentation.
MixupLoss
Mixup data augmentation that mixes training examples and their labels.
ModelEMACallback
Model EMA (Exponential Moving Average) callback.
ModelSoup
Model Soup - Weight-space averaging for improved generalization.
ModelSummary
Model summary containing layer-wise parameter information.
MseLoss
Mean squared error loss for regression.
MultiStepLrScheduler
Multi-step learning rate scheduler.
MultiTaskLoss
Multi-task loss that combines multiple losses with configurable weighting.
NAdamOptimizer
NAdam optimizer (Nesterov-accelerated Adam).
NoAugmentation
No augmentation (identity transformation).
NoamScheduler
Noam scheduler (Transformer learning rate schedule).
NoiseAugmenter
Gaussian noise augmentation.
NormalizedDiscountedCumulativeGain
Normalized Discounted Cumulative Gain (NDCG) metric for ranking.
OneCycleLrScheduler
One-cycle learning rate scheduler. Increases LR from initial to max, then decreases to min.
OneCyclePolicyScheduler
One-cycle learning rate policy.
OneHotEncoder
One-hot encoder for categorical data.
OnlineHardExampleMiner
Online hard example mining during training.
OptimizerConfig
Configuration for optimizers.
OrthogonalRegularization
Orthogonal regularization.
PCGrad
PCGrad: Project conflicting gradients for multi-task learning.
ParameterDifference
Statistics about parameter differences between two models.
ParameterStats
Model parameter statistics for a single layer or the entire model.
PerClassMetrics
Per-class metrics report.
PolyLoss
Poly Loss - Polynomial Expansion of Cross-Entropy Loss.
PolynomialDecayLrScheduler
Polynomial decay learning rate scheduler.
Precision
Precision metric for classification.
ProdigyConfig
Configuration for Prodigy optimizer
ProdigyOptimizer
Prodigy optimizer
ProfilingCallback
Callback for profiling training performance.
ProfilingStats
Performance profiling statistics.
ProgressiveUnfreezing
Progressive unfreezing strategy for transfer learning.
PrototypicalDistance
Prototypical distance calculator for few-shot learning.
PruningConfig
Configuration for pruning strategies.
PruningStats
Statistics about pruned model.
QuantizationAwareTraining
Quantization-aware training (QAT) utilities.
QuantizationConfig
Configuration for quantization.
QuantizationParams
Quantization parameters (scale and zero-point).
QuantizedTensor
Quantized tensor representation.
Quantizer
Main quantizer for model compression.
RAdamOptimizer
RAdam optimizer (Rectified Adam) with variance warmup (ICLR 2020).
RMSpropOptimizer
RMSprop optimizer (Root Mean Square Propagation).
RandomErasingAugmenter
Random Erasing augmentation.
RandomSearch
Random search strategy for hyperparameter optimization.
Recall
Recall metric for classification.
ReduceLROnPlateauScheduler
Reduce learning rate on plateau (metric-based adaptive scheduler).
ReduceLrOnPlateauCallback
Callback for learning rate reduction on plateau.
Reptile
Reptile meta-learning algorithm.
ReptileConfig
Reptile algorithm configuration.
RocCurve
ROC curve and AUC computation utilities.
RotationAugmenter
Rotation augmentation (placeholder for future implementation).
RuleSatisfactionLoss
Rule satisfaction loss - measures how well rules are satisfied.
SWACallback
SWA (Stochastic Weight Averaging) callback.
SamOptimizer
SAM optimizer (Sharpness Aware Minimization).
ScaleAugmenter
Scale augmentation.
ScheduleFreeAdamW
Schedule-free AdamW optimizer.
ScheduleFreeConfig
Configuration for schedule-free optimizers.
SchedulerConfig
Builder for creating scheduler configurations.
SelfPacedCurriculum
Self-paced learning: model determines its own learning pace.
SgdOptimizer
SGD optimizer with momentum.
SgdrScheduler
SGDR: Stochastic Gradient Descent with Warm Restarts scheduler.
SophiaConfig
Configuration for Sophia optimizer with additional Sophia-specific parameters
SophiaOptimizer
Sophia optimizer - Second-order optimizer with Hessian diagonal estimation
SpectralNormalization
Spectral Normalization regularizer.
StackingEnsemble
Stacking ensemble with a meta-learner.
StepDecayScheduler
Multiplies the learning rate by gamma every step_size steps.
StepLrScheduler
Step-based learning rate scheduler. Decreases learning rate by a factor every step_size epochs.
StratifiedKFold
Stratified K-fold cross-validation.
StructuredPruner
Structured pruning (remove entire neurons/channels/filters).
SupportSet
Support set for few-shot learning.
TaskCurriculum
Task-level curriculum for multi-task learning.
TensorBoardLogger
TensorBoard logger that writes real event files.
TimeEstimator
Training time estimation based on iteration timing.
TimeSeriesSplit
Time series split for temporal data.
TopKAccuracy
Top-K accuracy metric. Measures whether the correct class is in the top K predictions.
Trainer
Main trainer for model training.
TrainerConfig
Configuration for training.
TrainingCheckpoint
Comprehensive checkpoint data structure.
TrainingHistory
Training history containing losses and metrics.
TrainingState
Training state passed to callbacks.
TransferLearningManager
Transfer learning strategy manager.
TripletLoss
Triplet loss for metric learning. Learns embeddings where anchor-positive distance < anchor-negative distance + margin.
TverskyLoss
Tversky loss (generalization of Dice loss). Useful for handling class imbalance in segmentation.
ValidationCallback
Callback for validation during training.
VotingEnsemble
Voting ensemble configuration.
WarmupCosineLrScheduler
Warmup with cosine annealing scheduler.
WarmupScheduler
Linear warmup followed by another scheduler.

Enums§

AccumulationError
Errors that can occur during gradient accumulation.
AcquisitionFunction
Acquisition function type for Bayesian Optimization.
AugmentationError
Errors that can occur during augmentation operations.
AugmentationStep
A single step in an augmentation pipeline.
BitWidth
Bit-width for quantization.
CheckpointCompression
Compression method for checkpoints.
CheckpointStrategy
Gradient checkpointing strategy.
CyclicLrMode
Cyclic learning rate mode.
DistanceMetric
Distance metric for few-shot learning.
GcStrategy
Gradient centralization strategy.
GpKernel
Gaussian Process kernel for Bayesian Optimization.
GradClipMode
Gradient clipping mode.
GradientScalingStrategy
Gradient scaling strategy for accumulation.
Granularity
Quantization granularity (per-tensor or per-channel).
HyperparamSpace
Hyperparameter space definition.
HyperparamValue
Hyperparameter value type.
LossScaler
Loss scaling strategy for mixed precision training.
MiningStrategy
Strategy for mining hard examples.
PlateauMode
Mode for ReduceLROnPlateau scheduler.
PrecisionMode
Precision mode for mixed precision training.
PreprocessingMethod
Preprocessing method.
QuantizationMode
Quantization mode determines the quantization strategy.
ReweightingStrategy
Strategy for reweighting samples.
SchedulerError
Error types for scheduler operations.
SchedulerType
Enum identifying the scheduler algorithm.
ShotType
Type of shot configuration for few-shot learning.
SophiaVariant
Variant of Sophia optimizer to use
SoupRecipe
Recipe for creating model soups
StructuredPruningAxis
Axis for structured pruning.
TaskWeightingStrategy
Strategy for weighting multiple tasks.
TrainError
Errors that can occur during training.
VotingMode
Voting ensemble for classification.

Traits§

AutodiffModel
Trait for models that support automatic differentiation via scirs2-autograd.
Callback
Trait for training callbacks.
CrossValidationSplit
Trait for cross-validation splitting strategies.
CurriculumStrategy
Trait for curriculum learning strategies.
DataAugmenter
Trait for data augmentation strategies.
DynamicModel
Trait for models with dynamic computation graphs.
Ensemble
Trait for ensemble methods.
LoggingBackend
Trait for logging backends.
Loss
Trait for loss functions.
LrScheduler
Trait for learning rate schedulers.
LrSchedulerV2
Trait for learning rate schedulers.
MetaLearner
Meta-learner trait for different meta-learning algorithms.
Metric
Trait for metrics.
Model
Trait for trainable models.
Optimizer
Trait for optimizers.
Pruner
Trait for pruning strategies.
Regularizer
Trait for regularization strategies.

Functions§

center_crop_2d
Center crop: crop [crop_h, crop_w] from the center of the last two spatial dims.
clip
Clamp all elements to [min_val, max_val].
compare_models
Compare two models and report differences in parameters.
compute_gradient_stats
Compute gradient statistics for all layers in a gradient dictionary.
cutmix
CutMix: paste a random rectangular region from x2 into x1.
denormalize
Denormalize input: x * std[c] + mean[c] (inverse of normalize).
dropout
Apply inverted dropout: zero each element with probability p; scale survivors by 1/(1−p).
dropout_mask
Generate a binary dropout mask of the given shape.
extract_batch
Extract batches from data arrays.
format_duration
Format a duration in seconds to a human-readable string.
gaussian_noise
Add element-wise Gaussian noise: output = input + N(0, std²).
mixup
Mixup: λ·x1 + (1−λ)·x2 where λ ~ Beta(alpha, alpha).
normalize
Normalize input: (x − mean[c]) / std[c].
print_gradient_report
Print a formatted report of gradient statistics.
random_crop_2d
Random 2-D crop: extract a sub-array of size [.., crop_h, crop_w] at a random position.
random_hflip
Random horizontal flip of the last two spatial dimensions with probability p.
random_vflip
Random vertical flip of the last two spatial dimensions with probability p.

Type Aliases§

HyperparamConfig
Hyperparameter configuration (a single point in parameter space).
PruningMask
Pruning mask indicating which weights are kept (1.0) or removed (0.0).
TrainResult
Result type for training operations.