Crate tensorlogic_train

Crate tensorlogic_train 

Source
Expand description

Training scaffolds: loss wiring, schedules, callbacks.

This crate provides comprehensive training infrastructure for Tensorlogic models:

  • Loss functions (standard and logical constraint-based)
  • Optimizer wrappers around SciRS2
  • Training loops with callbacks
  • Batch management
  • Validation and metrics
  • Regularization techniques
  • Data augmentation
  • Logging and monitoring
  • Curriculum learning strategies
  • Transfer learning utilities
  • Hyperparameter optimization (grid search, random search)
  • Cross-validation utilities
  • Model ensembling

Structs§

Accuracy
Accuracy metric for classification.
AdaBeliefOptimizer
AdaBelief optimizer (NeurIPS 2020).
AdaMaxOptimizer
AdaMax optimizer (variant of Adam with infinity norm).
AdagradOptimizer
Adagrad optimizer (Adaptive Gradient).
AdamOptimizer
Adam optimizer.
AdamWOptimizer
AdamW optimizer (Adam with decoupled weight decay).
AveragingEnsemble
Averaging ensemble for regression.
BCEWithLogitsLoss
Binary cross-entropy with logits loss (numerically stable).
BaggingHelper
Bagging (Bootstrap Aggregating) utilities.
BalancedAccuracy
Balanced accuracy metric. Average of recall per class, useful for imbalanced datasets.
BatchCallback
Callback that logs batch progress.
BatchConfig
Configuration for batch processing.
BatchIterator
Iterator over batches of data.
CallbackList
List of callbacks to execute in order.
CheckpointCallback
Callback for model checkpointing.
CohensKappa
Cohen’s Kappa statistic. Measures inter-rater agreement, accounting for chance agreement. Ranges from -1 to +1, where 1 is perfect agreement, 0 is random chance.
CompetenceCurriculum
Competence-based curriculum: adapts to model’s current competence level.
CompositeAugmenter
Composite augmenter that applies multiple augmentations sequentially.
CompositeRegularization
Composite regularization that combines multiple regularizers.
ConfusionMatrix
Confusion matrix for multi-class classification.
ConsoleLogger
Console logger that outputs to stdout.
ConstraintViolationLoss
Constraint violation loss - penalizes constraint violations.
ContrastiveLoss
Contrastive loss for metric learning. Used to learn embeddings where similar pairs are close and dissimilar pairs are far apart.
CosineAnnealingLrScheduler
Cosine annealing learning rate scheduler. Anneals learning rate using a cosine schedule.
CrossEntropyLoss
Cross-entropy loss for classification.
CrossValidationResults
Cross-validation result aggregator.
CurriculumManager
Manager for curriculum learning that tracks training progress.
CyclicLrScheduler
Cyclic learning rate scheduler.
DataShuffler
Data shuffler for randomizing training data.
DiceLoss
Dice loss for segmentation tasks.
DiscriminativeFineTuning
Discriminative fine-tuning: use different learning rates for different layers.
EarlyStoppingCallback
Callback for early stopping based on validation loss.
ElasticNetRegularization
Elastic Net regularization (combination of L1 and L2).
EpochCallback
Callback that logs training progress.
ExponentialCurriculum
Exponential curriculum: exponentially increase sample percentage.
ExponentialLrScheduler
Exponential learning rate scheduler. Decreases learning rate by a factor of gamma every epoch.
F1Score
F1 score metric for classification.
FeatureExtractorMode
Feature extraction mode: freeze entire feature extractor.
FileLogger
File logger that writes logs to a file.
FocalLoss
Focal loss for addressing class imbalance. Reference: Lin et al., “Focal Loss for Dense Object Detection”
GradientAccumulationCallback
Gradient Accumulation callback.
GradientMonitor
Gradient flow monitor for tracking gradient statistics during training.
GradientSummary
Summary of gradient statistics.
GridSearch
Grid search strategy for hyperparameter optimization.
HingeLoss
Hinge loss for maximum-margin classification (SVM-style).
HistogramCallback
Callback for tracking weight histograms during training.
HistogramStats
Weight histogram statistics for debugging and monitoring.
HuberLoss
Huber loss for robust regression.
HyperparamResult
Result of a hyperparameter evaluation.
KFold
K-fold cross-validation.
KLDivergenceLoss
Kullback-Leibler Divergence loss. Measures how one probability distribution diverges from a reference distribution.
L1Regularization
L1 regularization (Lasso).
L2Regularization
L2 regularization (Ridge / Weight Decay).
LambOptimizer
LAMB optimizer (Layer-wise Adaptive Moments optimizer for Batch training). Designed for large batch training, uses layer-wise adaptation.
LarsOptimizer
LARS optimizer (Layer-wise Adaptive Rate Scaling).
LayerFreezingConfig
Layer freezing configuration for transfer learning.
LearningRateFinder
Learning rate finder callback using the LR range test.
LeaveOneOut
Leave-one-out cross-validation.
LinearCurriculum
Linear curriculum: gradually increase the percentage of samples used.
LinearModel
A simple linear model for testing and demonstration.
LogicalLoss
Logical loss combining multiple objectives.
LookaheadOptimizer
Lookahead optimizer (wrapper that uses slow and fast weights).
LossConfig
Configuration for loss functions.
MatthewsCorrelationCoefficient
Matthews Correlation Coefficient (MCC) metric. Ranges from -1 to +1, where +1 is perfect prediction, 0 is random, -1 is total disagreement. Particularly useful for imbalanced datasets.
MetricTracker
Metric tracker for managing multiple metrics.
MetricsLogger
Metrics logger that aggregates and logs training metrics.
MixupAugmenter
Mixup augmentation.
ModelEMACallback
Model EMA (Exponential Moving Average) callback.
MseLoss
Mean squared error loss for regression.
MultiStepLrScheduler
Multi-step learning rate scheduler.
NAdamOptimizer
NAdam optimizer (Nesterov-accelerated Adam).
NoAugmentation
No augmentation (identity transformation).
NoamScheduler
Noam scheduler (Transformer learning rate schedule).
NoiseAugmenter
Gaussian noise augmentation.
OneCycleLrScheduler
One-cycle learning rate scheduler. Increases LR from initial to max, then decreases to min.
OptimizerConfig
Configuration for optimizers.
PerClassMetrics
Per-class metrics report.
PolynomialDecayLrScheduler
Polynomial decay learning rate scheduler.
Precision
Precision metric for classification.
ProfilingCallback
Callback for profiling training performance.
ProfilingStats
Performance profiling statistics.
ProgressiveUnfreezing
Progressive unfreezing strategy for transfer learning.
RAdamOptimizer
RAdam optimizer (Rectified Adam) with variance warmup (ICLR 2020).
RMSpropOptimizer
RMSprop optimizer (Root Mean Square Propagation).
RandomSearch
Random search strategy for hyperparameter optimization.
Recall
Recall metric for classification.
ReduceLROnPlateauScheduler
Reduce learning rate on plateau (metric-based adaptive scheduler).
ReduceLrOnPlateauCallback
Callback for learning rate reduction on plateau.
RocCurve
ROC curve and AUC computation utilities.
RotationAugmenter
Rotation augmentation (placeholder for future implementation).
RuleSatisfactionLoss
Rule satisfaction loss - measures how well rules are satisfied.
SWACallback
SWA (Stochastic Weight Averaging) callback.
SamOptimizer
SAM optimizer (Sharpness Aware Minimization).
ScaleAugmenter
Scale augmentation.
SelfPacedCurriculum
Self-paced learning: model determines its own learning pace.
SgdOptimizer
SGD optimizer with momentum.
StackingEnsemble
Stacking ensemble with a meta-learner.
StepLrScheduler
Step-based learning rate scheduler. Decreases learning rate by a factor every step_size epochs.
StratifiedKFold
Stratified K-fold cross-validation.
TaskCurriculum
Task-level curriculum for multi-task learning.
TensorBoardLogger
TensorBoard logger (placeholder for future implementation).
TimeSeriesSplit
Time series split for temporal data.
TopKAccuracy
Top-K accuracy metric. Measures whether the correct class is in the top K predictions.
Trainer
Main trainer for model training.
TrainerConfig
Configuration for training.
TrainingCheckpoint
Comprehensive checkpoint data structure.
TrainingHistory
Training history containing losses and metrics.
TrainingState
Training state passed to callbacks.
TransferLearningManager
Transfer learning strategy manager.
TripletLoss
Triplet loss for metric learning. Learns embeddings where anchor-positive distance < anchor-negative distance + margin.
TverskyLoss
Tversky loss (generalization of Dice loss). Useful for handling class imbalance in segmentation.
ValidationCallback
Callback for validation during training.
VotingEnsemble
Voting ensemble configuration.
WarmupCosineLrScheduler
Warmup with cosine annealing scheduler.

Enums§

CyclicLrMode
Cyclic learning rate mode.
GradClipMode
Gradient clipping mode.
HyperparamSpace
Hyperparameter space definition.
HyperparamValue
Hyperparameter value type.
PlateauMode
Mode for ReduceLROnPlateau scheduler.
TrainError
Errors that can occur during training.
VotingMode
Voting ensemble for classification.

Traits§

AutodiffModel
Trait for models that support automatic differentiation via scirs2-autograd.
Callback
Trait for training callbacks.
CrossValidationSplit
Trait for cross-validation splitting strategies.
CurriculumStrategy
Trait for curriculum learning strategies.
DataAugmenter
Trait for data augmentation strategies.
DynamicModel
Trait for models with dynamic computation graphs.
Ensemble
Trait for ensemble methods.
LoggingBackend
Trait for logging backends.
Loss
Trait for loss functions.
LrScheduler
Trait for learning rate schedulers.
Metric
Trait for metrics.
Model
Trait for trainable models.
Optimizer
Trait for optimizers.
Regularizer
Trait for regularization strategies.

Functions§

extract_batch
Extract batches from data arrays.

Type Aliases§

HyperparamConfig
Hyperparameter configuration (a single point in parameter space).
TrainResult
Result type for training operations.