1use anyhow::Result;
2
3use crate::{global_backend, DType, OpsTrait, StorageTrait, Tensor, TensorBase, UninitVec};
4
5impl<S: StorageTrait> TensorBase<S> {
6 #[inline(always)]
7 pub fn relu(&self) -> Result<Tensor> {
8 if self.is_contiguous() {
9 self.relu_contiguous()
10 } else {
11 self.relu_non_contiguous()
12 }
13 }
14
15 #[inline(always)]
16 fn relu_contiguous(&self) -> Result<Tensor> {
17 let numel = self.shape.numel();
18 let backend = global_backend();
19
20 match self.dtype {
21 DType::Bool | DType::Uint8 | DType::Uint16 | DType::Uint32 | DType::Uint64 => {
22 self.clone_or_copy()
23 }
24 DType::Fp32 => {
25 let input_data = self.as_slice::<f32>()?;
26 let out = UninitVec::<f32>::new(numel).init_with(|dst| {
27 backend.relu_f32(input_data, dst);
28 });
29 Tensor::from_vec(out, self.shape)
30 }
31 DType::Fp64 => {
32 let input_data = self.as_slice::<f64>()?;
33 let out = UninitVec::<f64>::new(numel).init_with(|dst| {
34 backend.relu_f64(input_data, dst);
35 });
36 Tensor::from_vec(out, self.shape)
37 }
38 DType::Int8 => {
39 let input_data = self.as_slice::<i8>()?;
40 let out = UninitVec::<i8>::new(numel).init_with(|dst| {
41 backend.relu_i8(input_data, dst);
42 });
43 Tensor::from_vec(out, self.shape)
44 }
45 DType::Int16 => {
46 let input_data = self.as_slice::<i16>()?;
47 let out = UninitVec::<i16>::new(numel).init_with(|dst| {
48 backend.relu_i16(input_data, dst);
49 });
50 Tensor::from_vec(out, self.shape)
51 }
52 DType::Int32 => {
53 let input_data = self.as_slice::<i32>()?;
54 let out = UninitVec::<i32>::new(numel).init_with(|dst| {
55 backend.relu_i32(input_data, dst);
56 });
57 Tensor::from_vec(out, self.shape)
58 }
59 DType::Int64 => {
60 let input_data = self.as_slice::<i64>()?;
61 let out = UninitVec::<i64>::new(numel).init_with(|dst| {
62 backend.relu_i64(input_data, dst);
63 });
64 Tensor::from_vec(out, self.shape)
65 }
66 DType::Fp16 => {
67 let input_data = self.as_slice::<half::f16>()?;
68 let out = UninitVec::<half::f16>::new(numel).init_with(|dst| {
69 backend.relu_f16(input_data, dst);
70 });
71 Tensor::from_vec(out, self.shape)
72 }
73 DType::Bf16 => {
74 let input_data = self.as_slice::<half::bf16>()?;
75 let out = UninitVec::<half::bf16>::new(numel).init_with(|dst| {
76 backend.relu_bf16(input_data, dst);
77 });
78 Tensor::from_vec(out, self.shape)
79 }
80 _ => anyhow::bail!("Unsupported dtype for relu operation: {:?}", self.dtype),
81 }
82 }
83
84 #[inline(always)]
85 fn relu_non_contiguous(&self) -> Result<Tensor> {
86 match self.dtype {
87 DType::Bool | DType::Uint8 | DType::Uint16 | DType::Uint32 | DType::Uint64 => {
88 self.clone_or_copy()
89 }
90 DType::Fp32 => self.map_non_contiguous::<f32>(|x| x.max(0.0)),
91 DType::Fp64 => self.map_non_contiguous::<f64>(|x| x.max(0.0)),
92 DType::Int8 => self.map_non_contiguous::<i8>(|x| *x.max(&0)),
93 DType::Int16 => self.map_non_contiguous::<i16>(|x| *x.max(&0)),
94 DType::Int32 => self.map_non_contiguous::<i32>(|x| *x.max(&0)),
95 DType::Int64 => self.map_non_contiguous::<i64>(|x| *x.max(&0)),
96 DType::Fp16 => {
97 self.map_non_contiguous::<half::f16>(|x| half::f16::from_f32(x.to_f32().max(0.0)))
98 }
99 DType::Bf16 => {
100 self.map_non_contiguous::<half::bf16>(|x| half::bf16::from_f32(x.to_f32().max(0.0)))
101 }
102 _ => anyhow::bail!("Unsupported dtype for relu operation: {:?}", self.dtype),
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::s;
111
112 #[test]
113 fn test_relu_f32() -> Result<()> {
114 let tensor = Tensor::from_vec(vec![1.0f32, -2.0, 3.0, -4.0], [2, 2])?;
115 let relu_output = tensor.relu()?;
116
117 assert_eq!(relu_output.at::<f32>([0, 0]), 1.0);
118 assert_eq!(relu_output.at::<f32>([0, 1]), 0.0);
119 assert_eq!(relu_output.at::<f32>([1, 0]), 3.0);
120 assert_eq!(relu_output.at::<f32>([1, 1]), 0.0);
121
122 Ok(())
123 }
124
125 #[test]
126 fn test_relu_f64() -> Result<()> {
127 let tensor = Tensor::from_vec(vec![1.0f64, -2.0, 3.0], [3])?;
128 let relu_output = tensor.relu()?;
129
130 assert_eq!(relu_output.at::<f64>([0]), 1.0);
131 assert_eq!(relu_output.at::<f64>([1]), 0.0);
132 assert_eq!(relu_output.at::<f64>([2]), 3.0);
133
134 Ok(())
135 }
136
137 #[test]
138 fn test_relu_i32() -> Result<()> {
139 let tensor = Tensor::from_vec(vec![1i32, -2, 3, -4], [2, 2])?;
140 let relu_output = tensor.relu()?;
141
142 assert_eq!(relu_output.at::<i32>([0, 0]), 1);
143 assert_eq!(relu_output.at::<i32>([0, 1]), 0);
144 assert_eq!(relu_output.at::<i32>([1, 0]), 3);
145 assert_eq!(relu_output.at::<i32>([1, 1]), 0);
146
147 Ok(())
148 }
149
150 #[test]
151 fn test_relu_u8() -> Result<()> {
152 let tensor = Tensor::from_vec(vec![1u8, 2, 3, 4], [2, 2])?;
153 let relu_output = tensor.relu()?;
154
155 assert_eq!(relu_output.at::<u8>([0, 0]), 1);
157 assert_eq!(relu_output.at::<u8>([0, 1]), 2);
158 assert_eq!(relu_output.at::<u8>([1, 0]), 3);
159 assert_eq!(relu_output.at::<u8>([1, 1]), 4);
160
161 Ok(())
162 }
163
164 #[test]
165 fn test_relu_f16() -> Result<()> {
166 let tensor = Tensor::from_vec(
167 vec![half::f16::from_f32(1.0), half::f16::from_f32(-2.0)],
168 [2],
169 )?;
170 let relu_output = tensor.relu()?;
171
172 assert_eq!(relu_output.at::<half::f16>([0]), half::f16::from_f32(1.0));
173 assert_eq!(relu_output.at::<half::f16>([1]), half::f16::from_f32(0.0));
174
175 Ok(())
176 }
177
178 #[test]
179 fn test_relu_bf16() -> Result<()> {
180 let tensor = Tensor::from_vec(
181 vec![half::bf16::from_f32(1.0), half::bf16::from_f32(-2.0)],
182 [2],
183 )?;
184 let relu_output = tensor.relu()?;
185
186 assert_eq!(relu_output.at::<half::bf16>([0]), half::bf16::from_f32(1.0));
187 assert_eq!(relu_output.at::<half::bf16>([1]), half::bf16::from_f32(0.0));
188
189 Ok(())
190 }
191
192 #[test]
193 fn test_relu_non_contiguous() -> Result<()> {
194 let tensor = Tensor::from_vec(vec![1.0f32, -2.0, 3.0, -4.0, 5.0, 6.0], [2, 3])?;
195 let sliced = tensor.slice(s![1]); let relu_output = sliced.relu()?;
198
199 assert_eq!(relu_output.at::<f32>([0]), 0.0); assert_eq!(relu_output.at::<f32>([1]), 5.0); assert_eq!(relu_output.at::<f32>([2]), 6.0); Ok(())
204 }
205
206 #[test]
207 fn test_relu_empty() -> Result<()> {
208 let tensor = Tensor::from_vec(vec![0u8; 0], [0])?;
209 let relu_output = tensor.relu()?;
210
211 assert_eq!(relu_output.numel(), 0);
212 assert_eq!(relu_output.dims(), [0]);
213
214 Ok(())
215 }
216}