1#![allow(dead_code)]
8
9pub mod architectures;
10pub mod audio;
11pub mod benchmark;
12pub mod builder;
13pub mod common;
14pub mod comparison;
15pub mod config;
16pub mod distillation;
17pub mod domain;
18pub mod downloader;
19pub mod ensembling;
20pub mod few_shot;
21pub mod fine_tuning;
22pub mod generative;
23pub mod gnn;
24pub mod lazy_loading;
25pub mod model_merging;
26pub mod model_sharding;
27pub mod multimodal;
28pub mod nlp;
29pub mod optimization;
30pub mod pruning;
32pub mod quantization;
33pub mod registry;
34pub mod rl;
35pub mod surgery;
36pub mod utils;
37pub mod validation;
38pub mod video;
39pub mod vision;
40pub mod vision_3d;
41
42#[cfg(feature = "diffusion_extended")]
43pub mod diffusion;
44
45pub use downloader::{DownloadProgress, ModelDownloader};
47pub use lazy_loading::{CacheStats, LazyModelLoader, LazyTensor, StreamingModelLoader};
48pub use model_merging::{LoRAMerger, MergeStrategy, ModelMerger, ModelSoup};
49pub use model_sharding::{DevicePlacement, ModelSharder, ShardingStats, ShardingStrategy};
50pub use registry::{ModelHandle, ModelInfo, ModelRegistry};
51pub use utils::{
52 convert_model_format, convert_pytorch_state_dict, convert_to_pytorch_state_dict,
53 load_model_from_file, load_model_weights, load_pytorch_checkpoint, load_safetensors_weights,
54 load_state_dict, map_parameter_names, save_model_to_file, save_pytorch_checkpoint,
55 save_tensors_to_safetensors, ModelFormat, ModelMetadata,
56};
57
58use thiserror::Error;
60
61#[derive(Error, Debug)]
62pub enum ModelError {
63 #[error("Model not found: {name}")]
64 ModelNotFound { name: String },
65
66 #[error("Download failed: {reason}")]
67 DownloadFailed { reason: String },
68
69 #[error("Invalid model format: {format}")]
70 InvalidFormat { format: String },
71
72 #[error("Serialization error: {0}")]
73 Serialization(#[from] safetensors::SafeTensorError),
74
75 #[error("IO error: {0}")]
76 Io(#[from] std::io::Error),
77
78 #[error("Network error: {0}")]
79 #[cfg(feature = "download")]
80 Network(#[from] reqwest::Error),
81
82 #[error("JSON error: {0}")]
83 Json(#[from] serde_json::Error),
84
85 #[error("Model loading error: {reason}")]
86 LoadingError { reason: String },
87
88 #[error("Model validation error: {reason}")]
89 ValidationError { reason: String },
90
91 #[error("ToRSh error: {0}")]
92 TorshError(#[from] torsh_core::error::TorshError),
93}
94
95pub type ModelResult<T> = Result<T, ModelError>;
97
98macro_rules! define_model_type {
100 (
101 $(
102 $(#[cfg(feature = $feature:literal)])?
103 $variant:ident($type:ty),
104 )*
105 ) => {
106 pub enum ModelType {
108 $(
109 $(#[cfg(feature = $feature)])?
110 $variant($type),
111 )*
112 }
113
114 impl torsh_nn::Module for ModelType {
115 fn forward(&self, input: &torsh_tensor::Tensor) -> torsh_core::error::Result<torsh_tensor::Tensor> {
116 match self {
117 $(
118 $(#[cfg(feature = $feature)])?
119 ModelType::$variant(model) => model.forward(input),
120 )*
121 }
122 }
123
124 fn parameters(&self) -> std::collections::HashMap<String, torsh_nn::Parameter> {
125 match self {
126 $(
127 $(#[cfg(feature = $feature)])?
128 ModelType::$variant(model) => model.parameters(),
129 )*
130 }
131 }
132
133 fn named_parameters(&self) -> std::collections::HashMap<String, torsh_nn::Parameter> {
134 match self {
135 $(
136 $(#[cfg(feature = $feature)])?
137 ModelType::$variant(model) => model.named_parameters(),
138 )*
139 }
140 }
141
142 fn training(&self) -> bool {
143 match self {
144 $(
145 $(#[cfg(feature = $feature)])?
146 ModelType::$variant(model) => model.training(),
147 )*
148 }
149 }
150
151 fn train(&mut self) {
152 match self {
153 $(
154 $(#[cfg(feature = $feature)])?
155 ModelType::$variant(model) => model.train(),
156 )*
157 }
158 }
159
160 fn eval(&mut self) {
161 match self {
162 $(
163 $(#[cfg(feature = $feature)])?
164 ModelType::$variant(model) => model.eval(),
165 )*
166 }
167 }
168
169 fn to_device(&mut self, device: torsh_core::DeviceType) -> torsh_core::error::Result<()> {
170 match self {
171 $(
172 $(#[cfg(feature = $feature)])?
173 ModelType::$variant(model) => model.to_device(device),
174 )*
175 }
176 }
177 }
178 };
179}
180
181define_model_type! {
183 #[cfg(feature = "vision")]
184 ResNet(crate::vision::ResNet),
185 #[cfg(feature = "vision")]
186 VisionTransformer(crate::vision::VisionTransformer),
187 }
231
232pub mod prelude {
234 pub use crate::{
235 convert_model_format, convert_pytorch_state_dict, convert_to_pytorch_state_dict,
236 load_model_from_file, load_model_weights, load_pytorch_checkpoint,
237 load_safetensors_weights, load_state_dict, map_parameter_names, save_model_to_file,
238 save_pytorch_checkpoint, save_tensors_to_safetensors, DownloadProgress, ModelDownloader,
239 ModelError, ModelFormat, ModelHandle, ModelInfo, ModelMetadata, ModelRegistry, ModelResult,
240 ModelType,
241 };
242
243 pub use crate::comparison::*;
244 pub use crate::distillation::*;
245 pub use crate::ensembling::*;
246 pub use crate::few_shot::*;
247 pub use crate::fine_tuning::*;
248 pub use crate::pruning::*;
249 pub use crate::quantization::*;
250 pub use crate::surgery::*;
251 pub use crate::validation::*;
252
253 #[cfg(feature = "vision")]
254 pub use crate::vision::*;
255
256 #[cfg(feature = "nlp")]
257 pub use crate::nlp::*;
258
259 #[cfg(feature = "audio")]
260 pub use crate::audio::*;
261
262 #[cfg(feature = "multimodal")]
263 pub use crate::multimodal::*;
264
265 #[cfg(feature = "gnn")]
266 pub use crate::gnn::*;
267
268 #[cfg(feature = "vision_3d")]
269 pub use crate::vision_3d::*;
270
271 #[cfg(feature = "video")]
272 pub use crate::video::*;
273
274 #[cfg(feature = "generative")]
275 pub use crate::generative::*;
276
277 #[cfg(feature = "rl")]
278 pub use crate::rl::*;
279
280 #[cfg(feature = "domain")]
281 pub use crate::domain::*;
282}