Skip to main content

torsh_models/
lib.rs

1//! Pre-trained models and model zoo for ToRSh deep learning framework
2//!
3//! This crate provides a comprehensive collection of pre-trained models and utilities
4//! for loading, using, and managing deep learning models in ToRSh.
5
6// Framework infrastructure - components designed for future use
7#![allow(dead_code)]
8
9pub mod architectures;
10pub mod audio;
11pub mod benchmark;
12pub mod builder;
13pub mod common;
14pub mod comparison;
15pub mod config;
16pub mod distillation;
17pub mod domain;
18pub mod downloader;
19pub mod ensembling;
20pub mod few_shot;
21pub mod fine_tuning;
22pub mod generative;
23pub mod gnn;
24pub mod lazy_loading;
25pub mod model_merging;
26pub mod model_sharding;
27pub mod multimodal;
28pub mod nlp;
29pub mod optimization;
30// pub mod prelude; // Defined inline below
31pub mod pruning;
32pub mod quantization;
33pub mod registry;
34pub mod rl;
35pub mod surgery;
36pub mod utils;
37pub mod validation;
38pub mod video;
39pub mod vision;
40pub mod vision_3d;
41
42#[cfg(feature = "diffusion_extended")]
43pub mod diffusion;
44
45// Re-exports
46pub use downloader::{DownloadProgress, ModelDownloader};
47pub use lazy_loading::{CacheStats, LazyModelLoader, LazyTensor, StreamingModelLoader};
48pub use model_merging::{LoRAMerger, MergeStrategy, ModelMerger, ModelSoup};
49pub use model_sharding::{DevicePlacement, ModelSharder, ShardingStats, ShardingStrategy};
50pub use registry::{ModelHandle, ModelInfo, ModelRegistry};
51pub use utils::{
52    convert_model_format, convert_pytorch_state_dict, convert_to_pytorch_state_dict,
53    load_model_from_file, load_model_weights, load_pytorch_checkpoint, load_safetensors_weights,
54    load_state_dict, map_parameter_names, save_model_to_file, save_pytorch_checkpoint,
55    save_tensors_to_safetensors, ModelFormat, ModelMetadata,
56};
57
58/// Common error types
59use thiserror::Error;
60
61#[derive(Error, Debug)]
62pub enum ModelError {
63    #[error("Model not found: {name}")]
64    ModelNotFound { name: String },
65
66    #[error("Download failed: {reason}")]
67    DownloadFailed { reason: String },
68
69    #[error("Invalid model format: {format}")]
70    InvalidFormat { format: String },
71
72    #[error("Serialization error: {0}")]
73    Serialization(#[from] safetensors::SafeTensorError),
74
75    #[error("IO error: {0}")]
76    Io(#[from] std::io::Error),
77
78    #[error("Network error: {0}")]
79    #[cfg(feature = "download")]
80    Network(#[from] reqwest::Error),
81
82    #[error("JSON error: {0}")]
83    Json(#[from] serde_json::Error),
84
85    #[error("Model loading error: {reason}")]
86    LoadingError { reason: String },
87
88    #[error("Model validation error: {reason}")]
89    ValidationError { reason: String },
90
91    #[error("ToRSh error: {0}")]
92    TorshError(#[from] torsh_core::error::TorshError),
93}
94
95/// Result type for model operations
96pub type ModelResult<T> = Result<T, ModelError>;
97
98/// Macro to define model types and generate implementations
99macro_rules! define_model_type {
100    (
101        $(
102            $(#[cfg(feature = $feature:literal)])?
103            $variant:ident($type:ty),
104        )*
105    ) => {
106        /// Concrete model enum to avoid trait object issues
107        pub enum ModelType {
108            $(
109                $(#[cfg(feature = $feature)])?
110                $variant($type),
111            )*
112        }
113
114        impl torsh_nn::Module for ModelType {
115            fn forward(&self, input: &torsh_tensor::Tensor) -> torsh_core::error::Result<torsh_tensor::Tensor> {
116                match self {
117                    $(
118                        $(#[cfg(feature = $feature)])?
119                        ModelType::$variant(model) => model.forward(input),
120                    )*
121                }
122            }
123
124            fn parameters(&self) -> std::collections::HashMap<String, torsh_nn::Parameter> {
125                match self {
126                    $(
127                        $(#[cfg(feature = $feature)])?
128                        ModelType::$variant(model) => model.parameters(),
129                    )*
130                }
131            }
132
133            fn named_parameters(&self) -> std::collections::HashMap<String, torsh_nn::Parameter> {
134                match self {
135                    $(
136                        $(#[cfg(feature = $feature)])?
137                        ModelType::$variant(model) => model.named_parameters(),
138                    )*
139                }
140            }
141
142            fn training(&self) -> bool {
143                match self {
144                    $(
145                        $(#[cfg(feature = $feature)])?
146                        ModelType::$variant(model) => model.training(),
147                    )*
148                }
149            }
150
151            fn train(&mut self) {
152                match self {
153                    $(
154                        $(#[cfg(feature = $feature)])?
155                        ModelType::$variant(model) => model.train(),
156                    )*
157                }
158            }
159
160            fn eval(&mut self) {
161                match self {
162                    $(
163                        $(#[cfg(feature = $feature)])?
164                        ModelType::$variant(model) => model.eval(),
165                    )*
166                }
167            }
168
169            fn to_device(&mut self, device: torsh_core::DeviceType) -> torsh_core::error::Result<()> {
170                match self {
171                    $(
172                        $(#[cfg(feature = $feature)])?
173                        ModelType::$variant(model) => model.to_device(device),
174                    )*
175                }
176            }
177        }
178    };
179}
180
181// Define all model types using the macro
182define_model_type! {
183    #[cfg(feature = "vision")]
184    ResNet(crate::vision::ResNet),
185    #[cfg(feature = "vision")]
186    VisionTransformer(crate::vision::VisionTransformer),
187    // NOTE: Additional vision models exist but require API updates for torsh-nn compatibility:
188    // - EfficientNet, SwinTransformer, ConvNeXt (implemented, needs torsh-nn v0.2 API)
189    // - DETR, MaskRCNN, YOLO (implemented, needs torsh-nn v0.2 API)
190    // - MobileNetV2, DenseNet (implemented, needs torsh-nn v0.2 API)
191    // These will be enabled in a future release once API compatibility is resolved
192    // NOTE: Additional model types planned for v0.2.0:
193    // - NLP: RoBERTa, BART, T5, GPT-2, XLNet, ELECTRA, DeBERTa, Longformer, BigBird
194    // - Audio: Wav2Vec2, Whisper, HuBERT, WavLM (base implementations exist)
195    // - Multimodal: CLIP, ALIGN (base implementations exist), Flamingo, DALL-E, BLIP, LLaVA, InstructBLIP
196    // - GNN: GCN, GraphSAGE, GAT, GIN
197    // These require module completion and/or API compatibility updates
198    // #[cfg(feature = "vision_3d")]
199    // CNN3D(crate::vision_3d::CNN3D),
200    // #[cfg(feature = "vision_3d")]
201    // PointNet(crate::vision_3d::PointNet),
202    // #[cfg(feature = "vision_3d")]
203    // PointNetPlusPlus(crate::vision_3d::PointNetPlusPlus),
204    // #[cfg(feature = "video")]
205    // ResNet3D(crate::video::ResNet3D),
206    // #[cfg(feature = "video")]
207    // SlowFast(crate::video::SlowFast),
208    // #[cfg(feature = "video")]
209    // VideoTransformer(crate::video::VideoTransformer),
210    // #[cfg(feature = "generative")]
211    // VAE(crate::generative::VAE),
212    // #[cfg(feature = "generative")]
213    // GAN(crate::generative::GAN),
214    // #[cfg(feature = "generative")]
215    // DiffusionModel(crate::generative::DiffusionUNet),
216    // #[cfg(feature = "rl")]
217    // DQN(crate::rl::DQN),
218    // #[cfg(feature = "rl")]
219    // PPO(crate::rl::PPO),
220    // #[cfg(feature = "rl")]
221    // A3C(crate::rl::A3C),
222    // #[cfg(feature = "domain")]
223    // UNet(crate::domain::UNet),
224    // #[cfg(feature = "domain")]
225    // UNet3D(crate::domain::UNet3D),
226    // #[cfg(feature = "domain")]
227    // PINN(crate::domain::PINN),
228    // #[cfg(feature = "domain")]
229    // FNO(crate::domain::FNO),
230}
231
232/// Prelude module for convenient imports
233pub mod prelude {
234    pub use crate::{
235        convert_model_format, convert_pytorch_state_dict, convert_to_pytorch_state_dict,
236        load_model_from_file, load_model_weights, load_pytorch_checkpoint,
237        load_safetensors_weights, load_state_dict, map_parameter_names, save_model_to_file,
238        save_pytorch_checkpoint, save_tensors_to_safetensors, DownloadProgress, ModelDownloader,
239        ModelError, ModelFormat, ModelHandle, ModelInfo, ModelMetadata, ModelRegistry, ModelResult,
240        ModelType,
241    };
242
243    pub use crate::comparison::*;
244    pub use crate::distillation::*;
245    pub use crate::ensembling::*;
246    pub use crate::few_shot::*;
247    pub use crate::fine_tuning::*;
248    pub use crate::pruning::*;
249    pub use crate::quantization::*;
250    pub use crate::surgery::*;
251    pub use crate::validation::*;
252
253    #[cfg(feature = "vision")]
254    pub use crate::vision::*;
255
256    #[cfg(feature = "nlp")]
257    pub use crate::nlp::*;
258
259    #[cfg(feature = "audio")]
260    pub use crate::audio::*;
261
262    #[cfg(feature = "multimodal")]
263    pub use crate::multimodal::*;
264
265    #[cfg(feature = "gnn")]
266    pub use crate::gnn::*;
267
268    #[cfg(feature = "vision_3d")]
269    pub use crate::vision_3d::*;
270
271    #[cfg(feature = "video")]
272    pub use crate::video::*;
273
274    #[cfg(feature = "generative")]
275    pub use crate::generative::*;
276
277    #[cfg(feature = "rl")]
278    pub use crate::rl::*;
279
280    #[cfg(feature = "domain")]
281    pub use crate::domain::*;
282}