1#![cfg_attr(not(feature = "std"), no_std)]
22#![allow(ambiguous_glob_reexports)]
23
24#[cfg(not(feature = "std"))]
25extern crate alloc;
26
27pub mod compile_time;
29pub mod container;
30#[cfg(feature = "std")]
31pub mod conversion;
32pub mod cuda_kernels;
33#[cfg(feature = "std")]
34pub mod export;
35pub mod functional;
36pub mod gradcheck;
37pub mod hardware_opts;
38pub mod init;
39pub mod layers;
40pub mod lazy;
41pub mod mixed_precision;
42pub mod model_zoo;
43pub mod modules;
44pub mod numerical_stability;
45pub mod optimization;
46pub mod parameter_updates;
47pub mod pruning;
48pub mod quantization;
49pub mod research;
50pub mod scirs2_neural_integration;
51#[cfg(feature = "serialize")]
52pub mod serialization;
53pub mod sparse;
54pub mod summary;
55pub mod visualization;
56
57pub mod core;
63pub use core::Module;
64
65pub mod parameter;
67pub use parameter::{
68 LayerType, Parameter, ParameterCollection, ParameterDiagnostics, ParameterStats,
69};
70
71pub mod hooks;
73pub use hooks::{HookCallback, HookHandle, HookRegistry, HookType};
74
75pub mod base;
77pub use base::ModuleBase;
78
79pub mod composition;
81pub use composition::{
82 ComposedModule, ConditionalModule, ModuleBuilder, ModuleComposition, ParallelModule,
83 ResidualModule,
84};
85
86pub mod construction;
88pub use construction::{ModuleConfig, ModuleConstruct};
89
90pub mod diagnostics;
92pub use diagnostics::{ModuleDiagnostics, ModuleInfo};
93
94pub mod utils;
96pub use utils::{ModuleApply, ModuleParameterStats};
97
98use torsh_tensor::Tensor;
103
104#[cfg(not(feature = "std"))]
107use alloc::sync::Arc;
108
109#[cfg(not(feature = "std"))]
110use hashbrown::HashMap;
111
112pub const VERSION: &str = env!("CARGO_PKG_VERSION");
114pub const VERSION_MAJOR: u32 = 0;
115pub const VERSION_MINOR: u32 = 1;
116pub const VERSION_PATCH: u32 = 0;
117
118pub struct SparseMatrix;
122
123impl SparseMatrix {
124 pub fn new() -> Self {
125 Self
126 }
127}
128
129impl Default for SparseMatrix {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135pub mod prelude {
137 pub use crate::container::*;
138 #[cfg(feature = "std")]
139 pub use crate::conversion::{
140 pytorch_compat, tensorflow_compat, ConversionConfig, FrameworkSource, MigrationHelper,
141 ModelConverter,
142 };
143 pub use crate::cuda_kernels::{
144 global_kernel_registry, CudaKernelRegistry, CudaNeuralOps, CudaOptimizations,
145 CustomActivations,
146 };
147 #[cfg(feature = "std")]
148 pub use crate::export::{
149 DeploymentOptimizer, ExportConfig, ExportFormat, ModelExporter, OptimizationLevel,
150 TargetDevice,
151 };
152 pub use crate::gradcheck::{
153 fast_gradcheck, gradcheck, precise_gradcheck, GradCheckConfig, GradCheckResult, GradChecker,
154 };
155 pub use crate::init::{
156 self,
157 auto_init,
159 coordinate_mlp_init,
160 delta_orthogonal_init,
161 fixup_init,
163 gan_balanced_init,
164 lsuv_init,
165 metainit,
166 recommend_init_method,
167 rezero_alpha_init,
168 rezero_init,
169 zero_centered_variance_init,
170 ActivationHint,
171 ArchitectureHint,
172 FanMode,
174 InitMethod,
175 Initializer,
176 Nonlinearity,
177 };
178 pub use crate::layers::*;
179 pub use crate::lazy::{lazy_linear, lazy_linear_no_bias, LazyLinear, LazyModule, LazyWrapper};
180 pub use crate::mixed_precision::prelude::*;
181 #[allow(unused_imports)]
182 pub use crate::modules::*;
183 pub use crate::numerical_stability::utils::{
184 comprehensive_stability_analysis, quick_stability_check,
185 };
186 pub use crate::numerical_stability::{
187 StabilityConfig, StabilityIssue, StabilityResults, StabilityTester,
188 };
189 pub use crate::optimization::{
190 optimize_for_inference, optimize_module, MemoryProfiler, NetworkOptimizer,
191 OptimizationReport,
192 };
193 pub use crate::parameter_updates::{
195 LayerSpecificOptimizers, ParameterUpdater, UpdateConfig, UpdateStatistics,
196 };
197 pub use crate::pruning::{Pruner, PruningConfig, PruningMask, PruningScope, PruningStrategy};
198 pub use crate::quantization::prelude::*;
199 pub use crate::scirs2_neural_integration::{
200 LayerNorm, MemoryEfficientSequential, Mish, MultiHeadAttention, NeuralConfig,
201 SciRS2NeuralProcessor, Swish, TransformerEncoderLayer, GELU,
202 };
203 pub use crate::summary::profiling::{
204 AnalysisConfig, AnalysisReport, BatchProfiler, BatchProfilingConfig, BatchProfilingResult,
205 FLOPSAnalysis, FLOPSCounter, MemoryAnalysis, ModelAnalyzer,
206 };
207 pub use crate::summary::utils::*;
208 pub use crate::summary::{summarize, LayerInfo, ModelProfiler, ModelSummary, SummaryConfig};
209 pub use crate::visualization::utils::*;
210 pub use crate::visualization::{GraphEdge, GraphNode, NetworkGraph, VisualizationConfig};
211 pub use crate::{ComposedModule, ConditionalModule, ParallelModule, ResidualModule};
212 pub use crate::{
213 HookCallback, HookHandle, HookRegistry, HookType, LayerType, Module, ModuleBase,
214 ModuleConfig, ModuleConstruct, Parameter, ParameterCollection, ParameterDiagnostics,
215 ParameterStats,
216 };
217 pub use crate::{ModuleBuilder, ModuleComposition, ModuleDiagnostics, ModuleInfo};
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use torsh_core::error::Result;
224
225 #[cfg(feature = "std")]
227 use std::{boxed::Box, sync::Arc, vec::Vec};
228
229 #[cfg(not(feature = "std"))]
230 use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
231
232 use parking_lot::Mutex;
234
235 #[test]
236 fn test_parameter() {
237 let tensor = torsh_tensor::creation::ones(&[3, 4]).unwrap();
238 let param = Parameter::new(tensor);
239 assert!(param.requires_grad());
240 }
241
242 #[test]
243 fn test_hook_registry() {
244 let mut registry = HookRegistry::new();
245
246 let call_count = Arc::new(Mutex::new(0));
248 let call_count_clone = call_count.clone();
249
250 let hook = Box::new(
251 move |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
252 *call_count_clone.lock() += 1;
253 Ok(())
254 },
255 );
256
257 let handle = registry.register_hook(HookType::PreForward, hook);
258
259 assert!(registry.has_hooks(HookType::PreForward));
260 assert_eq!(registry.hook_count(HookType::PreForward), 1);
261 assert!(!registry.has_hooks(HookType::PostForward));
262
263 assert!(registry.remove_hook(HookType::PreForward, handle));
265 assert!(!registry.has_hooks(HookType::PreForward));
266 assert_eq!(registry.hook_count(HookType::PreForward), 0);
267
268 assert!(!registry.remove_hook(HookType::PreForward, handle));
270 }
271
272 #[test]
273 fn test_hook_execution() -> Result<()> {
274 let mut registry = HookRegistry::new();
275
276 let execution_log = Arc::new(Mutex::new(Vec::new()));
278 let log_clone = execution_log.clone();
279
280 let pre_hook = Box::new(
281 move |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
282 log_clone.lock().push("pre_forward".to_string());
283 Ok(())
284 },
285 );
286
287 let log_clone2 = execution_log.clone();
288 let post_hook = Box::new(
289 move |_module: &dyn Module, _input: &Tensor, output: Option<&Tensor>| {
290 assert!(output.is_some()); log_clone2.lock().push("post_forward".to_string());
292 Ok(())
293 },
294 );
295
296 registry.register_hook(HookType::PreForward, pre_hook);
297 registry.register_hook(HookType::PostForward, post_hook);
298
299 struct DummyModule;
301 impl Module for DummyModule {
302 fn forward(&self, input: &Tensor) -> Result<Tensor> {
303 Ok(input.clone())
304 }
305 }
306
307 let dummy_module = DummyModule;
308 let input = torsh_tensor::creation::zeros(&[2, 3])?;
309 let output = torsh_tensor::creation::ones(&[2, 3])?;
310
311 registry.execute_hooks(HookType::PreForward, &dummy_module, &input, None)?;
313 registry.execute_hooks(HookType::PostForward, &dummy_module, &input, Some(&output))?;
314
315 let log = execution_log.lock();
317 assert_eq!(log.len(), 2);
318 assert_eq!(log[0], "pre_forward");
319 assert_eq!(log[1], "post_forward");
320
321 Ok(())
322 }
323
324 #[test]
325 fn test_module_base_hooks() -> Result<()> {
326 let mut base = ModuleBase::new();
327
328 let call_count = Arc::new(Mutex::new(0));
330 let call_count_clone = call_count.clone();
331
332 let hook = Box::new(
333 move |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
334 *call_count_clone.lock() += 1;
335 Ok(())
336 },
337 );
338
339 let handle = base.register_hook(HookType::PreForward, hook);
340 assert!(base.has_hooks(HookType::PreForward));
341 assert_eq!(base.hook_count(HookType::PreForward), 1);
342
343 assert!(base.remove_hook(HookType::PreForward, handle));
345 assert!(!base.has_hooks(HookType::PreForward));
346
347 Ok(())
348 }
349
350 #[test]
351 fn test_hook_error_propagation() -> Result<()> {
352 let mut registry = HookRegistry::new();
353
354 let error_hook = Box::new(
356 |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
357 Err(torsh_core::error::TorshError::Other(
358 "Hook error".to_string(),
359 ))
360 },
361 );
362
363 registry.register_hook(HookType::PreForward, error_hook);
364
365 struct DummyModule;
366 impl Module for DummyModule {
367 fn forward(&self, input: &Tensor) -> Result<Tensor> {
368 Ok(input.clone())
369 }
370 }
371
372 let dummy_module = DummyModule;
373 let input = torsh_tensor::creation::zeros(&[2, 3])?;
374
375 let result = registry.execute_hooks(HookType::PreForward, &dummy_module, &input, None);
377 assert!(result.is_err());
378
379 Ok(())
380 }
381
382 #[test]
383 fn test_multiple_hooks_execution_order() -> Result<()> {
384 let mut registry = HookRegistry::new();
385
386 let execution_order = Arc::new(Mutex::new(Vec::new()));
387
388 for i in 0..3 {
390 let order_clone = execution_order.clone();
391 let hook = Box::new(
392 move |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
393 order_clone.lock().push(i);
394 Ok(())
395 },
396 );
397 registry.register_hook(HookType::PreForward, hook);
398 }
399
400 assert_eq!(registry.hook_count(HookType::PreForward), 3);
401
402 struct DummyModule;
403 impl Module for DummyModule {
404 fn forward(&self, input: &Tensor) -> Result<Tensor> {
405 Ok(input.clone())
406 }
407 }
408
409 let dummy_module = DummyModule;
410 let input = torsh_tensor::creation::zeros(&[2, 3])?;
411
412 registry.execute_hooks(HookType::PreForward, &dummy_module, &input, None)?;
413
414 let order = execution_order.lock();
416 assert_eq!(*order, vec![0, 1, 2]);
417
418 Ok(())
419 }
420
421 #[test]
422 fn test_hook_clear_operations() {
423 let mut registry = HookRegistry::new();
424
425 let dummy_hook = Box::new(|_: &dyn Module, _: &Tensor, _: Option<&Tensor>| Ok(()));
427 registry.register_hook(HookType::PreForward, dummy_hook);
428
429 let dummy_hook2 = Box::new(|_: &dyn Module, _: &Tensor, _: Option<&Tensor>| Ok(()));
430 registry.register_hook(HookType::PostForward, dummy_hook2);
431
432 assert!(registry.has_hooks(HookType::PreForward));
433 assert!(registry.has_hooks(HookType::PostForward));
434
435 registry.clear_hooks(HookType::PreForward);
437 assert!(!registry.has_hooks(HookType::PreForward));
438 assert!(registry.has_hooks(HookType::PostForward));
439
440 let dummy_hook3 = Box::new(|_: &dyn Module, _: &Tensor, _: Option<&Tensor>| Ok(()));
442 registry.register_hook(HookType::PreBackward, dummy_hook3);
443 assert!(registry.has_hooks(HookType::PreBackward));
444
445 registry.clear_all_hooks();
447 assert!(!registry.has_hooks(HookType::PreForward));
448 assert!(!registry.has_hooks(HookType::PostForward));
449 assert!(!registry.has_hooks(HookType::PreBackward));
450 assert!(!registry.has_hooks(HookType::PostBackward));
451 }
452
453 #[test]
454 fn test_modular_system_integrity() {
455 let tensor = torsh_tensor::creation::randn(&[3, 4]).unwrap();
459 let param = Parameter::new(tensor);
460 assert!(param.requires_grad());
461
462 let stats = param.stats().unwrap();
464 assert_eq!(stats.numel, 12);
465
466 let mut collection = ParameterCollection::new();
468 collection.add("test_param".to_string(), param);
469 assert_eq!(collection.len(), 1);
470 assert!(!collection.is_empty());
471
472 let base = ModuleBase::new();
474 assert!(base.training());
475
476 let registry = HookRegistry::new();
478 assert!(!registry.has_hooks(HookType::PreForward));
479
480 let config = ModuleConfig::new();
482 assert!(config.training);
483 assert_eq!(config.dropout, 0.0);
484 }
485
486 #[test]
487 fn test_backward_compatibility() {
488 let tensor = torsh_tensor::creation::ones(&[2, 3]).unwrap();
493 let param = Parameter::new(tensor);
494 assert!(param.requires_grad());
495
496 let shape = param.shape().unwrap();
498 assert_eq!(shape, vec![2, 3]);
499
500 let numel = param.numel().unwrap();
501 assert_eq!(numel, 6);
502
503 let mut base = ModuleBase::new();
505 base.register_parameter("test".to_string(), param);
506 assert_eq!(base.named_parameters().len(), 1);
507
508 let mut registry = HookRegistry::new();
510 let hook = Box::new(|_: &dyn Module, _: &Tensor, _: Option<&Tensor>| Ok(()));
511 let handle = registry.register_hook(HookType::PreForward, hook);
512 assert!(registry.has_hooks(HookType::PreForward));
513 assert!(registry.remove_hook(HookType::PreForward, handle));
514 }
515}