Skip to main content

yscv_model/
core.rs

1//! Model definitions, losses, checkpoints, and training helpers for yscv.
2#![forbid(unsafe_code)]
3
4pub const CRATE_ID: &str = "yscv-model";
5
6#[path = "attention.rs"]
7mod attention;
8#[path = "augmentation.rs"]
9mod augmentation;
10#[path = "batch_infer.rs"]
11mod batch_infer;
12#[path = "blocks.rs"]
13mod blocks;
14#[path = "callbacks.rs"]
15mod callbacks;
16#[path = "checkpoint.rs"]
17mod checkpoint;
18#[path = "checkpoint_state.rs"]
19mod checkpoint_state;
20#[path = "data_loader.rs"]
21mod data_loader;
22#[path = "dataset/mod.rs"]
23mod dataset;
24#[path = "distributed.rs"]
25mod distributed;
26#[path = "ema.rs"]
27mod ema;
28#[path = "error.rs"]
29mod error;
30#[path = "fusion.rs"]
31mod fusion;
32#[path = "hub.rs"]
33mod hub;
34#[path = "init.rs"]
35mod init;
36#[path = "layers/mod.rs"]
37mod layers;
38#[path = "lora.rs"]
39mod lora;
40#[path = "loss.rs"]
41mod loss;
42#[path = "lr_finder.rs"]
43mod lr_finder;
44#[path = "mixed_precision.rs"]
45mod mixed_precision;
46#[path = "onnx_export.rs"]
47mod onnx_export;
48#[path = "pipeline.rs"]
49mod pipeline;
50#[path = "quantize.rs"]
51mod quantize;
52#[path = "recurrent.rs"]
53mod recurrent;
54#[path = "safetensors.rs"]
55mod safetensors;
56#[path = "sequential.rs"]
57mod sequential;
58#[path = "tcp_transport.rs"]
59pub mod tcp_transport;
60#[path = "tensorboard.rs"]
61mod tensorboard;
62#[path = "train.rs"]
63mod train;
64#[path = "trainer.rs"]
65mod trainer;
66#[path = "training_log.rs"]
67mod training_log;
68#[path = "transform.rs"]
69mod transform;
70#[path = "transformer_decoder.rs"]
71mod transformer_decoder;
72#[path = "weight_mapping.rs"]
73mod weight_mapping;
74#[path = "weights.rs"]
75mod weights;
76#[path = "zoo.rs"]
77mod zoo;
78
79pub use attention::{
80    FeedForward, MultiHeadAttention, MultiHeadAttentionConfig, TransformerEncoderBlock,
81    generate_causal_mask, generate_padding_mask, scaled_dot_product_attention,
82};
83pub use augmentation::{ImageAugmentationOp, ImageAugmentationPipeline};
84pub use batch_infer::{BatchCollector, DynamicBatchConfig, batched_inference};
85pub use blocks::{
86    AnchorFreeHead, FpnNeck, MbConvBlock, PatchEmbedding, SqueezeExciteBlock, UNetDecoderStage,
87    UNetEncoderStage, VisionTransformer, add_bottleneck_block, add_residual_block,
88    build_resnet_feature_extractor, build_simple_cnn_classifier,
89};
90pub use callbacks::{
91    BestModelCheckpoint, EarlyStopping, MetricsLogger, MonitorMode, TrainingCallback,
92    train_epochs_with_callbacks,
93};
94pub use checkpoint::{
95    LayerCheckpoint, SequentialCheckpoint, TensorSnapshot, checkpoint_from_json, checkpoint_to_json,
96};
97pub use checkpoint_state::{
98    adam_state_from_map, adam_state_to_map, load_training_checkpoint, save_training_checkpoint,
99    sgd_state_from_map, sgd_state_to_map,
100};
101pub use data_loader::{
102    DataLoader, DataLoaderBatch, DataLoaderConfig, DataLoaderIter, RandomSampler,
103    SequentialSampler, StreamingDataLoader, WeightedRandomSampler,
104};
105pub use dataset::{
106    Batch, BatchIterOptions, CutMixConfig, DatasetSplit, ImageFolderTargetMode, MiniBatchIter,
107    MixUpConfig, SamplingPolicy, SupervisedCsvConfig, SupervisedDataset,
108    SupervisedImageFolderConfig, SupervisedImageFolderLoadResult, SupervisedImageManifestConfig,
109    SupervisedJsonlConfig, load_supervised_dataset_csv_file, load_supervised_dataset_jsonl_file,
110    load_supervised_image_folder_dataset, load_supervised_image_folder_dataset_with_classes,
111    load_supervised_image_manifest_csv_file, parse_supervised_dataset_csv,
112    parse_supervised_dataset_jsonl, parse_supervised_image_manifest_csv,
113};
114pub use distributed::{
115    AllReduceAggregator, CompressedGradient, DataParallelConfig, DistributedConfig,
116    GradientAggregator, InProcessTransport, LocalAggregator, ParameterServer,
117    PipelineParallelConfig, PipelineStage, TopKCompressor, Transport, compress_gradients,
118    decompress_gradients, distributed_train_step, gather_shards, shard_tensor, split_into_stages,
119};
120pub use ema::ExponentialMovingAverage;
121pub use error::ModelError;
122pub use fusion::{fuse_conv_bn, optimize_sequential};
123pub use hub::{HubEntry, ModelHub, default_cache_dir};
124pub use init::{
125    constant, kaiming_normal, kaiming_uniform, orthogonal, xavier_normal, xavier_uniform,
126};
127pub use layers::{
128    AdaptiveAvgPool2dLayer, AdaptiveMaxPool2dLayer, AvgPool2dLayer, BatchNorm2dLayer, Conv1dLayer,
129    Conv2dLayer, Conv3dLayer, ConvTranspose2dLayer, DeformableConv2dLayer, DepthwiseConv2dLayer,
130    DropoutLayer, EmbeddingLayer, FeedForwardLayer, FlattenLayer, GELULayer, GlobalAvgPool2dLayer,
131    GroupNormLayer, GruLayer, InstanceNormLayer, LayerNormLayer, LeakyReLULayer, LinearLayer,
132    LstmLayer, MaskHead, MaxPool2dLayer, MishLayer, ModelLayer, MultiHeadAttentionLayer,
133    PReLULayer, PixelShuffleLayer, ReLULayer, ResidualBlock, RnnLayer, SeparableConv2dLayer,
134    SiLULayer, SigmoidLayer, SoftmaxLayer, TanhLayer, TransformerEncoderLayer, UpsampleLayer,
135};
136pub use lora::{LoraConfig, LoraLinear};
137pub use loss::{
138    bce_loss, contrastive_loss, cosine_embedding_loss, cross_entropy_loss, ctc_loss, dice_loss,
139    distillation_loss, focal_loss, hinge_loss, huber_loss, kl_div_loss,
140    label_smoothing_cross_entropy, mae_loss, mse_loss, nll_loss, smooth_l1_loss, triplet_loss,
141};
142pub use lr_finder::{LrFinderConfig, LrFinderResult, lr_range_test};
143pub use mixed_precision::{
144    DynamicLossScaler, MixedPrecisionConfig, cast_params_for_forward, cast_to_master,
145    mixed_precision_train_step,
146};
147pub use onnx_export::{export_sequential_to_onnx, export_sequential_to_onnx_file};
148pub use pipeline::InferencePipeline;
149pub use quantize::{
150    PerChannelQuantResult, PrunedTensor, QuantMode, QuantizedTensor, apply_pruning_mask,
151    dequantize_weights, prune_magnitude, quantize_per_channel, quantize_weights, quantized_matmul,
152};
153pub use recurrent::{
154    GruCell, LstmCell, RnnCell, bilstm_forward_sequence, gru_forward_sequence,
155    lstm_forward_sequence, rnn_forward_sequence,
156};
157pub use safetensors::{SafeTensorDType, SafeTensorFile, TensorInfo, load_state_dict};
158pub use sequential::SequentialModel;
159pub use tcp_transport::{NodeRole, TcpAllReduceAggregator, TcpTransport, loopback_pair};
160pub use tensorboard::{TensorBoardCallback, TensorBoardWriter};
161pub use train::{
162    CnnTrainConfig, EpochMetrics, EpochTrainOptions, OptimizerType, ScheduledEpochMetrics,
163    SchedulerTrainOptions, SupervisedLoss, accumulate_gradients, collect_gradients, infer_batch,
164    infer_batch_graph, scale_gradients, train_cnn_epoch_adam, train_cnn_epoch_adamw,
165    train_cnn_epoch_sgd, train_cnn_epochs, train_epoch_adam, train_epoch_adam_with_loss,
166    train_epoch_adam_with_options, train_epoch_adam_with_options_and_loss, train_epoch_adamw,
167    train_epoch_adamw_with_loss, train_epoch_adamw_with_options,
168    train_epoch_adamw_with_options_and_loss, train_epoch_distributed, train_epoch_distributed_sgd,
169    train_epoch_rmsprop, train_epoch_rmsprop_with_loss, train_epoch_rmsprop_with_options,
170    train_epoch_rmsprop_with_options_and_loss, train_epoch_sgd, train_epoch_sgd_with_loss,
171    train_epoch_sgd_with_options, train_epoch_sgd_with_options_and_loss,
172    train_epochs_adam_with_scheduler, train_epochs_adam_with_scheduler_and_loss,
173    train_epochs_adamw_with_scheduler, train_epochs_adamw_with_scheduler_and_loss,
174    train_epochs_rmsprop_with_scheduler, train_epochs_rmsprop_with_scheduler_and_loss,
175    train_epochs_sgd_with_scheduler, train_epochs_sgd_with_scheduler_and_loss, train_step_adam,
176    train_step_adam_with_accumulation, train_step_adam_with_loss, train_step_adamw,
177    train_step_adamw_with_accumulation, train_step_adamw_with_loss, train_step_rmsprop,
178    train_step_rmsprop_with_accumulation, train_step_rmsprop_with_loss, train_step_sgd,
179    train_step_sgd_with_accumulation, train_step_sgd_with_loss,
180};
181pub use trainer::{LossKind, OptimizerKind, TrainResult, Trainer, TrainerConfig};
182pub use training_log::TrainingLog;
183pub use transform::{
184    CenterCrop, Compose, GaussianBlur, Normalize, PermuteDims, RandomHorizontalFlip, Resize,
185    ScaleValues, Transform,
186};
187pub use transformer_decoder::{CrossAttention, TransformerDecoder, TransformerDecoderBlock};
188pub use weight_mapping::{remap_state_dict, timm_to_yscv_name};
189pub use weights::{inspect_weights, load_weights, save_weights};
190pub use zoo::{
191    ArchitectureConfig, ModelArchitecture, ModelZoo, build_alexnet, build_classifier,
192    build_feature_extractor, build_mobilenet_v2, build_resnet, build_resnet_custom, build_vgg,
193};
194
195#[path = "tests/mod.rs"]
196#[cfg(test)]
197mod tests;