train_station/tensor/ops/
log.rs

1//! Natural logarithm operations for tensors
2//!
3//! Provides element-wise natural logarithm function following PyTorch conventions with
4//! comprehensive automatic differentiation support and optimized scalar computation.
5//!
6//! # Key Features
7//!
8//! - **Natural Logarithm**: `log()` - Computes ln(x) for each element (PyTorch `log()` equivalent)
9//! - **Automatic Differentiation**: Full gradtrack support with gradient d/dx log(x) = 1/x
10//! - **Optimized Scalar Math**: Uses optimized scalar logarithm for accuracy and simplicity
11//! - **Domain Validation**: Automatic validation of positive input values
12//! - **Cache Optimization**: Memory access patterns optimized for modern CPUs
13//! - **Mathematical Accuracy**: High-precision logarithm computation
14//!
15//! # Mathematical Properties
16//!
17//! The natural logarithm function ln(x) has the following properties:
18//! - ln(1) = 0
19//! - ln(e) = 1 (where e ≈ 2.71828 is Euler's number)
20//! - ln(x*y) = ln(x) + ln(y)
21//! - ln(x^n) = n * ln(x)
22//! - Domain: x > 0 (positive real numbers only)
23//! - Gradient: d/dx ln(x) = 1/x
24//!
25//! # Performance Characteristics
26//!
27//! - **Scalar Optimization**: Optimized scalar logarithm computation
28//! - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
29//! - **Cache-friendly Access**: Linear memory access patterns
30//! - **Mathematical Accuracy**: High-precision floating-point logarithm
31//! - **Domain Validation**: Efficient positive value checking
32//! - **Gradient Optimization**: Efficient gradtrack with NoGradTrack support
33
34use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
35use crate::tensor::core::Tensor;
36
37impl Tensor {
38    /// Internal optimized logarithm operation
39    ///
40    /// Performs element-wise natural logarithm computation using optimized scalar math
41    /// for maximum accuracy and performance. This is the core implementation
42    /// used by `log()`.
43    ///
44    /// # Returns
45    ///
46    /// A new tensor containing the natural logarithm of each element
47    ///
48    /// # Performance Characteristics
49    ///
50    /// - **Scalar Optimization**: Uses optimized scalar logarithm for accuracy
51    /// - **Unrolled Loops**: 4x unrolling for optimal instruction throughput
52    /// - **Cache-friendly**: Linear memory access patterns
53    /// - **Mathematical Accuracy**: High-precision floating-point logarithm
54    /// - **Zero-sized Handling**: Fast return for empty tensors
55    /// - **Domain Validation**: Efficient positive value checking
56    ///
57    /// # Implementation Details
58    ///
59    /// Uses scalar logarithm computation for maximum mathematical accuracy.
60    /// Implements 4x unrolled loops for optimal instruction-level parallelism.
61    /// Validates that all input values are positive to ensure mathematical correctness.
62    ///
63    /// # Panics
64    ///
65    /// Panics if any element is non-positive (x <= 0) as logarithm is undefined
66    /// for non-positive real numbers.
67    #[inline]
68    pub(crate) fn log_optimized(&self) -> Tensor {
69        let mut output = Tensor::new(self.shape().dims.clone());
70
71        if self.size() == 0 {
72            return output;
73        }
74
75        unsafe {
76            let src = self.as_ptr();
77            let dst = output.as_mut_ptr();
78            let size = self.size();
79            let mut i = 0;
80            // Unrolled scalar loop
81            while i + 4 <= size {
82                let x0 = *src.add(i);
83                let x1 = *src.add(i + 1);
84                let x2 = *src.add(i + 2);
85                let x3 = *src.add(i + 3);
86                assert!(
87                    x0 > 0.0 && x1 > 0.0 && x2 > 0.0 && x3 > 0.0,
88                    "log domain error: x <= 0"
89                );
90                *dst.add(i) = x0.ln();
91                *dst.add(i + 1) = x1.ln();
92                *dst.add(i + 2) = x2.ln();
93                *dst.add(i + 3) = x3.ln();
94                i += 4;
95            }
96            while i < size {
97                let x = *src.add(i);
98                assert!(x > 0.0, "log domain error: x <= 0");
99                *dst.add(i) = x.ln();
100                i += 1;
101            }
102        }
103
104        output
105    }
106
107    /// Element-wise natural logarithm.
108    ///
109    /// Computes the natural logarithm for each element: `output[i] = ln(self[i])`
110    ///
111    /// # Returns
112    /// A new tensor with the natural logarithm of each element
113    ///
114    /// # Examples
115    ///
116    /// ## Basic Natural Logarithm
117    ///
118    /// ```
119    /// use train_station::Tensor;
120    ///
121    /// let a = Tensor::from_slice(&[1.0, 2.71828, 7.38906], vec![3]).unwrap();
122    /// let b = a.log();
123    /// assert_eq!(b.shape().dims, vec![3]);
124    /// assert_eq!(b.get(&[0]), 0.0); // ln(1) = 0
125    /// assert!((b.get(&[1]) - 1.0).abs() < 1e-5); // ln(e) ≈ 1
126    /// assert!((b.get(&[2]) - 2.0).abs() < 1e-5); // ln(e^2) ≈ 2
127    /// ```
128    ///
129    /// ## Mathematical Properties
130    ///
131    /// ```
132    /// use train_station::Tensor;
133    ///
134    /// let a = Tensor::from_slice(&[4.0, 8.0, 16.0], vec![3]).unwrap();
135    /// let b = a.log();
136    /// assert_eq!(b.shape().dims, vec![3]);
137    /// assert!((b.get(&[0]) - 1.38629).abs() < 1e-5); // ln(4) ≈ 1.38629
138    /// assert!((b.get(&[1]) - 2.07944).abs() < 1e-5); // ln(8) ≈ 2.07944
139    /// assert!((b.get(&[2]) - 2.77259).abs() < 1e-5); // ln(16) ≈ 2.77259
140    /// ```
141    ///
142    /// # Panics
143    /// Panics if any element is non-positive (x <= 0)
144    #[inline]
145    pub fn log(&self) -> Tensor {
146        let mut result = self.log_optimized();
147        if self.requires_grad() && is_grad_enabled() {
148            result.set_requires_grad_internal(true);
149            let grad_fn = GradFn::Log {
150                saved_input: Box::new(self.clone()),
151            };
152            result.set_grad_fn(grad_fn.clone());
153            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
154        }
155        result
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_log_basic() {
165        let data = [1.0, 2.0, 3.0, 4.0];
166        let x = Tensor::from_slice(&data, vec![2, 2]).unwrap();
167        let y = x.log_optimized();
168        unsafe {
169            let yd = std::slice::from_raw_parts(y.as_ptr(), y.size());
170            for i in 0..y.size() {
171                assert!((yd[i] - data[i].ln()).abs() < 1e-6);
172            }
173        }
174    }
175
176    #[test]
177    #[should_panic]
178    fn test_log_domain_panic() {
179        let data = [1.0, 0.0];
180        let x = Tensor::from_slice(&data, vec![2]).unwrap();
181        let _ = x.log_optimized();
182    }
183
184    #[test]
185    fn test_log_gradtrack() {
186        let x = Tensor::from_slice(&[1.0, 2.0, 4.0], vec![3])
187            .unwrap()
188            .with_requires_grad();
189        let mut y = x.log();
190        y.backward(None);
191        let gx = x.grad_by_value().expect("grad missing");
192        // d/dx log(x) = 1/x
193        assert!((gx.get(&[0]) - 1.0).abs() < 1e-6);
194        assert!((gx.get(&[1]) - 0.5).abs() < 1e-6);
195        assert!((gx.get(&[2]) - 0.25).abs() < 1e-6);
196    }
197}