slsl/unary/
relu.rs

1use anyhow::Result;
2
3use crate::{global_backend, DType, OpsTrait, StorageTrait, Tensor, TensorBase, UninitVec};
4
5impl<S: StorageTrait> TensorBase<S> {
6    #[inline(always)]
7    pub fn relu(&self) -> Result<Tensor> {
8        if self.is_contiguous() {
9            self.relu_contiguous()
10        } else {
11            self.relu_non_contiguous()
12        }
13    }
14
15    #[inline(always)]
16    fn relu_contiguous(&self) -> Result<Tensor> {
17        let numel = self.shape.numel();
18        let backend = global_backend();
19
20        match self.dtype {
21            DType::Bool | DType::Uint8 | DType::Uint16 | DType::Uint32 | DType::Uint64 => {
22                self.clone_or_copy()
23            }
24            DType::Fp32 => {
25                let input_data = self.as_slice::<f32>()?;
26                let out = UninitVec::<f32>::new(numel).init_with(|dst| {
27                    backend.relu_f32(input_data, dst);
28                });
29                Tensor::from_vec(out, self.shape)
30            }
31            DType::Fp64 => {
32                let input_data = self.as_slice::<f64>()?;
33                let out = UninitVec::<f64>::new(numel).init_with(|dst| {
34                    backend.relu_f64(input_data, dst);
35                });
36                Tensor::from_vec(out, self.shape)
37            }
38            DType::Int8 => {
39                let input_data = self.as_slice::<i8>()?;
40                let out = UninitVec::<i8>::new(numel).init_with(|dst| {
41                    backend.relu_i8(input_data, dst);
42                });
43                Tensor::from_vec(out, self.shape)
44            }
45            DType::Int16 => {
46                let input_data = self.as_slice::<i16>()?;
47                let out = UninitVec::<i16>::new(numel).init_with(|dst| {
48                    backend.relu_i16(input_data, dst);
49                });
50                Tensor::from_vec(out, self.shape)
51            }
52            DType::Int32 => {
53                let input_data = self.as_slice::<i32>()?;
54                let out = UninitVec::<i32>::new(numel).init_with(|dst| {
55                    backend.relu_i32(input_data, dst);
56                });
57                Tensor::from_vec(out, self.shape)
58            }
59            DType::Int64 => {
60                let input_data = self.as_slice::<i64>()?;
61                let out = UninitVec::<i64>::new(numel).init_with(|dst| {
62                    backend.relu_i64(input_data, dst);
63                });
64                Tensor::from_vec(out, self.shape)
65            }
66            DType::Fp16 => {
67                let input_data = self.as_slice::<half::f16>()?;
68                let out = UninitVec::<half::f16>::new(numel).init_with(|dst| {
69                    backend.relu_f16(input_data, dst);
70                });
71                Tensor::from_vec(out, self.shape)
72            }
73            DType::Bf16 => {
74                let input_data = self.as_slice::<half::bf16>()?;
75                let out = UninitVec::<half::bf16>::new(numel).init_with(|dst| {
76                    backend.relu_bf16(input_data, dst);
77                });
78                Tensor::from_vec(out, self.shape)
79            }
80            _ => anyhow::bail!("Unsupported dtype for relu operation: {:?}", self.dtype),
81        }
82    }
83
84    #[inline(always)]
85    fn relu_non_contiguous(&self) -> Result<Tensor> {
86        match self.dtype {
87            DType::Bool | DType::Uint8 | DType::Uint16 | DType::Uint32 | DType::Uint64 => {
88                self.clone_or_copy()
89            }
90            DType::Fp32 => self.map_non_contiguous::<f32>(|x| x.max(0.0)),
91            DType::Fp64 => self.map_non_contiguous::<f64>(|x| x.max(0.0)),
92            DType::Int8 => self.map_non_contiguous::<i8>(|x| *x.max(&0)),
93            DType::Int16 => self.map_non_contiguous::<i16>(|x| *x.max(&0)),
94            DType::Int32 => self.map_non_contiguous::<i32>(|x| *x.max(&0)),
95            DType::Int64 => self.map_non_contiguous::<i64>(|x| *x.max(&0)),
96            DType::Fp16 => {
97                self.map_non_contiguous::<half::f16>(|x| half::f16::from_f32(x.to_f32().max(0.0)))
98            }
99            DType::Bf16 => {
100                self.map_non_contiguous::<half::bf16>(|x| half::bf16::from_f32(x.to_f32().max(0.0)))
101            }
102            _ => anyhow::bail!("Unsupported dtype for relu operation: {:?}", self.dtype),
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::s;
111
112    #[test]
113    fn test_relu_f32() -> Result<()> {
114        let tensor = Tensor::from_vec(vec![1.0f32, -2.0, 3.0, -4.0], [2, 2])?;
115        let relu_output = tensor.relu()?;
116
117        assert_eq!(relu_output.at::<f32>([0, 0]), 1.0);
118        assert_eq!(relu_output.at::<f32>([0, 1]), 0.0);
119        assert_eq!(relu_output.at::<f32>([1, 0]), 3.0);
120        assert_eq!(relu_output.at::<f32>([1, 1]), 0.0);
121
122        Ok(())
123    }
124
125    #[test]
126    fn test_relu_f64() -> Result<()> {
127        let tensor = Tensor::from_vec(vec![1.0f64, -2.0, 3.0], [3])?;
128        let relu_output = tensor.relu()?;
129
130        assert_eq!(relu_output.at::<f64>([0]), 1.0);
131        assert_eq!(relu_output.at::<f64>([1]), 0.0);
132        assert_eq!(relu_output.at::<f64>([2]), 3.0);
133
134        Ok(())
135    }
136
137    #[test]
138    fn test_relu_i32() -> Result<()> {
139        let tensor = Tensor::from_vec(vec![1i32, -2, 3, -4], [2, 2])?;
140        let relu_output = tensor.relu()?;
141
142        assert_eq!(relu_output.at::<i32>([0, 0]), 1);
143        assert_eq!(relu_output.at::<i32>([0, 1]), 0);
144        assert_eq!(relu_output.at::<i32>([1, 0]), 3);
145        assert_eq!(relu_output.at::<i32>([1, 1]), 0);
146
147        Ok(())
148    }
149
150    #[test]
151    fn test_relu_u8() -> Result<()> {
152        let tensor = Tensor::from_vec(vec![1u8, 2, 3, 4], [2, 2])?;
153        let relu_output = tensor.relu()?;
154
155        // Unsigned types should return original values for ReLU
156        assert_eq!(relu_output.at::<u8>([0, 0]), 1);
157        assert_eq!(relu_output.at::<u8>([0, 1]), 2);
158        assert_eq!(relu_output.at::<u8>([1, 0]), 3);
159        assert_eq!(relu_output.at::<u8>([1, 1]), 4);
160
161        Ok(())
162    }
163
164    #[test]
165    fn test_relu_f16() -> Result<()> {
166        let tensor = Tensor::from_vec(
167            vec![half::f16::from_f32(1.0), half::f16::from_f32(-2.0)],
168            [2],
169        )?;
170        let relu_output = tensor.relu()?;
171
172        assert_eq!(relu_output.at::<half::f16>([0]), half::f16::from_f32(1.0));
173        assert_eq!(relu_output.at::<half::f16>([1]), half::f16::from_f32(0.0));
174
175        Ok(())
176    }
177
178    #[test]
179    fn test_relu_bf16() -> Result<()> {
180        let tensor = Tensor::from_vec(
181            vec![half::bf16::from_f32(1.0), half::bf16::from_f32(-2.0)],
182            [2],
183        )?;
184        let relu_output = tensor.relu()?;
185
186        assert_eq!(relu_output.at::<half::bf16>([0]), half::bf16::from_f32(1.0));
187        assert_eq!(relu_output.at::<half::bf16>([1]), half::bf16::from_f32(0.0));
188
189        Ok(())
190    }
191
192    #[test]
193    fn test_relu_non_contiguous() -> Result<()> {
194        let tensor = Tensor::from_vec(vec![1.0f32, -2.0, 3.0, -4.0, 5.0, 6.0], [2, 3])?;
195        // Create a non-contiguous view by selecting the second row
196        let sliced = tensor.slice(s![1]); // Select the second row
197        let relu_output = sliced.relu()?;
198
199        assert_eq!(relu_output.at::<f32>([0]), 0.0); // -4.0 -> 0.0
200        assert_eq!(relu_output.at::<f32>([1]), 5.0); // 5.0 -> 5.0
201        assert_eq!(relu_output.at::<f32>([2]), 6.0); // 6.0 -> 6.0
202
203        Ok(())
204    }
205
206    #[test]
207    fn test_relu_empty() -> Result<()> {
208        let tensor = Tensor::from_vec(vec![0u8; 0], [0])?;
209        let relu_output = tensor.relu()?;
210
211        assert_eq!(relu_output.numel(), 0);
212        assert_eq!(relu_output.dims(), [0]);
213
214        Ok(())
215    }
216}