1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use rayon::prelude::*;
5use std::sync::Arc;
6
7pub fn sigmoid(input: &Tensor) -> Tensor {
9 #[cfg(feature = "wgpu_backend")]
10 {
11 if let Some(input_buf) = input.storage().wgpu_buffer() {
12 if !input.is_contiguous() {
13 return sigmoid(&input.contiguous());
14 }
15
16 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
17 let size: usize = input.shape().iter().product();
18 let output_buf = elementwise_wgpu_buffer(
19 input_buf,
20 input.shape(),
21 input.strides(),
22 None,
23 input.shape(),
24 ElementwiseOp::Sigmoid,
25 None,
26 );
27 let storage = Storage::new_wgpu(output_buf, size, 0);
28 let mut tensor = Tensor::new_with_storage(storage, input.shape());
29
30 if input.requires_grad() {
31 tensor.set_requires_grad_mut(true);
32 tensor.set_op(Arc::new(SigmoidBackward {
33 input: input.clone(),
34 }));
35 }
36 return tensor;
37 }
38 }
39
40 if !input.is_contiguous() {
41 return sigmoid(&input.contiguous());
42 }
43
44 let input_guard = input.data();
45 let input_data = &*input_guard;
46
47 let result_data: Vec<f32> = input_data
48 .par_iter()
49 .map(|&x| 1.0 / (1.0 + (-x).exp()))
50 .collect();
51
52 let storage = Storage::new(result_data);
53 let mut tensor = Tensor::new_with_storage(storage, input.shape());
54
55 if input.requires_grad() {
56 tensor.set_requires_grad_mut(true);
57 tensor.set_op(Arc::new(SigmoidBackward {
68 input: input.clone(),
69 }));
70 }
71
72 tensor
73}
74
75#[derive(Debug)]
76pub struct SigmoidBackward {
77 pub input: Tensor,
78}
79
80impl BackwardOp for SigmoidBackward {
81 fn backward(&self, grad: &Tensor) {
82 if self.input.requires_grad() {
83 #[cfg(feature = "wgpu_backend")]
85 {
86 if let Some(_) = self.input.storage().wgpu_buffer() {
87 let s = sigmoid(&self.input);
90 let s_buf = s
91 .storage()
92 .wgpu_buffer()
93 .expect("Sigmoid output should be on GPU");
94
95 let grad_contig = if !grad.is_contiguous() {
97 grad.contiguous()
98 } else {
99 grad.clone()
100 };
101 let grad_buf = grad_contig
102 .storage()
103 .wgpu_buffer()
104 .expect("Grad should be on GPU");
105
106 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
107 let size = grad.shape().iter().product();
108 let output_buf = elementwise_wgpu_buffer(
110 s_buf,
111 s.shape(),
112 s.strides(),
113 Some((grad_buf, grad.shape(), grad.strides())),
114 grad.shape(),
115 ElementwiseOp::SigmoidBackward,
116 None,
117 );
118 let storage = Storage::new_wgpu(output_buf, size, 0);
119 let grad_input = Tensor::new_with_storage(storage, grad.shape());
120
121 self.input.accumulate_grad(&grad_input);
122 self.input.backward_step();
123 return;
124 }
125 }
126
127 #[cfg(feature = "wgpu_backend")]
132 let (input, grad) = {
133 let i = if self.input.storage().device().is_wgpu() {
134 self.input.to_cpu()
135 } else {
136 self.input.clone()
137 };
138 let g = if grad.storage().device().is_wgpu() {
139 grad.to_cpu()
140 } else {
141 grad.clone()
142 };
143 (i, g)
144 };
145 #[cfg(not(feature = "wgpu_backend"))]
146 let (input, grad) = (self.input.clone(), grad.clone());
147
148 let s = sigmoid(&input);
149
150 let s_guard = s.data();
156 let grad_guard = grad.data();
157 let s_data = &*s_guard;
158 let grad_data = &*grad_guard;
159
160 let grad_input_data: Vec<f32> = s_data
161 .par_iter()
162 .zip(grad_data.par_iter())
163 .map(|(s_val, g_val)| g_val * s_val * (1.0 - s_val))
164 .collect();
165
166 let grad_input = Tensor::new_with_storage(Storage::new(grad_input_data), grad.shape());
167
168 self.input.accumulate_grad(&grad_input);
169 self.input.backward_step();
170 }
171 }
172}
173
174pub fn tanh(input: &Tensor) -> Tensor {
176 #[cfg(feature = "wgpu_backend")]
177 {
178 if let Some(input_buf) = input.storage().wgpu_buffer() {
179 if !input.is_contiguous() {
180 return tanh(&input.contiguous());
181 }
182
183 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
184 let size: usize = input.shape().iter().product();
185 let output_buf = elementwise_wgpu_buffer(
186 input_buf,
187 input.shape(),
188 input.strides(),
189 None,
190 input.shape(),
191 ElementwiseOp::Tanh,
192 None,
193 );
194 let storage = Storage::new_wgpu(output_buf, size, 0);
195 let mut tensor = Tensor::new_with_storage(storage, input.shape());
196
197 if input.requires_grad() {
198 tensor.set_requires_grad_mut(true);
199 tensor.set_op(Arc::new(TanhBackward {
200 input: input.clone(),
201 }));
202 }
203 return tensor;
204 }
205 }
206
207 if !input.is_contiguous() {
208 return tanh(&input.contiguous());
209 }
210
211 let input_guard = input.data();
212 let input_data = &*input_guard;
213
214 let result_data: Vec<f32> = input_data.par_iter().map(|&x| x.tanh()).collect();
215
216 let storage = Storage::new(result_data);
217 let mut tensor = Tensor::new_with_storage(storage, input.shape());
218
219 if input.requires_grad() {
220 tensor.set_requires_grad_mut(true);
221 tensor.set_op(Arc::new(TanhBackward {
222 input: input.clone(),
223 }));
224 }
225
226 tensor
227}
228
229#[derive(Debug)]
230pub struct TanhBackward {
231 pub input: Tensor,
232}
233
234impl BackwardOp for TanhBackward {
235 fn backward(&self, grad: &Tensor) {
236 if self.input.requires_grad() {
237 #[cfg(feature = "wgpu_backend")]
238 {
239 if let Some(_) = self.input.storage().wgpu_buffer() {
240 let t = tanh(&self.input);
242 let t_buf = t
243 .storage()
244 .wgpu_buffer()
245 .expect("Tanh output should be on GPU");
246
247 let grad_contig = if !grad.is_contiguous() {
248 grad.contiguous()
249 } else {
250 grad.clone()
251 };
252 let grad_buf = grad_contig
253 .storage()
254 .wgpu_buffer()
255 .expect("Grad should be on GPU");
256
257 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
258 let size = grad.shape().iter().product();
259 let output_buf = elementwise_wgpu_buffer(
261 t_buf,
262 t.shape(),
263 t.strides(),
264 Some((grad_buf, grad.shape(), grad.strides())),
265 grad.shape(),
266 ElementwiseOp::TanhBackward,
267 None,
268 );
269
270 let storage = Storage::new_wgpu(output_buf, size, 0);
271 let grad_input = Tensor::new_with_storage(storage, grad.shape());
272
273 self.input.accumulate_grad(&grad_input);
274 self.input.backward_step();
275 return;
276 }
277 }
278
279 #[cfg(feature = "wgpu_backend")]
281 let (input, grad) = {
282 let i = if self.input.storage().device().is_wgpu() {
283 self.input.to_cpu()
284 } else {
285 self.input.clone()
286 };
287 let g = if grad.storage().device().is_wgpu() {
288 grad.to_cpu()
289 } else {
290 grad.clone()
291 };
292 (i, g)
293 };
294 #[cfg(not(feature = "wgpu_backend"))]
295 let (input, grad) = (self.input.clone(), grad.clone());
296
297 let t = tanh(&input);
298
299 let t_guard = t.data();
300 let grad_guard = grad.data();
301 let t_data = &*t_guard;
302 let grad_data = &*grad_guard;
303
304 let grad_input_data: Vec<f32> = t_data
305 .par_iter()
306 .zip(grad_data.par_iter())
307 .map(|(t_val, g_val)| g_val * (1.0 - t_val * t_val))
308 .collect();
309
310 let grad_input = Tensor::new_with_storage(Storage::new(grad_input_data), grad.shape());
311
312 self.input.accumulate_grad(&grad_input);
313 self.input.backward_step();
314 }
315 }
316}
317
318pub fn softmax(input: &Tensor, dim: i64) -> Tensor {
321 let ndim = input.shape().len() as i64;
323 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
324
325 if dim != input.shape().len() - 1 {
326 panic!("Softmax currently only supports last dimension (dim=-1)");
328 }
329
330 let shape = input.shape();
331 let last_dim_size = shape[shape.len() - 1];
332 let _outer_size: usize = shape.iter().take(shape.len() - 1).product();
333
334 if !input.is_contiguous() {
335 return softmax(&input.contiguous(), dim as i64);
336 }
337
338 #[cfg(feature = "wgpu_backend")]
339 let input = if input.storage().device().is_wgpu() {
340 input.to_cpu()
341 } else {
342 input.clone()
343 };
344
345 let input_guard = input.data();
346 let input_data = &*input_guard;
347
348 let mut output_data = vec![0.0; input_data.len()];
349
350 output_data
352 .par_chunks_mut(last_dim_size)
353 .enumerate()
354 .for_each(|(i, out_row)| {
355 let offset = i * last_dim_size;
356 let in_row = &input_data[offset..offset + last_dim_size];
357
358 let max_val = in_row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
360
361 let mut sum_exp = 0.0;
362 for (j, &val) in in_row.iter().enumerate() {
363 let exp_val = (val - max_val).exp();
364 out_row[j] = exp_val;
365 sum_exp += exp_val;
366 }
367
368 for val in out_row.iter_mut() {
369 *val /= sum_exp;
370 }
371 });
372
373 let storage = Storage::new(output_data);
374 let mut tensor = Tensor::new_with_storage(storage, shape);
375
376 if input.requires_grad() {
377 tensor.set_requires_grad_mut(true);
378 tensor.set_op(Arc::new(SoftmaxBackward {
387 output: tensor.clone(), input: input.clone(),
391 dim,
392 }));
393 }
394
395 tensor
396}
397
398#[derive(Debug)]
399pub struct SoftmaxBackward {
400 pub input: Tensor,
401 pub output: Tensor, pub dim: usize,
407}
408
409impl BackwardOp for SoftmaxBackward {
410 fn backward(&self, grad: &Tensor) {
411 if self.input.requires_grad() {
412 let s = softmax(&self.input, self.dim as i64);
414
415 #[cfg(feature = "wgpu_backend")]
420 let (s, grad) = {
421 let s = if s.storage().device().is_wgpu() {
422 s.to_cpu()
423 } else {
424 s
425 };
426 let g = if grad.storage().device().is_wgpu() {
427 grad.to_cpu()
428 } else {
429 grad.clone()
430 };
431 (s, g)
432 };
433
434 let s_guard = s.data();
435 let s_data = &*s_guard;
436
437 let grad_guard = grad.data();
438 let grad_data = &*grad_guard;
439
440 let shape = s.shape();
441 let last_dim = shape[shape.len() - 1];
442
443 let mut grad_input_data = vec![0.0; s_data.len()];
444
445 grad_input_data
446 .par_chunks_mut(last_dim)
447 .enumerate()
448 .for_each(|(i, out_row)| {
449 let offset = i * last_dim;
450 let s_row = &s_data[offset..offset + last_dim];
451 let g_row = &grad_data[offset..offset + last_dim];
452
453 let mut dot = 0.0;
454 for j in 0..last_dim {
455 dot += s_row[j] * g_row[j];
456 }
457
458 for j in 0..last_dim {
459 out_row[j] = s_row[j] * (g_row[j] - dot);
460 }
461 });
462
463 let grad_input = Tensor::new(&grad_input_data, shape);
464 self.input.accumulate_grad(&grad_input);
465 self.input.backward_step();
466 }
467 }
468}