1#![forbid(unsafe_code)]
3
4pub const CRATE_ID: &str = "yscv-model";
5
6#[path = "attention.rs"]
7mod attention;
8#[path = "augmentation.rs"]
9mod augmentation;
10#[path = "batch_infer.rs"]
11mod batch_infer;
12#[path = "blocks.rs"]
13mod blocks;
14#[path = "callbacks.rs"]
15mod callbacks;
16#[path = "checkpoint.rs"]
17mod checkpoint;
18#[path = "checkpoint_state.rs"]
19mod checkpoint_state;
20#[path = "data_loader.rs"]
21mod data_loader;
22#[path = "dataset/mod.rs"]
23mod dataset;
24#[path = "distributed.rs"]
25mod distributed;
26#[path = "ema.rs"]
27mod ema;
28#[path = "error.rs"]
29mod error;
30#[path = "fusion.rs"]
31mod fusion;
32#[path = "hub.rs"]
33mod hub;
34#[path = "init.rs"]
35mod init;
36#[path = "layers/mod.rs"]
37mod layers;
38#[path = "lora.rs"]
39mod lora;
40#[path = "loss.rs"]
41mod loss;
42#[path = "lr_finder.rs"]
43mod lr_finder;
44#[path = "mixed_precision.rs"]
45mod mixed_precision;
46#[path = "onnx_export.rs"]
47mod onnx_export;
48#[path = "pipeline.rs"]
49mod pipeline;
50#[path = "quantize.rs"]
51mod quantize;
52#[path = "recurrent.rs"]
53mod recurrent;
54#[path = "safetensors.rs"]
55mod safetensors;
56#[path = "sequential.rs"]
57mod sequential;
58#[path = "tcp_transport.rs"]
59pub mod tcp_transport;
60#[path = "tensorboard.rs"]
61mod tensorboard;
62#[path = "train.rs"]
63mod train;
64#[path = "trainer.rs"]
65mod trainer;
66#[path = "training_log.rs"]
67mod training_log;
68#[path = "transform.rs"]
69mod transform;
70#[path = "transformer_decoder.rs"]
71mod transformer_decoder;
72#[path = "weight_mapping.rs"]
73mod weight_mapping;
74#[path = "weights.rs"]
75mod weights;
76#[path = "zoo.rs"]
77mod zoo;
78
79pub use attention::{
80 FeedForward, MultiHeadAttention, MultiHeadAttentionConfig, TransformerEncoderBlock,
81 generate_causal_mask, generate_padding_mask, scaled_dot_product_attention,
82};
83pub use augmentation::{ImageAugmentationOp, ImageAugmentationPipeline};
84pub use batch_infer::{BatchCollector, DynamicBatchConfig, batched_inference};
85pub use blocks::{
86 AnchorFreeHead, FpnNeck, MbConvBlock, PatchEmbedding, SqueezeExciteBlock, UNetDecoderStage,
87 UNetEncoderStage, VisionTransformer, add_bottleneck_block, add_residual_block,
88 build_resnet_feature_extractor, build_simple_cnn_classifier,
89};
90pub use callbacks::{
91 BestModelCheckpoint, EarlyStopping, MetricsLogger, MonitorMode, TrainingCallback,
92 train_epochs_with_callbacks,
93};
94pub use checkpoint::{
95 LayerCheckpoint, SequentialCheckpoint, TensorSnapshot, checkpoint_from_json, checkpoint_to_json,
96};
97pub use checkpoint_state::{
98 adam_state_from_map, adam_state_to_map, load_training_checkpoint, save_training_checkpoint,
99 sgd_state_from_map, sgd_state_to_map,
100};
101pub use data_loader::{
102 DataLoader, DataLoaderBatch, DataLoaderConfig, DataLoaderIter, RandomSampler,
103 SequentialSampler, StreamingDataLoader, WeightedRandomSampler,
104};
105pub use dataset::{
106 Batch, BatchIterOptions, CutMixConfig, DatasetSplit, ImageFolderTargetMode, MiniBatchIter,
107 MixUpConfig, SamplingPolicy, SupervisedCsvConfig, SupervisedDataset,
108 SupervisedImageFolderConfig, SupervisedImageFolderLoadResult, SupervisedImageManifestConfig,
109 SupervisedJsonlConfig, load_supervised_dataset_csv_file, load_supervised_dataset_jsonl_file,
110 load_supervised_image_folder_dataset, load_supervised_image_folder_dataset_with_classes,
111 load_supervised_image_manifest_csv_file, parse_supervised_dataset_csv,
112 parse_supervised_dataset_jsonl, parse_supervised_image_manifest_csv,
113};
114pub use distributed::{
115 AllReduceAggregator, CompressedGradient, DataParallelConfig, DistributedConfig,
116 GradientAggregator, InProcessTransport, LocalAggregator, ParameterServer,
117 PipelineParallelConfig, PipelineStage, TopKCompressor, Transport, compress_gradients,
118 decompress_gradients, distributed_train_step, gather_shards, shard_tensor, split_into_stages,
119};
120pub use ema::ExponentialMovingAverage;
121pub use error::ModelError;
122pub use fusion::{fuse_conv_bn, optimize_sequential};
123pub use hub::{HubEntry, ModelHub, default_cache_dir};
124pub use init::{
125 constant, kaiming_normal, kaiming_uniform, orthogonal, xavier_normal, xavier_uniform,
126};
127pub use layers::{
128 AdaptiveAvgPool2dLayer, AdaptiveMaxPool2dLayer, AvgPool2dLayer, BatchNorm2dLayer, Conv1dLayer,
129 Conv2dLayer, Conv3dLayer, ConvTranspose2dLayer, DeformableConv2dLayer, DepthwiseConv2dLayer,
130 DropoutLayer, EmbeddingLayer, FeedForwardLayer, FlattenLayer, GELULayer, GlobalAvgPool2dLayer,
131 GroupNormLayer, GruLayer, InstanceNormLayer, LayerNormLayer, LeakyReLULayer, LinearLayer,
132 LstmLayer, MaskHead, MaxPool2dLayer, MishLayer, ModelLayer, MultiHeadAttentionLayer,
133 PReLULayer, PixelShuffleLayer, ReLULayer, ResidualBlock, RnnLayer, SeparableConv2dLayer,
134 SiLULayer, SigmoidLayer, SoftmaxLayer, TanhLayer, TransformerEncoderLayer, UpsampleLayer,
135};
136pub use lora::{LoraConfig, LoraLinear};
137pub use loss::{
138 bce_loss, contrastive_loss, cosine_embedding_loss, cross_entropy_loss, ctc_loss, dice_loss,
139 distillation_loss, focal_loss, hinge_loss, huber_loss, kl_div_loss,
140 label_smoothing_cross_entropy, mae_loss, mse_loss, nll_loss, smooth_l1_loss, triplet_loss,
141};
142pub use lr_finder::{LrFinderConfig, LrFinderResult, lr_range_test};
143pub use mixed_precision::{
144 DynamicLossScaler, MixedPrecisionConfig, cast_params_for_forward, cast_to_master,
145 mixed_precision_train_step,
146};
147pub use onnx_export::{export_sequential_to_onnx, export_sequential_to_onnx_file};
148pub use pipeline::InferencePipeline;
149pub use quantize::{
150 PerChannelQuantResult, PrunedTensor, QuantMode, QuantizedTensor, apply_pruning_mask,
151 dequantize_weights, prune_magnitude, quantize_per_channel, quantize_weights, quantized_matmul,
152};
153pub use recurrent::{
154 GruCell, LstmCell, RnnCell, bilstm_forward_sequence, gru_forward_sequence,
155 lstm_forward_sequence, rnn_forward_sequence,
156};
157pub use safetensors::{SafeTensorDType, SafeTensorFile, TensorInfo, load_state_dict};
158pub use sequential::SequentialModel;
159pub use tcp_transport::{NodeRole, TcpAllReduceAggregator, TcpTransport, loopback_pair};
160pub use tensorboard::{TensorBoardCallback, TensorBoardWriter};
161pub use train::{
162 CnnTrainConfig, EpochMetrics, EpochTrainOptions, OptimizerType, ScheduledEpochMetrics,
163 SchedulerTrainOptions, SupervisedLoss, accumulate_gradients, collect_gradients, infer_batch,
164 infer_batch_graph, scale_gradients, train_cnn_epoch_adam, train_cnn_epoch_adamw,
165 train_cnn_epoch_sgd, train_cnn_epochs, train_epoch_adam, train_epoch_adam_with_loss,
166 train_epoch_adam_with_options, train_epoch_adam_with_options_and_loss, train_epoch_adamw,
167 train_epoch_adamw_with_loss, train_epoch_adamw_with_options,
168 train_epoch_adamw_with_options_and_loss, train_epoch_distributed, train_epoch_distributed_sgd,
169 train_epoch_rmsprop, train_epoch_rmsprop_with_loss, train_epoch_rmsprop_with_options,
170 train_epoch_rmsprop_with_options_and_loss, train_epoch_sgd, train_epoch_sgd_with_loss,
171 train_epoch_sgd_with_options, train_epoch_sgd_with_options_and_loss,
172 train_epochs_adam_with_scheduler, train_epochs_adam_with_scheduler_and_loss,
173 train_epochs_adamw_with_scheduler, train_epochs_adamw_with_scheduler_and_loss,
174 train_epochs_rmsprop_with_scheduler, train_epochs_rmsprop_with_scheduler_and_loss,
175 train_epochs_sgd_with_scheduler, train_epochs_sgd_with_scheduler_and_loss, train_step_adam,
176 train_step_adam_with_accumulation, train_step_adam_with_loss, train_step_adamw,
177 train_step_adamw_with_accumulation, train_step_adamw_with_loss, train_step_rmsprop,
178 train_step_rmsprop_with_accumulation, train_step_rmsprop_with_loss, train_step_sgd,
179 train_step_sgd_with_accumulation, train_step_sgd_with_loss,
180};
181pub use trainer::{LossKind, OptimizerKind, TrainResult, Trainer, TrainerConfig};
182pub use training_log::TrainingLog;
183pub use transform::{
184 CenterCrop, Compose, GaussianBlur, Normalize, PermuteDims, RandomHorizontalFlip, Resize,
185 ScaleValues, Transform,
186};
187pub use transformer_decoder::{CrossAttention, TransformerDecoder, TransformerDecoderBlock};
188pub use weight_mapping::{remap_state_dict, timm_to_yscv_name};
189pub use weights::{inspect_weights, load_weights, save_weights};
190pub use zoo::{
191 ArchitectureConfig, ModelArchitecture, ModelZoo, build_alexnet, build_classifier,
192 build_feature_extractor, build_mobilenet_v2, build_resnet, build_resnet_custom, build_vgg,
193};
194
195#[path = "tests/mod.rs"]
196#[cfg(test)]
197mod tests;