Skip to main content

torsh_nn/
lib.rs

1//! Neural network modules for ToRSh
2//!
3//! This crate provides PyTorch-compatible neural network layers and modules,
4//! built on top of scirs2-neural for optimized implementations.
5//!
6//! # Modular Architecture
7//!
8//! The neural network core system is organized into specialized modules for improved maintainability:
9//!
10//! - **core**: Core Module trait system and essential interfaces
11//! - **parameter**: Comprehensive parameter management and initialization
12//! - **hooks**: Hook system infrastructure for module callbacks
13//! - **base**: ModuleBase helper for module implementations
14//! - **composition**: Module composition patterns (sequential, parallel, etc.)
15//! - **construction**: Module construction and configuration patterns
16//! - **diagnostics**: Module and parameter diagnostics and health checking
17//! - **utils**: Module utilities and helper functions
18//!
19//! All components maintain full backward compatibility through comprehensive re-exports.
20
21#![cfg_attr(not(feature = "std"), no_std)]
22#![allow(ambiguous_glob_reexports)]
23
24#[cfg(not(feature = "std"))]
25extern crate alloc;
26
27// pub mod checkpoint; // Temporarily disabled for testing
28pub mod compile_time;
29pub mod container;
30#[cfg(feature = "std")]
31pub mod conversion;
32pub mod cuda_kernels;
33#[cfg(feature = "std")]
34pub mod export;
35pub mod functional;
36pub mod gradcheck;
37pub mod hardware_opts;
38pub mod init;
39pub mod layers;
40pub mod lazy;
41pub mod mixed_precision;
42pub mod model_zoo;
43pub mod modules;
44pub mod numerical_stability;
45pub mod optimization;
46pub mod parameter_updates;
47pub mod pruning;
48pub mod quantization;
49pub mod research;
50pub mod scirs2_neural_integration;
51#[cfg(feature = "serialize")]
52pub mod serialization;
53pub mod sparse;
54pub mod summary;
55pub mod visualization;
56
57// =============================================================================
58// MODULAR ARCHITECTURE IMPORTS
59// =============================================================================
60
61// Core module trait system
62pub mod core;
63pub use core::Module;
64
65// Parameter management system
66pub mod parameter;
67pub use parameter::{
68    LayerType, Parameter, ParameterCollection, ParameterDiagnostics, ParameterStats,
69};
70
71// Hook system infrastructure
72pub mod hooks;
73pub use hooks::{HookCallback, HookHandle, HookRegistry, HookType};
74
75// Module base infrastructure
76pub mod base;
77pub use base::ModuleBase;
78
79// Module composition system
80pub mod composition;
81pub use composition::{
82    ComposedModule, ConditionalModule, ModuleBuilder, ModuleComposition, ParallelModule,
83    ResidualModule,
84};
85
86// Module construction and configuration
87pub mod construction;
88pub use construction::{ModuleConfig, ModuleConstruct};
89
90// Module diagnostics and analysis
91pub mod diagnostics;
92pub use diagnostics::{ModuleDiagnostics, ModuleInfo};
93
94// Module utilities
95pub mod utils;
96pub use utils::{ModuleApply, ModuleParameterStats};
97
98// =============================================================================
99// BACKWARD COMPATIBILITY IMPORTS AND RE-EXPORTS
100// =============================================================================
101
102use torsh_tensor::Tensor;
103
104// Conditional imports for std/no_std compatibility
105
106#[cfg(not(feature = "std"))]
107use alloc::sync::Arc;
108
109#[cfg(not(feature = "std"))]
110use hashbrown::HashMap;
111
112// Version information
113pub const VERSION: &str = env!("CARGO_PKG_VERSION");
114pub const VERSION_MAJOR: u32 = 0;
115pub const VERSION_MINOR: u32 = 1;
116pub const VERSION_PATCH: u32 = 0;
117
118// Note: impl_module_constructor macro is already available via #[macro_export]
119
120/// Sparse Matrix placeholder for compatibility
121pub struct SparseMatrix;
122
123impl SparseMatrix {
124    pub fn new() -> Self {
125        Self
126    }
127}
128
129impl Default for SparseMatrix {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135/// Prelude module for convenient imports
136pub mod prelude {
137    pub use crate::container::*;
138    #[cfg(feature = "std")]
139    pub use crate::conversion::{
140        pytorch_compat, tensorflow_compat, ConversionConfig, FrameworkSource, MigrationHelper,
141        ModelConverter,
142    };
143    pub use crate::cuda_kernels::{
144        global_kernel_registry, CudaKernelRegistry, CudaNeuralOps, CudaOptimizations,
145        CustomActivations,
146    };
147    #[cfg(feature = "std")]
148    pub use crate::export::{
149        DeploymentOptimizer, ExportConfig, ExportFormat, ModelExporter, OptimizationLevel,
150        TargetDevice,
151    };
152    pub use crate::gradcheck::{
153        fast_gradcheck, gradcheck, precise_gradcheck, GradCheckConfig, GradCheckResult, GradChecker,
154    };
155    pub use crate::init::{
156        self,
157        // Automatic initialization
158        auto_init,
159        coordinate_mlp_init,
160        delta_orthogonal_init,
161        // Modern initialization methods
162        fixup_init,
163        gan_balanced_init,
164        lsuv_init,
165        metainit,
166        recommend_init_method,
167        rezero_alpha_init,
168        rezero_init,
169        zero_centered_variance_init,
170        ActivationHint,
171        ArchitectureHint,
172        // Core types
173        FanMode,
174        InitMethod,
175        Initializer,
176        Nonlinearity,
177    };
178    pub use crate::layers::*;
179    pub use crate::lazy::{lazy_linear, lazy_linear_no_bias, LazyLinear, LazyModule, LazyWrapper};
180    pub use crate::mixed_precision::prelude::*;
181    #[allow(unused_imports)]
182    pub use crate::modules::*;
183    pub use crate::numerical_stability::utils::{
184        comprehensive_stability_analysis, quick_stability_check,
185    };
186    pub use crate::numerical_stability::{
187        StabilityConfig, StabilityIssue, StabilityResults, StabilityTester,
188    };
189    pub use crate::optimization::{
190        optimize_for_inference, optimize_module, MemoryProfiler, NetworkOptimizer,
191        OptimizationReport,
192    };
193    // pub use crate::parameter::sharing::{ParameterSharingRegistry, SharingStats}; // Module not yet implemented
194    pub use crate::parameter_updates::{
195        LayerSpecificOptimizers, ParameterUpdater, UpdateConfig, UpdateStatistics,
196    };
197    pub use crate::pruning::{Pruner, PruningConfig, PruningMask, PruningScope, PruningStrategy};
198    pub use crate::quantization::prelude::*;
199    pub use crate::scirs2_neural_integration::{
200        LayerNorm, MemoryEfficientSequential, Mish, MultiHeadAttention, NeuralConfig,
201        SciRS2NeuralProcessor, Swish, TransformerEncoderLayer, GELU,
202    };
203    pub use crate::summary::profiling::{
204        AnalysisConfig, AnalysisReport, BatchProfiler, BatchProfilingConfig, BatchProfilingResult,
205        FLOPSAnalysis, FLOPSCounter, MemoryAnalysis, ModelAnalyzer,
206    };
207    pub use crate::summary::utils::*;
208    pub use crate::summary::{summarize, LayerInfo, ModelProfiler, ModelSummary, SummaryConfig};
209    pub use crate::visualization::utils::*;
210    pub use crate::visualization::{GraphEdge, GraphNode, NetworkGraph, VisualizationConfig};
211    pub use crate::{ComposedModule, ConditionalModule, ParallelModule, ResidualModule};
212    pub use crate::{
213        HookCallback, HookHandle, HookRegistry, HookType, LayerType, Module, ModuleBase,
214        ModuleConfig, ModuleConstruct, Parameter, ParameterCollection, ParameterDiagnostics,
215        ParameterStats,
216    };
217    pub use crate::{ModuleBuilder, ModuleComposition, ModuleDiagnostics, ModuleInfo};
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use torsh_core::error::Result;
224
225    // Conditional imports for std/no_std compatibility
226    #[cfg(feature = "std")]
227    use std::{boxed::Box, sync::Arc, vec::Vec};
228
229    #[cfg(not(feature = "std"))]
230    use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
231
232    // Use parking_lot::Mutex for both std and no_std
233    use parking_lot::Mutex;
234
235    #[test]
236    fn test_parameter() {
237        let tensor = torsh_tensor::creation::ones(&[3, 4]).unwrap();
238        let param = Parameter::new(tensor);
239        assert!(param.requires_grad());
240    }
241
242    #[test]
243    fn test_hook_registry() {
244        let mut registry = HookRegistry::new();
245
246        // Test registering hooks
247        let call_count = Arc::new(Mutex::new(0));
248        let call_count_clone = call_count.clone();
249
250        let hook = Box::new(
251            move |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
252                *call_count_clone.lock() += 1;
253                Ok(())
254            },
255        );
256
257        let handle = registry.register_hook(HookType::PreForward, hook);
258
259        assert!(registry.has_hooks(HookType::PreForward));
260        assert_eq!(registry.hook_count(HookType::PreForward), 1);
261        assert!(!registry.has_hooks(HookType::PostForward));
262
263        // Test removing hooks
264        assert!(registry.remove_hook(HookType::PreForward, handle));
265        assert!(!registry.has_hooks(HookType::PreForward));
266        assert_eq!(registry.hook_count(HookType::PreForward), 0);
267
268        // Test removing non-existent hook
269        assert!(!registry.remove_hook(HookType::PreForward, handle));
270    }
271
272    #[test]
273    fn test_hook_execution() -> Result<()> {
274        let mut registry = HookRegistry::new();
275
276        // Track hook execution
277        let execution_log = Arc::new(Mutex::new(Vec::new()));
278        let log_clone = execution_log.clone();
279
280        let pre_hook = Box::new(
281            move |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
282                log_clone.lock().push("pre_forward".to_string());
283                Ok(())
284            },
285        );
286
287        let log_clone2 = execution_log.clone();
288        let post_hook = Box::new(
289            move |_module: &dyn Module, _input: &Tensor, output: Option<&Tensor>| {
290                assert!(output.is_some()); // Post-forward should have output
291                log_clone2.lock().push("post_forward".to_string());
292                Ok(())
293            },
294        );
295
296        registry.register_hook(HookType::PreForward, pre_hook);
297        registry.register_hook(HookType::PostForward, post_hook);
298
299        // Create a dummy module and tensor for testing
300        struct DummyModule;
301        impl Module for DummyModule {
302            fn forward(&self, input: &Tensor) -> Result<Tensor> {
303                Ok(input.clone())
304            }
305        }
306
307        let dummy_module = DummyModule;
308        let input = torsh_tensor::creation::zeros(&[2, 3])?;
309        let output = torsh_tensor::creation::ones(&[2, 3])?;
310
311        // Execute hooks
312        registry.execute_hooks(HookType::PreForward, &dummy_module, &input, None)?;
313        registry.execute_hooks(HookType::PostForward, &dummy_module, &input, Some(&output))?;
314
315        // Check execution log
316        let log = execution_log.lock();
317        assert_eq!(log.len(), 2);
318        assert_eq!(log[0], "pre_forward");
319        assert_eq!(log[1], "post_forward");
320
321        Ok(())
322    }
323
324    #[test]
325    fn test_module_base_hooks() -> Result<()> {
326        let mut base = ModuleBase::new();
327
328        // Test hook registration
329        let call_count = Arc::new(Mutex::new(0));
330        let call_count_clone = call_count.clone();
331
332        let hook = Box::new(
333            move |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
334                *call_count_clone.lock() += 1;
335                Ok(())
336            },
337        );
338
339        let handle = base.register_hook(HookType::PreForward, hook);
340        assert!(base.has_hooks(HookType::PreForward));
341        assert_eq!(base.hook_count(HookType::PreForward), 1);
342
343        // Test hook removal
344        assert!(base.remove_hook(HookType::PreForward, handle));
345        assert!(!base.has_hooks(HookType::PreForward));
346
347        Ok(())
348    }
349
350    #[test]
351    fn test_hook_error_propagation() -> Result<()> {
352        let mut registry = HookRegistry::new();
353
354        // Hook that returns an error
355        let error_hook = Box::new(
356            |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
357                Err(torsh_core::error::TorshError::Other(
358                    "Hook error".to_string(),
359                ))
360            },
361        );
362
363        registry.register_hook(HookType::PreForward, error_hook);
364
365        struct DummyModule;
366        impl Module for DummyModule {
367            fn forward(&self, input: &Tensor) -> Result<Tensor> {
368                Ok(input.clone())
369            }
370        }
371
372        let dummy_module = DummyModule;
373        let input = torsh_tensor::creation::zeros(&[2, 3])?;
374
375        // Hook execution should propagate the error
376        let result = registry.execute_hooks(HookType::PreForward, &dummy_module, &input, None);
377        assert!(result.is_err());
378
379        Ok(())
380    }
381
382    #[test]
383    fn test_multiple_hooks_execution_order() -> Result<()> {
384        let mut registry = HookRegistry::new();
385
386        let execution_order = Arc::new(Mutex::new(Vec::new()));
387
388        // Register multiple hooks
389        for i in 0..3 {
390            let order_clone = execution_order.clone();
391            let hook = Box::new(
392                move |_module: &dyn Module, _input: &Tensor, _output: Option<&Tensor>| {
393                    order_clone.lock().push(i);
394                    Ok(())
395                },
396            );
397            registry.register_hook(HookType::PreForward, hook);
398        }
399
400        assert_eq!(registry.hook_count(HookType::PreForward), 3);
401
402        struct DummyModule;
403        impl Module for DummyModule {
404            fn forward(&self, input: &Tensor) -> Result<Tensor> {
405                Ok(input.clone())
406            }
407        }
408
409        let dummy_module = DummyModule;
410        let input = torsh_tensor::creation::zeros(&[2, 3])?;
411
412        registry.execute_hooks(HookType::PreForward, &dummy_module, &input, None)?;
413
414        // Hooks should execute in registration order
415        let order = execution_order.lock();
416        assert_eq!(*order, vec![0, 1, 2]);
417
418        Ok(())
419    }
420
421    #[test]
422    fn test_hook_clear_operations() {
423        let mut registry = HookRegistry::new();
424
425        // Register hooks for different types
426        let dummy_hook = Box::new(|_: &dyn Module, _: &Tensor, _: Option<&Tensor>| Ok(()));
427        registry.register_hook(HookType::PreForward, dummy_hook);
428
429        let dummy_hook2 = Box::new(|_: &dyn Module, _: &Tensor, _: Option<&Tensor>| Ok(()));
430        registry.register_hook(HookType::PostForward, dummy_hook2);
431
432        assert!(registry.has_hooks(HookType::PreForward));
433        assert!(registry.has_hooks(HookType::PostForward));
434
435        // Clear specific hook type
436        registry.clear_hooks(HookType::PreForward);
437        assert!(!registry.has_hooks(HookType::PreForward));
438        assert!(registry.has_hooks(HookType::PostForward));
439
440        // Register another hook
441        let dummy_hook3 = Box::new(|_: &dyn Module, _: &Tensor, _: Option<&Tensor>| Ok(()));
442        registry.register_hook(HookType::PreBackward, dummy_hook3);
443        assert!(registry.has_hooks(HookType::PreBackward));
444
445        // Clear all hooks
446        registry.clear_all_hooks();
447        assert!(!registry.has_hooks(HookType::PreForward));
448        assert!(!registry.has_hooks(HookType::PostForward));
449        assert!(!registry.has_hooks(HookType::PreBackward));
450        assert!(!registry.has_hooks(HookType::PostBackward));
451    }
452
453    #[test]
454    fn test_modular_system_integrity() {
455        // Test that all modules are properly accessible and modular architecture works
456
457        // Test parameter creation
458        let tensor = torsh_tensor::creation::randn(&[3, 4]).unwrap();
459        let param = Parameter::new(tensor);
460        assert!(param.requires_grad());
461
462        // Test parameter statistics
463        let stats = param.stats().unwrap();
464        assert_eq!(stats.numel, 12);
465
466        // Test parameter collection
467        let mut collection = ParameterCollection::new();
468        collection.add("test_param".to_string(), param);
469        assert_eq!(collection.len(), 1);
470        assert!(!collection.is_empty());
471
472        // Test module base
473        let base = ModuleBase::new();
474        assert!(base.training());
475
476        // Test hook registry
477        let registry = HookRegistry::new();
478        assert!(!registry.has_hooks(HookType::PreForward));
479
480        // Test module config
481        let config = ModuleConfig::new();
482        assert!(config.training);
483        assert_eq!(config.dropout, 0.0);
484    }
485
486    #[test]
487    fn test_backward_compatibility() {
488        // Test that the modular system maintains full backward compatibility
489        // All original APIs should work exactly as before
490
491        // Test parameter creation (original API)
492        let tensor = torsh_tensor::creation::ones(&[2, 3]).unwrap();
493        let param = Parameter::new(tensor);
494        assert!(param.requires_grad());
495
496        // Test parameter access methods (original API)
497        let shape = param.shape().unwrap();
498        assert_eq!(shape, vec![2, 3]);
499
500        let numel = param.numel().unwrap();
501        assert_eq!(numel, 6);
502
503        // Test module base functionality (original API)
504        let mut base = ModuleBase::new();
505        base.register_parameter("test".to_string(), param);
506        assert_eq!(base.named_parameters().len(), 1);
507
508        // Test hook system (original API)
509        let mut registry = HookRegistry::new();
510        let hook = Box::new(|_: &dyn Module, _: &Tensor, _: Option<&Tensor>| Ok(()));
511        let handle = registry.register_hook(HookType::PreForward, hook);
512        assert!(registry.has_hooks(HookType::PreForward));
513        assert!(registry.remove_hook(HookType::PreForward, handle));
514    }
515}