Skip to main content

Crate tensorlogic_train

Crate tensorlogic_train 

Source
Expand description

Training scaffolds: loss wiring, schedules, callbacks.

Version: 0.1.0-beta.1 | Status: Production Ready

This crate provides comprehensive training infrastructure for Tensorlogic models:

  • Loss functions (standard and logical constraint-based)
  • Optimizer wrappers around SciRS2
  • Training loops with callbacks
  • Batch management
  • Validation and metrics
  • Regularization techniques
  • Data augmentation
  • Logging and monitoring
  • Curriculum learning strategies
  • Transfer learning utilities
  • Hyperparameter optimization (grid search, random search)
  • Cross-validation utilities
  • Model ensembling
  • Model pruning and compression
  • Model quantization (int8, int4, int2)
  • Mixed precision training (FP16, BF16)
  • Advanced sampling strategies

Structs§

Accuracy
Accuracy metric for classification.
AdaBeliefOptimizer
AdaBelief optimizer (NeurIPS 2020).
AdaMaxOptimizer
AdaMax optimizer (variant of Adam with infinity norm).
AdagradOptimizer
Adagrad optimizer (Adaptive Gradient).
AdamOptimizer
Adam optimizer.
AdamPOptimizer
AdamP optimizer with projection-based weight decay.
AdamWOptimizer
AdamW optimizer (Adam with decoupled weight decay).
AttentionTransferLoss
Attention transfer for distillation based on attention maps.
AutocastContext
Automatic Mixed Precision (AMP) context manager.
AveragingEnsemble
Averaging ensemble for regression.
BCEWithLogitsLoss
Binary cross-entropy with logits loss (numerically stable).
BaggingHelper
Bagging (Bootstrap Aggregating) utilities.
BalancedAccuracy
Balanced accuracy metric. Average of recall per class, useful for imbalanced datasets.
BatchCallback
Callback that logs batch progress.
BatchConfig
Configuration for batch processing.
BatchIterator
Iterator over batches of data.
BatchReweighter
Batch reweighting based on sample importance.
BayesianOptimization
Bayesian Optimization for hyperparameter tuning.
CallbackList
List of callbacks to execute in order.
CheckpointCallback
Callback for model checkpointing with auto-cleanup.
ClassBalancedSampler
Class-balanced sampling for imbalanced datasets.
CohensKappa
Cohen’s Kappa statistic. Measures inter-rater agreement, accounting for chance agreement. Ranges from -1 to +1, where 1 is perfect agreement, 0 is random chance.
CompetenceCurriculum
Competence-based curriculum: adapts to model’s current competence level.
CompositeAugmenter
Composite augmenter that applies multiple augmentations sequentially.
CompositeRegularization
Composite regularization that combines multiple regularizers.
ConfusionMatrix
Confusion matrix for multi-class classification.
ConsoleLogger
Console logger that outputs to stdout.
ConstraintViolationLoss
Constraint violation loss - penalizes constraint violations.
ContrastiveLoss
Contrastive loss for metric learning. Used to learn embeddings where similar pairs are close and dissimilar pairs are far apart.
CosineAnnealingLrScheduler
Cosine annealing learning rate scheduler. Anneals learning rate using a cosine schedule.
CrossEntropyLoss
Cross-entropy loss for classification.
CrossValidationResults
Cross-validation result aggregator.
CsvLoader
CSV data loader.
CsvLogger
CSV logger for easy data analysis.
CurriculumManager
Manager for curriculum learning that tracks training progress.
CurriculumSampler
Curriculum sampling for progressive difficulty.
CutMixAugmenter
CutMix augmentation (ICCV 2019).
CutOutAugmenter
CutOut augmentation.
CyclicLrScheduler
Cyclic learning rate scheduler.
DataPreprocessor
Data preprocessor for normalization and standardization.
DataShuffler
Data shuffler for randomizing training data.
Dataset
Dataset container for training data.
DiceCoefficient
Dice Coefficient metric (F1 Score variant for segmentation).
DiceLoss
Dice loss for segmentation tasks.
DiscriminativeFineTuning
Discriminative fine-tuning: use different learning rates for different layers.
DistillationLoss
Knowledge distillation loss that combines student predictions with teacher soft targets.
DropBlock
DropBlock regularization.
DropPath
DropPath (Stochastic Depth) regularization.
DynamicRangeCalibrator
Dynamic range calibration for post-training quantization.
EarlyStoppingCallback
Callback for early stopping based on validation loss.
ElasticNetRegularization
Elastic Net regularization (combination of L1 and L2).
EpisodeSampler
Episode sampler for N-way K-shot tasks.
EpochCallback
Callback that logs training progress.
ExpectedCalibrationError
Expected Calibration Error (ECE) metric.
ExponentialCurriculum
Exponential curriculum: exponentially increase sample percentage.
ExponentialLrScheduler
Exponential learning rate scheduler. Decreases learning rate by a factor of gamma every epoch.
ExponentialStochasticDepth
Exponential stochastic depth scheduler.
F1Score
F1 score metric for classification.
FeatureDistillationLoss
Feature-based distillation that matches intermediate layer representations.
FeatureExtractorMode
Feature extraction mode: freeze entire feature extractor.
FewShotAccuracy
Few-shot accuracy evaluator.
FileLogger
File logger that writes logs to a file.
FocalLoss
Focal loss for addressing class imbalance. Reference: Lin et al., “Focal Loss for Dense Object Detection”
FocalSampler
Focal sampling strategy.
GaussianProcess
Gaussian Process regressor for Bayesian Optimization.
GcConfig
Configuration for gradient centralization.
GcStats
Statistics for gradient centralization.
GlobalPruner
Global pruning across multiple layers.
GradientAccumulationCallback
Gradient Accumulation callback with advanced features.
GradientAccumulationStats
Statistics for gradient accumulation.
GradientCentralization
Gradient Centralization optimizer wrapper.
GradientCheckpointConfig
Gradient checkpointing configuration.
GradientMonitor
Gradient flow monitor for tracking gradient statistics during training.
GradientPruner
Gradient-based pruning (prune weights with smallest gradients).
GradientScaler
Gradient scaler for automatic mixed precision.
GradientStats
Gradient statistics for monitoring gradient flow.
GradientSummary
Summary of gradient statistics.
GridSearch
Grid search strategy for hyperparameter optimization.
GroupLassoRegularization
Group Lasso regularization.
HardNegativeMiner
Hard negative mining for handling imbalanced datasets.
HingeLoss
Hinge loss for maximum-margin classification (SVM-style).
HistogramCallback
Callback for tracking weight histograms during training.
HistogramStats
Weight histogram statistics for debugging and monitoring.
HuberLoss
Huber loss for robust regression.
HyperparamResult
Result of a hyperparameter evaluation.
ImportanceSampler
Importance sampling based on sample scores.
IoU
Intersection over Union (IoU) metric for segmentation tasks.
JsonlLogger
JSONL (JSON Lines) logger for machine-readable output.
KFold
K-fold cross-validation.
KLDivergenceLoss
Kullback-Leibler Divergence loss. Measures how one probability distribution diverges from a reference distribution.
L1Regularization
L1 regularization (Lasso).
L2Regularization
L2 regularization (Ridge / Weight Decay).
LabelEncoder
Label encoder for converting string labels to integers.
LabelSmoothingLoss
Label smoothing cross-entropy loss.
LambOptimizer
LAMB optimizer (Layer-wise Adaptive Moments optimizer for Batch training). Designed for large batch training, uses layer-wise adaptation.
LarsOptimizer
LARS optimizer (Layer-wise Adaptive Rate Scaling).
LayerFreezingConfig
Layer freezing configuration for transfer learning.
LayerPruningStats
Pruning statistics for a single layer.
LearningRateFinder
Learning rate finder callback using the LR range test.
LeaveOneOut
Leave-one-out cross-validation.
LinearCurriculum
Linear curriculum: gradually increase the percentage of samples used.
LinearDropBlockScheduler
Linear DropBlock scheduler.
LinearModel
A simple linear model for testing and demonstration.
LinearStochasticDepth
Linear stochastic depth scheduler.
LionConfig
Lion optimizer configuration.
LionOptimizer
Lion optimizer.
LogicalLoss
Logical loss combining multiple objectives.
LookaheadOptimizer
Lookahead optimizer (wrapper that uses slow and fast weights).
LossConfig
Configuration for loss functions.
LrRangeTestAnalyzer
Learning rate range test analyzer for finding optimal learning rates.
MAML
MAML (Model-Agnostic Meta-Learning) implementation.
MAMLConfig
MAML (Model-Agnostic Meta-Learning) configuration.
MagnitudePruner
Magnitude-based pruning (prune smallest weights).
MatchingNetwork
Matching network for few-shot learning.
MatthewsCorrelationCoefficient
Matthews Correlation Coefficient (MCC) metric. Ranges from -1 to +1, where +1 is perfect prediction, 0 is random, -1 is total disagreement. Particularly useful for imbalanced datasets.
MaxNormRegularization
MaxNorm constraint regularizer.
MaximumCalibrationError
Maximum Calibration Error (MCE) metric.
MeanAveragePrecision
Mean Average Precision (mAP) metric for object detection and retrieval.
MeanIoU
Mean Intersection over Union (mIoU) metric for multi-class segmentation.
MemoryBudgetManager
Memory budget manager for training.
MemoryEfficientTraining
Memory-efficient training utilities.
MemoryProfilerCallback
Memory profiler callback for tracking memory usage during training.
MemorySettings
Recommended memory settings.
MemoryStats
Memory statistics for a training session.
MetaStats
Meta-learning statistics tracker.
MetaTask
Meta-learning task representation.
MetricTracker
Metric tracker for managing multiple metrics.
MetricsLogger
Metrics logger that aggregates and logs training metrics.
MixedPrecisionStats
Statistics for mixed precision training.
MixedPrecisionTrainer
Mixed precision training manager.
MixupAugmenter
Mixup augmentation.
MixupLoss
Mixup data augmentation that mixes training examples and their labels.
ModelEMACallback
Model EMA (Exponential Moving Average) callback.
ModelSoup
Model Soup - Weight-space averaging for improved generalization.
ModelSummary
Model summary containing layer-wise parameter information.
MseLoss
Mean squared error loss for regression.
MultiStepLrScheduler
Multi-step learning rate scheduler.
MultiTaskLoss
Multi-task loss that combines multiple losses with configurable weighting.
NAdamOptimizer
NAdam optimizer (Nesterov-accelerated Adam).
NoAugmentation
No augmentation (identity transformation).
NoamScheduler
Noam scheduler (Transformer learning rate schedule).
NoiseAugmenter
Gaussian noise augmentation.
NormalizedDiscountedCumulativeGain
Normalized Discounted Cumulative Gain (NDCG) metric for ranking.
OneCycleLrScheduler
One-cycle learning rate scheduler. Increases LR from initial to max, then decreases to min.
OneHotEncoder
One-hot encoder for categorical data.
OnlineHardExampleMiner
Online hard example mining during training.
OptimizerConfig
Configuration for optimizers.
OrthogonalRegularization
Orthogonal regularization.
PCGrad
PCGrad: Project conflicting gradients for multi-task learning.
ParameterDifference
Statistics about parameter differences between two models.
ParameterStats
Model parameter statistics for a single layer or the entire model.
PerClassMetrics
Per-class metrics report.
PolyLoss
Poly Loss - Polynomial Expansion of Cross-Entropy Loss.
PolynomialDecayLrScheduler
Polynomial decay learning rate scheduler.
Precision
Precision metric for classification.
ProdigyConfig
Configuration for Prodigy optimizer
ProdigyOptimizer
Prodigy optimizer
ProfilingCallback
Callback for profiling training performance.
ProfilingStats
Performance profiling statistics.
ProgressiveUnfreezing
Progressive unfreezing strategy for transfer learning.
PrototypicalDistance
Prototypical distance calculator for few-shot learning.
PruningConfig
Configuration for pruning strategies.
PruningStats
Statistics about pruned model.
QuantizationAwareTraining
Quantization-aware training (QAT) utilities.
QuantizationConfig
Configuration for quantization.
QuantizationParams
Quantization parameters (scale and zero-point).
QuantizedTensor
Quantized tensor representation.
Quantizer
Main quantizer for model compression.
RAdamOptimizer
RAdam optimizer (Rectified Adam) with variance warmup (ICLR 2020).
RMSpropOptimizer
RMSprop optimizer (Root Mean Square Propagation).
RandomErasingAugmenter
Random Erasing augmentation.
RandomSearch
Random search strategy for hyperparameter optimization.
Recall
Recall metric for classification.
ReduceLROnPlateauScheduler
Reduce learning rate on plateau (metric-based adaptive scheduler).
ReduceLrOnPlateauCallback
Callback for learning rate reduction on plateau.
Reptile
Reptile meta-learning algorithm.
ReptileConfig
Reptile algorithm configuration.
RocCurve
ROC curve and AUC computation utilities.
RotationAugmenter
Rotation augmentation (placeholder for future implementation).
RuleSatisfactionLoss
Rule satisfaction loss - measures how well rules are satisfied.
SWACallback
SWA (Stochastic Weight Averaging) callback.
SamOptimizer
SAM optimizer (Sharpness Aware Minimization).
ScaleAugmenter
Scale augmentation.
ScheduleFreeAdamW
Schedule-free AdamW optimizer.
ScheduleFreeConfig
Configuration for schedule-free optimizers.
SelfPacedCurriculum
Self-paced learning: model determines its own learning pace.
SgdOptimizer
SGD optimizer with momentum.
SgdrScheduler
SGDR: Stochastic Gradient Descent with Warm Restarts scheduler.
SophiaConfig
Configuration for Sophia optimizer with additional Sophia-specific parameters
SophiaOptimizer
Sophia optimizer - Second-order optimizer with Hessian diagonal estimation
SpectralNormalization
Spectral Normalization regularizer.
StackingEnsemble
Stacking ensemble with a meta-learner.
StepLrScheduler
Step-based learning rate scheduler. Decreases learning rate by a factor every step_size epochs.
StratifiedKFold
Stratified K-fold cross-validation.
StructuredPruner
Structured pruning (remove entire neurons/channels/filters).
SupportSet
Support set for few-shot learning.
TaskCurriculum
Task-level curriculum for multi-task learning.
TensorBoardLogger
TensorBoard logger that writes real event files.
TimeEstimator
Training time estimation based on iteration timing.
TimeSeriesSplit
Time series split for temporal data.
TopKAccuracy
Top-K accuracy metric. Measures whether the correct class is in the top K predictions.
Trainer
Main trainer for model training.
TrainerConfig
Configuration for training.
TrainingCheckpoint
Comprehensive checkpoint data structure.
TrainingHistory
Training history containing losses and metrics.
TrainingState
Training state passed to callbacks.
TransferLearningManager
Transfer learning strategy manager.
TripletLoss
Triplet loss for metric learning. Learns embeddings where anchor-positive distance < anchor-negative distance + margin.
TverskyLoss
Tversky loss (generalization of Dice loss). Useful for handling class imbalance in segmentation.
ValidationCallback
Callback for validation during training.
VotingEnsemble
Voting ensemble configuration.
WarmupCosineLrScheduler
Warmup with cosine annealing scheduler.

Enums§

AcquisitionFunction
Acquisition function type for Bayesian Optimization.
BitWidth
Bit-width for quantization.
CheckpointCompression
Compression method for checkpoints.
CheckpointStrategy
Gradient checkpointing strategy.
CyclicLrMode
Cyclic learning rate mode.
DistanceMetric
Distance metric for few-shot learning.
GcStrategy
Gradient centralization strategy.
GpKernel
Gaussian Process kernel for Bayesian Optimization.
GradClipMode
Gradient clipping mode.
GradientScalingStrategy
Gradient scaling strategy for accumulation.
Granularity
Quantization granularity (per-tensor or per-channel).
HyperparamSpace
Hyperparameter space definition.
HyperparamValue
Hyperparameter value type.
LossScaler
Loss scaling strategy for mixed precision training.
MiningStrategy
Strategy for mining hard examples.
PlateauMode
Mode for ReduceLROnPlateau scheduler.
PrecisionMode
Precision mode for mixed precision training.
PreprocessingMethod
Preprocessing method.
QuantizationMode
Quantization mode determines the quantization strategy.
ReweightingStrategy
Strategy for reweighting samples.
ShotType
Type of shot configuration for few-shot learning.
SophiaVariant
Variant of Sophia optimizer to use
SoupRecipe
Recipe for creating model soups
StructuredPruningAxis
Axis for structured pruning.
TaskWeightingStrategy
Strategy for weighting multiple tasks.
TrainError
Errors that can occur during training.
VotingMode
Voting ensemble for classification.

Traits§

AutodiffModel
Trait for models that support automatic differentiation via scirs2-autograd.
Callback
Trait for training callbacks.
CrossValidationSplit
Trait for cross-validation splitting strategies.
CurriculumStrategy
Trait for curriculum learning strategies.
DataAugmenter
Trait for data augmentation strategies.
DynamicModel
Trait for models with dynamic computation graphs.
Ensemble
Trait for ensemble methods.
LoggingBackend
Trait for logging backends.
Loss
Trait for loss functions.
LrScheduler
Trait for learning rate schedulers.
MetaLearner
Meta-learner trait for different meta-learning algorithms.
Metric
Trait for metrics.
Model
Trait for trainable models.
Optimizer
Trait for optimizers.
Pruner
Trait for pruning strategies.
Regularizer
Trait for regularization strategies.

Functions§

compare_models
Compare two models and report differences in parameters.
compute_gradient_stats
Compute gradient statistics for all layers in a gradient dictionary.
extract_batch
Extract batches from data arrays.
format_duration
Format a duration in seconds to a human-readable string.
print_gradient_report
Print a formatted report of gradient statistics.

Type Aliases§

HyperparamConfig
Hyperparameter configuration (a single point in parameter space).
PruningMask
Pruning mask indicating which weights are kept (1.0) or removed (0.0).
TrainResult
Result type for training operations.