Expand description
Model definitions, losses, checkpoints, and training helpers for yscv.
Modules§
- tcp_
transport - TCP-based transport for multi-node gradient exchange.
Structs§
- Adaptive
AvgPool2d Layer - Adaptive average pooling: output a fixed spatial size regardless of input size.
- Adaptive
MaxPool2d Layer - Adaptive max pooling: output a fixed spatial size.
- AllReduce
Aggregator - All-reduce aggregator: averages gradients across all workers via ring reduce.
- Anchor
Free Head - Anchor-free detection head (FCOS-style, inference-mode, NHWC).
- Architecture
Config - Describes the shape of a model architecture.
- AvgPool2d
Layer - 2D average-pooling layer (NHWC layout).
- Batch
- One deterministic batch from dataset iterator.
- Batch
Collector - Collects individual samples into batches for efficient processing.
- Batch
Iter Options - Controls mini-batch order, truncation behavior, and optional per-batch regularization.
- Batch
Norm2d Layer - 2D batch normalization layer (NHWC layout).
- Best
Model Checkpoint - Saves model weights when monitored metric improves.
- Center
Crop - Crop the center region of an image tensor.
Input shape:
[H, W, C]. Output shape:[size, size, C]. - CnnTrain
Config - Configuration for high-level CNN training.
- Compose
- Chains multiple transforms sequentially.
- Compressed
Gradient - Compressed gradient: stores only the top-k elements by magnitude.
- Conv1d
Layer - 1D convolution layer (NLC layout:
[batch, length, channels]). - Conv2d
Layer - 2D convolution layer (NHWC layout).
- Conv3d
Layer - 3D convolution layer (BDHWC layout).
- Conv
Transpose2d Layer - Transposed 2D convolution layer (NHWC layout).
- Cross
Attention - Cross-attention: query from decoder, key/value from encoder output.
- CutMix
Config - Controls per-batch region replacement interpolation for image tensors.
- Data
Loader - Parallel data loader that prefetches batches using worker threads.
- Data
Loader Batch - A batch of samples produced by the data loader.
- Data
Loader Config - Configuration for the parallel data loader.
- Data
Loader Iter - Iterator over batches produced by worker threads.
- Data
Parallel Config - Configuration for data-parallel distributed training.
- Dataset
Split - Deterministic dataset split produced by
split_by_counts/split_by_ratio. - Deformable
Conv2d Layer - Deformable 2D convolution layer (NHWC layout).
- Depthwise
Conv2d Layer - Depthwise 2D convolution layer (NHWC layout).
- Distributed
Config - Identifies a worker inside a distributed training group.
- Dropout
Layer - Dropout layer (training vs eval mode).
- Dynamic
Batch Config - Dynamic batching configuration for inference.
- Dynamic
Loss Scaler - State for dynamic loss scaling during mixed-precision training.
- Early
Stopping - Early stopping to halt training when a metric stops improving.
- Embedding
Layer - Embedding lookup table: maps integer indices to dense vectors.
- Epoch
Metrics - Metrics for one training epoch.
- Epoch
Train Options - Epoch-level training controls for batch order and preprocessing.
- Exponential
Moving Average - Exponential Moving Average of model parameters.
- Feed
Forward - Feed-forward network: Linear(d_model, d_ff) -> ReLU -> Linear(d_ff, d_model).
- Feed
Forward Layer - Feed-forward layer wrapping
FeedForward. - Flatten
Layer - Flatten layer: reshapes NHWC
[N, H, W, C]to[N, H*W*C]for dense layer input. - FpnNeck
- Feature Pyramid Network lateral + top-down pathway (inference-mode, NHWC).
- GELU
Layer - GELU activation layer.
- Gaussian
Blur - Apply Gaussian blur to an image tensor.
Input shape:
[H, W, C]. - Global
AvgPool2d Layer - Global average pooling: NHWC
[N,H,W,C]->[N,1,1,C]. - Group
Norm Layer - Group normalization: divides channels into groups and normalizes within each group.
- GruCell
- GRU cell: update and reset gates.
- GruLayer
- GRU layer wrapping
gru_forward_sequence. - HubEntry
- Registry entry for a pretrained model.
- Image
Augmentation Pipeline - Ordered per-sample augmentation pipeline for NHWC mini-batch data.
- InProcess
Transport - In-process transport backed by
mpscchannels (for testing). - Inference
Pipeline - Builder-style inference pipeline that wraps a
SequentialModelwith optional pre- and post-processing closures. - Instance
Norm Layer - Instance normalization (normalizes per-sample per-channel).
- Layer
Norm Layer - Layer normalization over the last dimension.
- Leaky
ReLU Layer - Stateless LeakyReLU layer with configurable negative slope.
- Linear
Layer - Dense linear layer:
y = x @ weight + bias. - Local
Aggregator - No-op aggregator for single-machine training (API uniformity).
- Lora
Config - LoRA configuration.
- Lora
Linear - A LoRA adapter for a linear layer.
- LrFinder
Config - Configuration for LR range test.
- LrFinder
Result - Result of an LR range test.
- Lstm
Cell - LSTM cell: standard gates (input, forget, cell, output).
- Lstm
Layer - LSTM layer wrapping
lstm_forward_sequence. - Mask
Head - Mask prediction head for instance segmentation (Mask R-CNN style).
- MaxPool2d
Layer - 2D max-pooling layer (NHWC layout).
- MbConv
Block - MBConv block (EfficientNet / MobileNetV2 inverted residual, inference-mode).
- Metrics
Logger - Logs training metrics to a CSV file and prints a summary line to stdout.
- Mini
Batch Iter - Deterministic sequential mini-batch iterator.
- Mish
Layer - Mish activation layer.
- MixUp
Config - Controls per-batch sample/label interpolation for regularized training.
- Mixed
Precision Config - Mixed-precision training configuration.
- Model
Hub - Model hub for downloading and caching pretrained weights.
- Model
Zoo - File-based pretrained model registry.
- Multi
Head Attention - Multi-head attention weights.
- Multi
Head Attention Config - Multi-head attention configuration.
- Multi
Head Attention Layer - Multi-head attention layer wrapping
MultiHeadAttention. - Normalize
- Normalize channels:
(x - mean) / std - PReLU
Layer - PReLU activation layer. Uses per-channel or single alpha for the negative slope.
- Parameter
Server - Centralized parameter server: rank 0 collects, averages, and broadcasts gradients (or parameters).
- Patch
Embedding - Patch embedding layer for Vision Transformer.
- PerChannel
Quant Result - Per-channel symmetric quantization for conv weights
[KH, KW, C_in, C_out]. - Permute
Dims - Permute dimensions.
- Pipeline
Parallel Config - Configuration for pipeline-parallel training.
- Pipeline
Stage - Pipeline parallelism: split a sequential model across multiple stages.
- Pixel
Shuffle Layer - Pixel shuffle / sub-pixel convolution: rearranges
[N, H, W, C*r^2]->[N, H*r, W*r, C]. - Pruned
Tensor - Result of magnitude-based weight pruning.
- Quantized
Tensor - Quantized tensor representation: INT8 values + per-tensor scale + zero-point.
- Random
Horizontal Flip - Randomly flip horizontally with probability
p. Uses xorshift64 PRNG seeded at construction. Input shape:[H, W, C]. - Random
Sampler - A sampler that yields indices in a random (deterministic) order.
- ReLU
Layer - Stateless ReLU layer.
- Residual
Block - Residual block: runs input through a sequence of layers, then adds the
original input as a skip connection (
output = layers(input) + input). - Resize
- Resize image tensor to target height and width using bilinear interpolation.
Input shape:
[H, W, C]. - RnnCell
- Vanilla RNN cell: h_t = tanh(x_t @ W_ih + h_{t-1} @ W_hh + b).
- RnnLayer
- RNN layer wrapping
rnn_forward_sequence. - Safe
Tensor File - A parsed SafeTensors file backed by an in-memory byte buffer.
- Scale
Values - Scale f32 values by a constant factor.
- Scheduled
Epoch Metrics - Metrics for one scheduler-driven epoch.
- Scheduler
Train Options - Scheduler-driven epoch training controls.
- Separable
Conv2d Layer - Separable 2D convolution layer (NHWC layout).
- Sequential
Checkpoint - Serializable sequential model checkpoint.
- Sequential
Model - Ordered stack of layers executed one-by-one.
- Sequential
Sampler - A sampler that yields indices in sequential order.
- SiLU
Layer - SiLU (Swish) activation layer.
- Sigmoid
Layer - Stateless sigmoid activation layer.
- Softmax
Layer - Softmax layer over the last dimension.
- Squeeze
Excite Block - Squeeze-and-Excitation block (inference-mode).
- Streaming
Data Loader - A data loader that lazily reads batches from disk, using a background thread to prefetch the next batch while the current batch is being processed.
- Supervised
CsvConfig - Configuration for parsing/loading supervised CSV datasets.
- Supervised
Dataset - Supervised dataset with aligned input/target sample axis at position 0.
- Supervised
Image Folder Config - Supervised
Image Folder Load Result - Result payload for image-folder dataset loading with explicit class mapping.
- Supervised
Image Manifest Config - Configuration for parsing/loading supervised image-manifest CSV datasets.
- Supervised
Jsonl Config - Configuration for parsing/loading supervised JSONL datasets.
- Tanh
Layer - Stateless tanh activation layer.
- TcpAll
Reduce Aggregator - Wrapper that uses a
TcpTransportfor gradient aggregation. - TcpTransport
- TCP-based transport for multi-node gradient exchange.
- Tensor
Board Callback - Training callback that logs scalar metrics to TensorBoard event files.
- Tensor
Board Writer - Writes TensorBoard-compatible event files in TFRecord format.
- Tensor
Info - Per-tensor metadata extracted from the SafeTensors JSON header.
- Tensor
Snapshot - Serializable tensor snapshot used in model checkpoints.
- TopK
Compressor - Top-K gradient compressor: keeps only the top
ratiofraction of gradients. - Train
Result - Training result returned after fitting.
- Trainer
- High-level trainer that wraps optimizer + loss + callbacks configuration.
- Trainer
Config - High-level training configuration.
- Training
Log - Records per-epoch training metrics.
- Transformer
Decoder - Stack of
TransformerDecoderBlocklayers. - Transformer
Decoder Block - Single transformer decoder block: masked self-attention → cross-attention → FFN, each sub-layer wrapped with residual connection and layer normalization.
- Transformer
Encoder Block - Transformer encoder block: MHA -> Add&Norm -> FFN -> Add&Norm.
- Transformer
Encoder Layer - Transformer encoder layer wrapping
TransformerEncoderBlock. - UNet
Decoder Stage - UNet decoder stage (inference-mode, NHWC).
- UNet
Encoder Stage - UNet encoder stage (inference-mode, NHWC).
- Upsample
Layer - Upsample layer: nearest or bilinear upsampling.
- Vision
Transformer - Vision Transformer (ViT) for image classification (inference-mode).
- Weighted
Random Sampler - Weighted random sampler: draws
num_samplesindices with probability proportional to weights.
Enums§
- Image
Augmentation Op - Per-sample image augmentations for rank-4 NHWC training tensors.
- Image
Folder Target Mode - Configuration for loading supervised image-folder datasets.
- Layer
Checkpoint - Serializable layer checkpoint payload.
- Loss
Kind - Which loss function to use.
- Model
Architecture - Known model architectures in the zoo.
- Model
Error - Errors returned by model-layer assembly, checkpoints, and training helpers.
- Model
Layer - Monitor
Mode - Mode for metric monitoring.
- Node
Role - Describes whether this node is the coordinator (rank 0) or a worker.
- Optimizer
Kind - Which optimizer to use.
- Optimizer
Type - Multi-epoch CNN training with configurable optimizer type.
- Quant
Mode - Quantization mode.
- Safe
TensorD Type - Supported element types in a SafeTensors file.
- Sampling
Policy - Sample-order policy used by
BatchIterOptions. - Supervised
Loss - Configures supervised-loss function used by train-step and train-epoch helpers.
Constants§
Traits§
- Gradient
Aggregator - Strategy for combining gradients across distributed workers.
- Training
Callback - Trait for training callbacks invoked after each epoch.
- Transform
- Trait for deterministic tensor transforms (preprocessing).
- Transport
- Byte-level communication primitive used by aggregation strategies.
Functions§
- accumulate_
gradients - Adds source gradients into the existing gradients of the given nodes.
- adam_
state_ from_ map - Restore Adam/AdamW state from a string-keyed map.
- adam_
state_ to_ map - Flatten Adam/AdamW state into a string-keyed map for serialization.
- add_
bottleneck_ block - Adds a MobileNetV2-style inverted bottleneck block to a SequentialModel.
- add_
residual_ block - Adds a ResNet-style residual block to a SequentialModel (inference-mode).
- apply_
pruning_ mask - Apply a binary mask to weights (element-wise multiply).
- batched_
inference - Splits a large input into batches, runs inference, and reassembles.
- bce_
loss - Binary cross-entropy loss for predictions already passed through sigmoid.
bce = -mean(target * log(pred) + (1 - target) * log(1 - pred)). - bilstm_
forward_ sequence - Bidirectional LSTM: runs forward and backward LSTMs, concatenates outputs.
- build_
alexnet - Builds a simple AlexNet-style conv stack.
- build_
classifier - Builds a full classifier with a custom number of output classes.
- build_
feature_ extractor - Builds a backbone (feature extractor) without the final classifier head.
- build_
mobilenet_ v2 - Builds a MobileNetV2-style model using inverted bottleneck blocks.
- build_
resnet - Builds a ResNet-family model: stem + residual stages + global-avg-pool + linear head.
- build_
resnet_ custom - Builds a ResNet with per-stage block counts (bypasses the single-count helper).
- build_
resnet_ feature_ extractor - Builds a ResNet-like feature extractor (no final classifier).
- build_
simple_ cnn_ classifier - Builds a simple CNN classifier architecture for NHWC input.
- build_
vgg - Builds a VGG-style sequential conv network.
- cast_
params_ for_ forward - Convert model parameters from master precision to forward precision.
- cast_
to_ master - Cast a list of tensors back to master dtype for gradient accumulation.
- checkpoint_
from_ json - checkpoint_
to_ json - collect_
gradients - Collects the current gradients for a set of nodes as owned tensors.
- compress_
gradients - Compress gradients by keeping only top-k% elements by magnitude.
- constant
- Fill a tensor with a constant value.
- contrastive_
loss - Contrastive loss for siamese networks.
- cosine_
embedding_ loss - Cosine embedding loss.
- cross_
entropy_ loss - Cross-entropy loss from raw logits.
Computes
nll_loss(log_softmax(logits), targets). - ctc_
loss - CTC (Connectionist Temporal Classification) loss.
- decompress_
gradients - Decompress gradients back to full tensors.
- default_
cache_ dir - Returns the default cache directory for downloaded model weights.
- dequantize_
weights - Dequantize a set of quantized weights back to f32 tensors.
- dice_
loss - Dice loss for segmentation.
- distillation_
loss - Knowledge distillation loss (Hinton et al., 2015).
- distributed_
train_ step - Performs a single distributed training step: forward, backward, aggregate, update.
- export_
sequential_ to_ onnx - Exports a
SequentialModelto an ONNX protobuf byte vector. - export_
sequential_ to_ onnx_ file - Exports a
SequentialModelto an ONNX file. - focal_
loss - Focal loss for imbalanced classification.
- fuse_
conv_ bn - Fuse Conv2d + BatchNorm2d into a single Conv2d with adjusted weights and bias.
- gather_
shards - Reassemble shards (produced by
shard_tensor) back into a single tensor. - generate_
causal_ mask - Generates a causal (lower-triangular) attention mask. Returns [seq_len, seq_len] tensor where:
- generate_
padding_ mask - Generates a padding mask for batched sequences with different lengths. lengths: actual length of each sequence in the batch max_len: maximum sequence length (pad length) Returns [batch, max_len] tensor where:
- gru_
forward_ sequence - Runs a GRU cell over a sequence
[batch, seq_len, input_size]. - hinge_
loss - Mean hinge loss:
mean(max(0, margin - prediction * target)). - huber_
loss - Mean Huber loss:
mean(0.5 * min(|e|, delta)^2 + delta * max(|e| - delta, 0)), wheree = prediction - target. - infer_
batch - Batch inference on a SequentialModel (tensor mode, no autograd graph).
- infer_
batch_ graph - Runs inference through the autograd graph and returns the output tensor value.
- inspect_
weights - Lists tensor names and shapes from a weight file without loading data.
- kaiming_
normal - Kaiming (He) normal initialization.
- kaiming_
uniform - Kaiming (He) uniform initialization.
- kl_
div_ loss - KL divergence loss:
- label_
smoothing_ cross_ entropy - Cross-entropy with label smoothing.
- load_
state_ dict - Load all tensors from a SafeTensors file into a name-to-tensor map.
- load_
supervised_ dataset_ csv_ file - Loads supervised training samples from a CSV file.
- load_
supervised_ dataset_ jsonl_ file - Loads supervised training samples from a JSONL file.
- load_
supervised_ image_ folder_ dataset - Loads supervised training samples from an image-folder classification tree.
- load_
supervised_ image_ folder_ dataset_ with_ classes - Loads supervised training samples from an image-folder classification tree and returns class mapping.
- load_
supervised_ image_ manifest_ csv_ file - Loads supervised training image-manifest CSV from file.
- load_
training_ checkpoint - Load a full training checkpoint, splitting model weights from optimizer state.
- load_
weights - Loads named tensors from a binary weight file.
- loopback_
pair - Create a loopback TCP transport pair for testing.
- lr_
range_ test - Run an LR range test.
- lstm_
forward_ sequence - Runs an LSTM cell over a sequence
[batch, seq_len, input_size]. - mae_
loss - Mean absolute error loss:
mean(abs(prediction - target)). - mixed_
precision_ train_ step - Runs a mixed-precision forward+backward step.
- mse_
loss - Mean squared error loss:
mean((prediction - target)^2). - nll_
loss - Negative log-likelihood loss from log-probabilities.
Expects
log_probsshape[batch, classes]andtargetsshape[batch, 1]where targets contain class indices as f32. - optimize_
sequential - Scan a
SequentialModeland fuse Conv2d + BatchNorm2d patterns. - orthogonal
- Orthogonal initialization via QR decomposition (simplified Gram-Schmidt).
- parse_
supervised_ dataset_ csv - Parses supervised training samples from CSV text into a
SupervisedDataset. - parse_
supervised_ dataset_ jsonl - Parses supervised training samples from JSONL text into a
SupervisedDataset. - parse_
supervised_ image_ manifest_ csv - Parses supervised training image-manifest CSV into a
SupervisedDataset. - prune_
magnitude - Prune weights by magnitude: zero out the smallest
sparsityfraction. - quantize_
per_ channel - quantize_
weights - Quantize all weight tensors in a model checkpoint for storage/inference.
- quantized_
matmul - Quantized matmul: dequantize -> f32 matmul -> re-quantize.
- remap_
state_ dict - Remap an entire state dict from timm names to yscv names.
- rnn_
forward_ sequence - Runs an RNN cell over a sequence
[batch, seq_len, input_size]. - save_
training_ checkpoint - Save a full training checkpoint: model weights + optimizer state.
- save_
weights - Saves a named set of tensors to a binary file (safetensors-like format).
- scale_
gradients - Scales gradients of the given nodes by a scalar factor.
- scaled_
dot_ product_ attention - Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V.
- sgd_
state_ from_ map - Restore SGD velocity buffers from a string-keyed map.
- sgd_
state_ to_ map - Flatten SGD velocity buffers into a string-keyed map for serialization.
- shard_
tensor - Shard a tensor along its first dimension into
num_shardsroughly equal parts. - smooth_
l1_ loss - Smooth L1 loss (detection-style parameterization of Huber loss):
- split_
into_ stages - Split a model with
num_layerslayers intonum_stagesroughly equal stages. - timm_
to_ yscv_ name - Translate a timm/PyTorch weight name to the corresponding yscv name.
- train_
cnn_ epoch_ adam - One-call CNN training epoch with Adam optimizer.
- train_
cnn_ epoch_ adamw - One-call CNN training epoch with AdamW optimizer.
- train_
cnn_ epoch_ sgd - One-call CNN training epoch: register params, forward, loss, backward, update, sync.
- train_
cnn_ epochs - Runs multiple CNN training epochs, returning per-epoch metrics.
- train_
epoch_ adam - Deterministic one-epoch Adam train loop over sequential mini-batches.
- train_
epoch_ adam_ with_ loss - Deterministic one-epoch Adam train loop with configurable supervised loss.
- train_
epoch_ adam_ with_ options - Deterministic one-epoch Adam train loop with configurable batch iterator options.
- train_
epoch_ adam_ with_ options_ and_ loss - Deterministic one-epoch Adam train loop with configurable batch iterator options and loss.
- train_
epoch_ adamw - Deterministic one-epoch AdamW train loop over sequential mini-batches.
- train_
epoch_ adamw_ with_ loss - Deterministic one-epoch AdamW train loop with configurable supervised loss.
- train_
epoch_ adamw_ with_ options - Deterministic one-epoch AdamW train loop with configurable batch iterator options.
- train_
epoch_ adamw_ with_ options_ and_ loss - Deterministic one-epoch AdamW train loop with configurable batch iterator options and loss.
- train_
epoch_ distributed - Train one epoch with distributed gradient synchronization.
- train_
epoch_ distributed_ sgd - Convenience wrapper: train one distributed epoch over a
SequentialModelandSupervisedDatasetwith SGD. - train_
epoch_ rmsprop - Deterministic one-epoch RMSProp train loop over sequential mini-batches.
- train_
epoch_ rmsprop_ with_ loss - Deterministic one-epoch RMSProp train loop with configurable supervised loss.
- train_
epoch_ rmsprop_ with_ options - Deterministic one-epoch RMSProp train loop with configurable batch iterator options.
- train_
epoch_ rmsprop_ with_ options_ and_ loss - Deterministic one-epoch RMSProp train loop with configurable batch iterator options and loss.
- train_
epoch_ sgd - Deterministic one-epoch train loop over sequential mini-batches.
- train_
epoch_ sgd_ with_ loss - Deterministic one-epoch train loop with configurable supervised loss.
- train_
epoch_ sgd_ with_ options - Deterministic one-epoch train loop with configurable batch iterator options.
- train_
epoch_ sgd_ with_ options_ and_ loss - Deterministic one-epoch train loop with configurable batch iterator options and loss.
- train_
epochs_ adam_ with_ scheduler - Runs multiple Adam epochs and advances scheduler after each epoch.
- train_
epochs_ adam_ with_ scheduler_ and_ loss - Runs multiple Adam epochs with configurable supervised loss and advances scheduler after each epoch.
- train_
epochs_ adamw_ with_ scheduler - Runs multiple AdamW epochs and advances scheduler after each epoch.
- train_
epochs_ adamw_ with_ scheduler_ and_ loss - Runs multiple AdamW epochs with configurable supervised loss and advances scheduler after each epoch.
- train_
epochs_ rmsprop_ with_ scheduler - Runs multiple RMSProp epochs and advances scheduler after each epoch.
- train_
epochs_ rmsprop_ with_ scheduler_ and_ loss - Runs multiple RMSProp epochs with configurable supervised loss and advances scheduler after each epoch.
- train_
epochs_ sgd_ with_ scheduler - Runs multiple SGD epochs and advances scheduler after each epoch.
- train_
epochs_ sgd_ with_ scheduler_ and_ loss - Runs multiple SGD epochs with configurable supervised loss and advances scheduler after each epoch.
- train_
epochs_ with_ callbacks - Train for multiple epochs with callbacks.
- train_
step_ adam - Runs one full train step: loss forward, backward, and Adam updates.
- train_
step_ adam_ with_ accumulation - Runs one training step with gradient accumulation across multiple micro-batches using the Adam optimizer.
- train_
step_ adam_ with_ loss - Runs one full train step: configured loss forward, backward, and Adam updates.
- train_
step_ adamw - Runs one full train step: loss forward, backward, and AdamW updates.
- train_
step_ adamw_ with_ accumulation - Runs one training step with gradient accumulation across multiple micro-batches using the AdamW optimizer.
- train_
step_ adamw_ with_ loss - Runs one full train step: configured loss forward, backward, and AdamW updates.
- train_
step_ rmsprop - Runs one full train step: loss forward, backward, and RMSProp updates.
- train_
step_ rmsprop_ with_ accumulation - Runs one training step with gradient accumulation across multiple micro-batches using the RMSProp optimizer.
- train_
step_ rmsprop_ with_ loss - Runs one full train step: configured loss forward, backward, and RMSProp updates.
- train_
step_ sgd - Runs one full train step: loss forward, backward, and SGD updates.
- train_
step_ sgd_ with_ accumulation - Runs one training step with gradient accumulation across multiple micro-batches.
- train_
step_ sgd_ with_ loss - Runs one full train step: configured loss forward, backward, and SGD updates.
- triplet_
loss - Triplet loss for metric learning.
- xavier_
normal - Xavier (Glorot) normal initialization.
- xavier_
uniform - Xavier (Glorot) uniform initialization.