Skip to main content

torsh_autograd/
lib.rs

1//! Automatic differentiation engine for ToRSh
2//!
3//! This crate provides a PyTorch-compatible autograd API that fully leverages
4//! scirs2-autograd's powerful automatic differentiation capabilities.
5//!
6//! # Quick Start
7//!
8//! ```rust,ignore
9//! use torsh_autograd::prelude::*;
10//!
11//! // Enable gradient computation
12//! let x = tensor::ones(&[2, 3]).requires_grad_(true);
13//! let y = x.pow(2).sum();
14//!
15//! // Compute gradients
16//! y.backward();
17//! let grad = x.grad();
18//! ```
19//!
20//! # Architecture
21//!
22//! The autograd system is built around several key components:
23//!
24//! - **Gradient computation**: Automatic computation of gradients through computation graphs
25//! - **Tensor operations**: Differentiable tensor operations with gradient tracking
26//! - **Variable management**: Thread-local variable environments for gradient storage
27//! - **Guard system**: RAII guards for gradient mode management
28//! - **Anomaly detection**: Detection and recovery from numerical anomalies
29//! - **SciRS2 integration**: Deep integration with the SciRS2 autograd system
30//! - **Hardware acceleration**: Multi-platform support (CUDA, Metal, WebGPU)
31//!
32//! # API Stability
33//!
34//! The crate follows semantic versioning with clearly defined stability levels:
35//!
36//! - **Stable APIs** ([`stable_api::stable`]): Core functionality with backward compatibility
37//! - **Beta APIs** ([`stable_api::beta`]): Feature-complete but may evolve
38//! - **Experimental APIs** ([`stable_api::experimental`]): May change significantly
39//!
40//! See [`stable_api`] for details on stability guarantees.
41//!
42//! # Examples
43//!
44//! The [`examples`] module provides comprehensive usage examples:
45//!
46//! - Basic gradient computation
47//! - Inference with `no_grad`
48//! - Gradient accumulation
49//! - Custom differentiable functions
50//! - Higher-order gradients
51//! - Hardware acceleration
52//! - Distributed training
53//!
54//! Run all examples: `examples::run_all_examples()`
55//!
56//! # Key Modules
57//!
58//! ## Core Autograd
59//! - [`autograd_traits`]: Core traits for differentiable tensors
60//! - [`gradient_storage`]: Thread-safe gradient storage management
61//! - [`grad_mode`]: Global gradient computation mode management
62//! - [`guards`]: RAII guards for automatic gradient mode restoration
63//! - [`variable_env`]: Thread-local variable environment management
64//!
65//! ## Advanced Features
66//! - [`complex_ops`]: Complex number operations with Wirtinger derivatives
67//! - [`anomaly_detection`]: Numerical anomaly detection and automatic recovery
68//! - [`gradient_clipping`]: Gradient clipping strategies
69//! - [`checkpoint_scheduler`]: Memory-efficient gradient checkpointing
70//! - [`higher_order_gradients`]: Higher-order derivative computation
71//!
72//! ## Hardware & Performance
73//! - [`hardware_acceleration`]: Multi-platform hardware acceleration
74//! - [`profiler`]: Performance profiling and analysis
75//! - [`simd_ops`]: SIMD-optimized gradient operations
76//! - [`distributed`]: Distributed gradient computation
77//!
78//! ## Integration & Compatibility
79//! - [`pytorch_compat`]: PyTorch compatibility layer
80//! - [`jax_transformations`]: JAX-style transformations
81//! - [`tensorflow_compat`]: TensorFlow compatibility
82//! - [`mlx_compat`]: Apple MLX compatibility
83//!
84//! # Feature Flags
85//!
86//! - `default`: Enables autograd, SIMD, and parallel features
87//! - `autograd`: SciRS2 autograd integration
88//! - `simd`: SIMD optimizations
89//! - `parallel`: Parallel gradient computation
90//! - `gpu`: GPU acceleration support
91//! - `webgpu`: WebGPU for browser deployment
92//! - `profiling`: Performance profiling tools
93//! - `scirs2-full`: All SciRS2 features
94
95// Core extracted modules for autograd functionality
96pub mod anomaly_alerts;
97pub mod anomaly_detection;
98pub mod autograd_traits;
99pub mod complex_ops;
100pub mod global_adapter;
101pub mod grad_mode;
102pub mod gradient_storage;
103pub mod guards;
104pub mod variable_env;
105
106// Existing specialized modules
107pub mod checkpoint_scheduler;
108pub mod common_utils;
109pub mod communication_efficient;
110pub mod compression;
111pub mod context;
112pub mod differentiable_programming;
113pub mod discrete_ops;
114pub mod distributed;
115pub mod error_diagnostics;
116pub mod error_handling;
117pub mod external_ad_integration;
118pub mod federated_learning;
119pub mod flamegraph;
120pub mod function;
121pub mod function_optimization;
122pub mod garbage_collection;
123pub mod gradient_checking;
124pub mod gradient_filtering;
125pub mod gradient_flow_analysis;
126pub mod gradient_scaling;
127pub mod gradient_scheduler;
128pub mod gradient_tracer;
129pub mod gradient_validation;
130pub mod graph_opt;
131pub mod graph_visualization;
132pub mod hyperparameter_optimization;
133pub mod iterative_solvers;
134pub mod jax_transformations;
135pub mod matrix_calculus;
136pub mod memory;
137pub mod meta_gradient;
138pub mod metrics_collection;
139pub mod mlx_compat;
140pub mod onnx_integration;
141pub mod operation_introspection;
142pub mod operation_replay;
143pub mod optimization_diff;
144pub mod parameter_server;
145pub mod profiler;
146pub mod property_testing;
147pub mod pytorch_compat;
148pub mod scirs2_integration;
149pub mod simd_gradient;
150pub mod simd_ops;
151pub mod staleness_handling;
152pub mod stochastic_graphs;
153pub mod structured_logging;
154pub mod symbolic;
155pub mod tensorflow_compat;
156pub mod visualization;
157pub mod vjp_optimization;
158
159// New modules for enhanced functionality
160pub mod auto_tuning;
161pub mod automatic_error_recovery;
162pub mod blas_integration;
163pub mod buffer_optimization;
164pub mod cross_framework_verification;
165pub mod custom_backends;
166pub mod edge_case_handling;
167pub mod exception_safety;
168pub mod gpu_gradient;
169pub mod graceful_degradation;
170pub mod gradient_clipping;
171pub mod gradient_hooks;
172pub mod hardware_acceleration;
173pub mod health_diagnostics;
174pub mod higher_order_gradients;
175pub mod integration_patterns;
176pub mod intelligent_chunking;
177pub mod interactive_debugger;
178pub mod neural_architecture_search;
179pub mod neural_ode;
180pub mod parallel_gradient;
181pub mod performance_regression;
182pub mod profiling_debugging_integration;
183pub mod progress_reporting;
184pub mod quantum_autograd;
185pub mod raii_resources;
186pub mod regression_testing;
187pub mod scirs2_integration_testing;
188pub mod specialized_gradient_libs;
189pub mod stress_testing;
190
191// Additional framework integration modules
192pub mod ad_framework_compatibility;
193
194// API stability and versioning
195pub mod stable_api;
196
197// Comprehensive examples and documentation
198pub mod examples;
199
200// Production monitoring and observability modules
201pub mod audit_logging;
202pub mod capacity_planning;
203pub mod error_rate_monitoring;
204pub mod flamegraph_generation;
205pub mod gradient_tracing;
206pub mod operation_cost_analysis;
207pub mod performance_dashboard;
208
209// Re-exports for convenience
210pub use crate::error_handling::{AutogradError, AutogradResult};
211
212pub use crate::autograd_traits::{
213    AutogradTensor, AutogradTensorFactory, BackwardTensor, GradientAccumulation,
214};
215
216pub use crate::global_adapter::{
217    backward_global, create_gradient_tensor, get_global_adapter, get_gradient_global,
218};
219
220pub use crate::grad_mode::{
221    is_grad_enabled, pop_grad_enabled, push_grad_enabled, set_grad_enabled, with_grad_mode,
222};
223
224pub use crate::guards::{enable_grad, no_grad, EnableGradGuard, GradModeGuard, NoGradGuard};
225
226pub use crate::gradient_storage::{
227    get_gradient_storage, GlobalGradientStorage, GradientStorage, HashMapGradientStorage,
228};
229
230pub use crate::variable_env::{
231    clear_variable_env, get_or_create_variable_env, handle_inplace_operation,
232    is_variable_env_initialized, validate_inplace_operation, with_variable_env, InplaceConfig,
233    InplaceStrategy,
234};
235
236pub use crate::complex_ops::backward_complex;
237pub use crate::pytorch_compat::backward;
238
239pub use crate::anomaly_detection::{
240    detect_complex_anomalies,
241    recovery::{
242        AnomalyRecoverySystem, RecoveryConfig, RecoveryResult, RecoveryStats, RecoveryStrategy,
243    },
244};
245
246pub use crate::scirs2_integration::{GradientTensor, SciRS2AutogradAdapter};
247
248pub use crate::auto_tuning::{
249    AppliedOptimization, AutoTuningController, OptimizationType, ParameterValue,
250    PerformanceSnapshot, TuningConfig, TuningRecommendation, TuningStatistics,
251};
252
253pub use crate::error_diagnostics::{
254    DiagnosticRecommendation, DiagnosticReport, DiagnosticStatus, DiagnosticsConfig,
255    ErrorCorrelation, ErrorDiagnosticsSystem, ErrorPattern, LabeledErrorEvent, MLAnalysisResult,
256    MLPatternPrediction, MLPatternRecognitionSystem, MLSystemConfig, PatternLabel, SeverityLevel,
257    TemporalContext,
258};
259
260// Version information
261pub const VERSION: &str = env!("CARGO_PKG_VERSION");
262pub const VERSION_MAJOR: u32 = 0;
263pub const VERSION_MINOR: u32 = 1;
264pub const VERSION_PATCH: u32 = 0;
265
266// Common imports and utilities
267use torsh_core::error::{Result, TorshError};
268
269/// Version tracking for tensor operations
270///
271/// This system tracks tensor versions to detect when in-place operations
272/// might invalidate the computation graph.
273#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
274pub struct TensorVersion {
275    /// Version number that increments with each modification
276    pub version: usize,
277    /// Unique tensor identifier
278    pub tensor_id: usize,
279}
280
281impl TensorVersion {
282    /// Create a new tensor version
283    pub fn new(tensor_id: usize) -> Self {
284        Self {
285            version: 0,
286            tensor_id,
287        }
288    }
289
290    /// Increment the version (for in-place operations)
291    pub fn increment(&mut self) -> Self {
292        self.version += 1;
293        *self
294    }
295
296    /// Check if this version is compatible with another version
297    pub fn is_compatible_with(&self, other: &TensorVersion) -> bool {
298        self.tensor_id == other.tensor_id && self.version == other.version
299    }
300}
301
302/// Global tensor ID counter for unique identification
303static TENSOR_ID_COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
304
305/// Generate a unique tensor ID
306pub fn new_tensor_id() -> usize {
307    TENSOR_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
308}
309
310/// In-place operation handling with gradient safety
311///
312/// This module provides utilities for safely handling in-place tensor operations
313/// while preserving gradient computation capabilities.
314pub mod inplace_versioning {
315    use super::*;
316
317    /// Strategy for handling version conflicts in in-place operations
318    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
319    pub enum VersionConflictStrategy {
320        /// Error on version conflicts
321        Error,
322        /// Warn about version conflicts but allow operation
323        Warn,
324        /// Create a copy before in-place operation
325        CopyOnWrite,
326        /// Silently allow version conflicts
327        Allow,
328    }
329
330    impl Default for VersionConflictStrategy {
331        fn default() -> Self {
332            Self::Warn
333        }
334    }
335
336    /// Check for version conflicts before in-place operations
337    pub fn check_version_compatibility(
338        current: &TensorVersion,
339        expected: &TensorVersion,
340        strategy: VersionConflictStrategy,
341        operation_name: &str,
342    ) -> Result<()> {
343        if !current.is_compatible_with(expected) {
344            let message = format!(
345                "Version conflict in in-place operation '{}': expected version {} but found {}",
346                operation_name, expected.version, current.version
347            );
348
349            match strategy {
350                VersionConflictStrategy::Error => {
351                    return Err(TorshError::AutogradError(message));
352                }
353                VersionConflictStrategy::Warn => {
354                    tracing::warn!("{}", message);
355                }
356                VersionConflictStrategy::CopyOnWrite => {
357                    tracing::info!("Triggering copy-on-write for: {}", message);
358                }
359                VersionConflictStrategy::Allow => {
360                    tracing::debug!("Allowing version conflict: {}", message);
361                }
362            }
363        }
364        Ok(())
365    }
366
367    /// Update tensor version after in-place operation
368    pub fn update_version_after_inplace(
369        version: &mut TensorVersion,
370        operation_name: &str,
371    ) -> TensorVersion {
372        let old_version = *version;
373        let new_version = version.increment();
374        tracing::debug!(
375            "Updated tensor {} version from {} to {} after in-place operation '{}'",
376            version.tensor_id,
377            old_version.version,
378            new_version.version,
379            operation_name
380        );
381        new_version
382    }
383}
384
385/// Public prelude for convenient importing
386pub mod prelude {
387    pub use crate::autograd_traits::{AutogradTensor, BackwardTensor, GradientAccumulation};
388    pub use crate::global_adapter::{
389        backward_global, create_gradient_tensor, get_global_adapter, get_gradient_global,
390    };
391    pub use crate::grad_mode::{is_grad_enabled, set_grad_enabled, with_grad_mode};
392    pub use crate::gradient_storage::{get_gradient_storage, GradientStorage};
393    pub use crate::guards::{enable_grad, no_grad, EnableGradGuard, NoGradGuard};
394    pub use crate::variable_env::{with_variable_env, InplaceStrategy};
395    pub use crate::{new_tensor_id, TensorVersion};
396
397    // Neural Architecture Search
398    pub use crate::neural_architecture_search::{
399        ConvOperation, DARTSArchitecture, DARTSPruningResult, IdentityOperation, MixedOperation,
400        ProgressiveDARTS, ProgressiveStageInfo, SamplingStrategy, SearchableOperation,
401        ZeroOperation, DARTS,
402    };
403
404    // Neural ODE
405    pub use crate::neural_ode::{
406        AdjointMethod, AdjointSolution, IntegrationMethod, NeuralODE, NeuralODELayer, ODESolver,
407        ODESolverConfig, ODESystem,
408    };
409
410    // Parallel Gradient Computation (SciRS2-Core Integration - Phase 1)
411    pub use crate::parallel_gradient::{
412        configure_global_parallel, get_global_parallel_computer, get_global_parallel_computer_mut,
413        ParallelConfig, ParallelGradientComputer, ParallelStats,
414    };
415
416    // GPU Gradient Computation (SciRS2-Core Integration - Phase 2)
417    pub use crate::gpu_gradient::{
418        get_global_gpu_computer, initialize_global_gpu, is_global_gpu_available, ActivationType,
419        GpuBackend, GpuConfig, GpuGradientComputer, GpuStats,
420    };
421
422    // SIMD Gradient Computation (SciRS2-Core Integration - Phase 3)
423    pub use crate::simd_gradient::{
424        configure_global_simd, get_global_simd_computer, get_global_simd_computer_mut,
425        SimdCapability, SimdConfig, SimdGradientComputer, SimdStats,
426    };
427
428    // Intelligent Chunking (SciRS2-Core Integration - Phase 4)
429    pub use crate::intelligent_chunking::{
430        configure_global_chunker, get_global_chunker, get_global_chunker_mut, ChunkingConfig,
431        ChunkingStats, ChunkingStrategy, IntelligentChunker,
432    };
433
434    // Advanced Gradient Clipping
435    pub use crate::gradient_clipping::{
436        clip_gradients_global, configure_global_clipper, get_global_clipper,
437        get_global_clipper_mut, ClippingStats, ClippingStrategy, GradientClipper,
438    };
439
440    // Higher-Order Gradients (Hessian, Jacobian)
441    pub use crate::higher_order_gradients::{
442        configure_global_higher_order, get_global_higher_order, get_global_higher_order_mut,
443        ComputationMode, GradientOrder, HigherOrderConfig, HigherOrderGradient, HigherOrderStats,
444    };
445
446    // Gradient Hooks System
447    pub use crate::gradient_hooks::{
448        get_global_hook_manager, get_global_hook_manager_mut, register_global_hook, GradientHook,
449        GradientHookManager, HookContext, HookPriority, HookStats, HookType,
450    };
451
452    // Automatic Error Recovery
453    pub use crate::automatic_error_recovery::{
454        get_global_recovery, recover_from_error, AutomaticErrorRecovery, CorrectiveTransform,
455        RecoveryAction, RecoveryStrategy, TransientFailureType,
456    };
457    pub use crate::with_error_recovery;
458
459    // Edge Case Handling
460    pub use crate::edge_case_handling::{
461        get_global_edge_case_handler, handle_tensor_edge_cases, validate_tensor_shapes,
462        EdgeCaseHandler, EdgeCaseStrategy, EdgeCaseTransformation, EdgeCaseType, TensorInfo,
463    };
464
465    // Quantum Computing Autograd
466    pub use crate::quantum_autograd::{
467        Complex, Observable, PauliX, QuantumCircuit, QuantumExpectationValue, QuantumGate,
468        QuantumState, QuantumStateGradient, Qubit, RotationY, VQEResult, CNOT, VQE,
469    };
470
471    // Cross-Framework Gradient Verification
472    pub use crate::cross_framework_verification::{
473        get_global_verifier, initialize_verification_frameworks, ComparisonTolerance,
474        CrossFrameworkVerifier, FrameworkAdapter, GradientComparisonResult, GradientData,
475        MockPyTorchAdapter, SupportedFramework, TorshFrameworkAdapter, VerificationReport,
476    };
477
478    // Regression Testing
479    pub use crate::regression_testing::{
480        get_global_regression_tester, GradientRegressionTester, GradientTestCase,
481        RegressionTestResult, RegressionTestStatistics,
482    };
483
484    // Exception Safety
485    pub use crate::exception_safety::{
486        get_global_executor, AutogradTransaction, ComputationGraphGuard, ExceptionSafeExecutor,
487        ExceptionSafetyAnalyzer, ExceptionSafetyLevel, GradientStorageGuard, ResourceGuard,
488        SafetyViolation, SafetyViolationReport, TransactionOperation,
489    };
490
491    // Graceful Degradation
492    pub use crate::graceful_degradation::{
493        get_global_degradation_manager, DegradationEvent, DegradationStatistics,
494        DegradationStrategy, FallbackFunction, GracefulDegradationManager, MatrixMultiplyFallback,
495        OperationCategory, UnsupportedOperationInfo, UnsupportedReason,
496    };
497
498    // SciRS2 Integration Testing
499    pub use crate::scirs2_integration_testing::{
500        get_global_integration_tester, run_scirs2_integration_tests, CompatibilitySummary,
501        PerformanceSummary, SciRS2IntegrationTestCase, SciRS2IntegrationTestSuite,
502        SciRS2IntegrationTester, SciRS2TestResult, SciRS2Version, TestCategory,
503        TestCategoryResults,
504    };
505
506    // Integration Patterns and Documentation
507    pub use crate::integration_patterns::{
508        IntegrationDocumentation, IntegrationPatterns, MigrationGuide, MigrationScenario, Pattern,
509        PatternCategory, PatternDocumentation, TroubleshootingGuide, TroubleshootingIssue,
510    };
511
512    // BLAS Integration
513    pub use crate::blas_integration::{
514        blas_dot, blas_gemm, blas_gemv, get_global_blas_manager, BlasConfig, BlasImplementation,
515        BlasManager, BlasOperation, BlasPerformanceReport, BlasProvider, PureRustBlasProvider,
516    };
517
518    // Specialized Gradient Libraries
519    pub use crate::specialized_gradient_libs::{
520        get_global_specialized_manager, BenchmarkReport, CasADiLibrary, ComputationResult,
521        Function, GradientComputationType, LibraryUsageReport, QuadraticFunction, SparseGradient,
522        SpecializedGradientLibrary, SpecializedLibConfig, SpecializedLibrary,
523        SpecializedLibraryManager,
524    };
525
526    // Custom Autograd Backends
527    pub use crate::custom_backends::{
528        get_active_backend, get_global_backend_registry, AutogradBackend, BackendCapability,
529        BackendConfig, BackendInfo, BackendRegistry, BackendTensor, CustomOperation, DataType,
530        DeviceConfig, DeviceType, GradFunction, OperationContext, OptimizationLevel,
531        PerformanceStats, ReferenceBackend,
532    };
533
534    // Hardware Acceleration
535    pub use crate::hardware_acceleration::{
536        get_global_acceleration_manager, AccelerationConfig, AcceleratorBenchmarkReport,
537        AcceleratorType, AcceleratorUsageReport, Conv2DParams, CudaAccelerator, DeviceStats,
538        HardwareAccelerationManager, HardwareAccelerator, HardwareCapability, HardwareDevice,
539        HardwareMemoryHandle, MetalAccelerator, OptimizationLevel as HardwareOptimizationLevel,
540        PrecisionPreference,
541    };
542
543    // Profiling and Debugging Integration
544    pub use crate::profiling_debugging_integration::{
545        get_global_profiling_debugging_manager, AnalysisCapability, CPUProfile, DebuggingConfig,
546        DebuggingReport, DebuggingSession, DebuggingTool, ExternalDebugger, ExternalProfiler,
547        GPUProfile, GdbDebugger, Hotspot, IntegrationConfig, IntegrationReport, MemoryError,
548        MemoryProfile, PerfProfiler, ProfilingConfig, ProfilingDebuggingManager, ProfilingReport,
549        ProfilingSession, ProfilingTool, StackTrace, ThreadError,
550    };
551
552    // AD Framework Compatibility
553    pub use crate::ad_framework_compatibility::{
554        check_framework_compatibility, convert_tensor, get_global_compatibility_manager,
555        migrate_model, ADFramework, ADFrameworkCompatibilityManager, AutomationLevel,
556        CompatibilityLevel, CompatibilityReport, CustomOperationDefinition, EffortLevel,
557        FrameworkAdapter as ADFrameworkAdapter, FrameworkCapabilities, FrameworkTensor,
558        MigrationCapability, MigrationData, MigrationOperation, MigrationPlan, MigrationResult,
559        MigrationStep, PerformanceComparison, PerformanceMetrics, PyTorchAdapter, PyTorchTensor,
560        RequiredTransformation, UniversalDataType, UniversalOperation, UniversalTensor,
561        ValidationResult,
562    };
563}
564
565/// Accumulate gradients with overflow protection
566pub mod accumulate {
567    use super::*;
568    use crate::autograd_traits::AutogradTensor;
569    use num_traits::Float;
570
571    /// Accumulate gradients safely with overflow detection
572    pub fn accumulate_gradient_safe<T>(
573        existing: &mut dyn AutogradTensor<T>,
574        new_grad: &dyn AutogradTensor<T>,
575        overflow_threshold: Option<T>,
576    ) -> Result<()>
577    where
578        T: torsh_core::dtype::TensorElement
579            + Float
580            + Clone
581            + std::fmt::Debug
582            + std::fmt::Display
583            + Send
584            + Sync,
585    {
586        // Check shapes match
587        if existing.shape() != new_grad.shape() {
588            return Err(TorshError::AutogradError(format!(
589                "Shape mismatch in gradient accumulation: {:?} vs {:?}",
590                existing.shape(),
591                new_grad.shape()
592            )));
593        }
594
595        // Get data for accumulation
596        let existing_data = existing.data();
597        let new_data = new_grad.data();
598
599        // Check for overflow if threshold provided
600        if let Some(threshold) = overflow_threshold {
601            for (existing_val, new_val) in existing_data.iter().zip(new_data.iter()) {
602                let sum = *existing_val + *new_val;
603                if sum.abs() > threshold {
604                    return Err(TorshError::AutogradError(format!(
605                        "Gradient accumulation overflow detected: {} + {} = {} > {}",
606                        existing_val, new_val, sum, threshold
607                    )));
608                }
609            }
610        }
611
612        // Since we can't modify trait objects directly, this is a limitation
613        // Real implementation would be in concrete tensor types
614        tracing::debug!(
615            "Gradient accumulation requested for tensor with shape {:?}",
616            existing.shape()
617        );
618
619        Ok(())
620    }
621
622    /// Check if gradient accumulation would cause overflow
623    pub fn check_accumulation_overflow<T>(val1: T, val2: T, threshold: T) -> bool
624    where
625        T: Float + PartialOrd,
626    {
627        let sum = val1 + val2;
628        sum.abs() > threshold
629    }
630}
631
632/// Gradient clipping utilities
633pub mod clip {
634    use super::*;
635    use crate::autograd_traits::AutogradTensor;
636    use num_traits::Float;
637
638    /// Clip gradients by global norm
639    pub fn clip_grad_norm<T>(
640        gradients: &[&dyn AutogradTensor<T>],
641        max_norm: T,
642        norm_type: f32,
643    ) -> Result<T>
644    where
645        T: torsh_core::dtype::TensorElement
646            + Float
647            + Clone
648            + std::fmt::Debug
649            + Send
650            + Sync
651            + std::fmt::Display,
652        f32: From<T>,
653    {
654        if gradients.is_empty() {
655            return Ok(<T as num_traits::Zero>::zero());
656        }
657
658        // Calculate total norm
659        let mut total_norm = <T as num_traits::Zero>::zero();
660
661        for grad in gradients {
662            let data = grad.data();
663            for &val in data.iter() {
664                if norm_type == 2.0 {
665                    total_norm = total_norm + val * val;
666                } else if norm_type == 1.0 {
667                    total_norm = total_norm + val.abs();
668                } else {
669                    let abs_val = val.abs();
670                    total_norm = total_norm
671                        + abs_val.powf(T::from(norm_type).unwrap_or(<T as num_traits::One>::one()));
672                }
673            }
674        }
675
676        if norm_type == 2.0 {
677            total_norm = total_norm.sqrt();
678        } else if norm_type != 1.0 {
679            total_norm =
680                total_norm.powf(T::from(1.0 / norm_type).unwrap_or(<T as num_traits::One>::one()));
681        }
682
683        tracing::debug!("Calculated gradient norm: {:?}", total_norm);
684
685        // Calculate clipping ratio
686        let clip_coef =
687            max_norm / (total_norm + T::from(1e-6).unwrap_or(<T as num_traits::One>::one()));
688        let clip_coef = clip_coef.min(<T as num_traits::One>::one());
689
690        tracing::debug!("Gradient clipping coefficient: {:?}", clip_coef);
691
692        Ok(total_norm)
693    }
694
695    /// Clip gradients by value
696    pub fn clip_grad_value<T>(
697        gradient: &dyn AutogradTensor<T>,
698        min_value: T,
699        max_value: T,
700    ) -> Result<Vec<T>>
701    where
702        T: torsh_core::dtype::TensorElement + Float + Clone + std::fmt::Debug + Send + Sync,
703    {
704        let data = gradient.data();
705        let clipped: Vec<T> = data
706            .iter()
707            .map(|&val| val.max(min_value).min(max_value))
708            .collect();
709
710        Ok(clipped)
711    }
712}
713
714/// Forward-mode automatic differentiation
715pub mod forward_mode {
716
717    use num_traits::Float;
718
719    /// Dual number for forward-mode AD
720    #[derive(Debug, Clone, Copy, PartialEq)]
721    pub struct Dual<T> {
722        /// Value
723        pub value: T,
724        /// Derivative
725        pub derivative: T,
726    }
727
728    impl<T: Float> Dual<T> {
729        /// Create a new dual number
730        pub fn new(value: T, derivative: T) -> Self {
731            Self { value, derivative }
732        }
733
734        /// Create a variable (derivative = 1)
735        pub fn variable(value: T) -> Self {
736            Self::new(value, <T as num_traits::One>::one())
737        }
738
739        /// Create a constant (derivative = 0)
740        pub fn constant(value: T) -> Self {
741            Self::new(value, T::zero())
742        }
743    }
744
745    impl<T: Float> std::ops::Add for Dual<T> {
746        type Output = Self;
747
748        fn add(self, rhs: Self) -> Self::Output {
749            Self::new(self.value + rhs.value, self.derivative + rhs.derivative)
750        }
751    }
752
753    impl<T: Float> std::ops::Mul for Dual<T> {
754        type Output = Self;
755
756        fn mul(self, rhs: Self) -> Self::Output {
757            Self::new(
758                self.value * rhs.value,
759                self.derivative * rhs.value + self.value * rhs.derivative,
760            )
761        }
762    }
763
764    /// Compute forward-mode derivative
765    pub fn forward_derivative<T, F>(input: T, f: F) -> (T, T)
766    where
767        T: Float + Clone,
768        F: Fn(Dual<T>) -> Dual<T>,
769    {
770        let dual_input = Dual::variable(input);
771        let dual_output = f(dual_input);
772        (dual_output.value, dual_output.derivative)
773    }
774}
775
776#[cfg(test)]
777mod tests {
778    use super::*;
779
780    #[test]
781    fn test_tensor_version() {
782        let mut version = TensorVersion::new(42);
783        assert_eq!(version.tensor_id, 42);
784        assert_eq!(version.version, 0);
785
786        let new_version = version.increment();
787        assert_eq!(new_version.version, 1);
788        assert_eq!(version.version, 1);
789    }
790
791    #[test]
792    fn test_tensor_version_compatibility() {
793        let version1 = TensorVersion::new(1);
794        let version2 = TensorVersion::new(1);
795        let version3 = TensorVersion::new(2);
796
797        assert!(version1.is_compatible_with(&version2));
798        assert!(!version1.is_compatible_with(&version3));
799    }
800
801    #[test]
802    fn test_unique_tensor_ids() {
803        let id1 = new_tensor_id();
804        let id2 = new_tensor_id();
805        let id3 = new_tensor_id();
806
807        assert_ne!(id1, id2);
808        assert_ne!(id2, id3);
809        assert_ne!(id1, id3);
810    }
811
812    #[test]
813    fn test_dual_number_arithmetic() {
814        use forward_mode::Dual;
815
816        let x = Dual::new(2.0, 1.0);
817        let y = Dual::new(3.0, 0.0);
818
819        let sum = x + y;
820        assert_eq!(sum.value, 5.0);
821        assert_eq!(sum.derivative, 1.0);
822
823        let product = x * y;
824        assert_eq!(product.value, 6.0);
825        assert_eq!(product.derivative, 3.0);
826    }
827
828    #[test]
829    fn test_forward_derivative() {
830        use forward_mode::forward_derivative;
831
832        // f(x) = x^2, f'(x) = 2x
833        let (value, derivative) = forward_derivative(3.0, |x| x * x);
834        assert_eq!(value, 9.0);
835        assert_eq!(derivative, 6.0);
836    }
837}