Skip to main content

rustorch_core/ops/
activations.rs

1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use rayon::prelude::*;
5use std::sync::Arc;
6
7// --- Sigmoid ---
8pub fn sigmoid(input: &Tensor) -> Tensor {
9    #[cfg(feature = "wgpu_backend")]
10    {
11        if let Some(input_buf) = input.storage().wgpu_buffer() {
12            if !input.is_contiguous() {
13                return sigmoid(&input.contiguous());
14            }
15
16            use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
17            let size: usize = input.shape().iter().product();
18            let output_buf = elementwise_wgpu_buffer(
19                input_buf,
20                input.shape(),
21                input.strides(),
22                None,
23                input.shape(),
24                ElementwiseOp::Sigmoid,
25                None,
26            );
27            let storage = Storage::new_wgpu(output_buf, size, 0);
28            let mut tensor = Tensor::new_with_storage(storage, input.shape());
29
30            if input.requires_grad() {
31                tensor.set_requires_grad_mut(true);
32                tensor.set_op(Arc::new(SigmoidBackward {
33                    input: input.clone(),
34                }));
35            }
36            return tensor;
37        }
38    }
39
40    if !input.is_contiguous() {
41        return sigmoid(&input.contiguous());
42    }
43
44    let input_guard = input.data();
45    let input_data = &*input_guard;
46
47    let result_data: Vec<f32> = input_data
48        .par_iter()
49        .map(|&x| 1.0 / (1.0 + (-x).exp()))
50        .collect();
51
52    let storage = Storage::new(result_data);
53    let mut tensor = Tensor::new_with_storage(storage, input.shape());
54
55    if input.requires_grad() {
56        tensor.set_requires_grad_mut(true);
57        // We store input to update its gradient.
58        // We also store output to avoid recomputing sigmoid(x) during backward.
59        // But we need to be careful about reference cycles if we store output (which is `tensor` itself).
60        // Since `tensor` owns `op`, and `op` owns `output` (tensor), we have a cycle.
61        // So we CANNOT store `tensor` (output) in `op`.
62        // We must recompute or store input.
63        // Recomputing is safer for memory management in this simple Arc-based graph.
64        // To optimize, we would need Weak refs or a different graph structure (e.g. tape-based).
65        // Sticking to recompute for now but fixing logic.
66
67        tensor.set_op(Arc::new(SigmoidBackward {
68            input: input.clone(),
69        }));
70    }
71
72    tensor
73}
74
75#[derive(Debug)]
76pub struct SigmoidBackward {
77    pub input: Tensor,
78}
79
80impl BackwardOp for SigmoidBackward {
81    fn backward(&self, grad: &Tensor) {
82        if self.input.requires_grad() {
83            // Check for GPU
84            #[cfg(feature = "wgpu_backend")]
85            {
86                if let Some(_) = self.input.storage().wgpu_buffer() {
87                    // We need output(sigmoid(input))
88                    // Recompute sigmoid
89                    let s = sigmoid(&self.input);
90                    let s_buf = s
91                        .storage()
92                        .wgpu_buffer()
93                        .expect("Sigmoid output should be on GPU");
94
95                    // grad might not be contiguous or on GPU if not handled properly upstream
96                    let grad_contig = if !grad.is_contiguous() {
97                        grad.contiguous()
98                    } else {
99                        grad.clone()
100                    };
101                    let grad_buf = grad_contig
102                        .storage()
103                        .wgpu_buffer()
104                        .expect("Grad should be on GPU");
105
106                    use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
107                    let size = grad.shape().iter().product();
108                    // SigmoidBackward takes output (s) and grad
109                    let output_buf = elementwise_wgpu_buffer(
110                        s_buf,
111                        s.shape(),
112                        s.strides(),
113                        Some((grad_buf, grad.shape(), grad.strides())),
114                        grad.shape(),
115                        ElementwiseOp::SigmoidBackward,
116                        None,
117                    );
118                    let storage = Storage::new_wgpu(output_buf, size, 0);
119                    let grad_input = Tensor::new_with_storage(storage, grad.shape());
120
121                    self.input.accumulate_grad(&grad_input);
122                    self.input.backward_step();
123                    return;
124                }
125            }
126
127            // grad_input = grad * sigmoid(input) * (1 - sigmoid(input))
128            // Recompute sigmoid
129
130            // Fix: Ensure CPU fallback
131            #[cfg(feature = "wgpu_backend")]
132            let (input, grad) = {
133                let i = if self.input.storage().device().is_wgpu() {
134                    self.input.to_cpu()
135                } else {
136                    self.input.clone()
137                };
138                let g = if grad.storage().device().is_wgpu() {
139                    grad.to_cpu()
140                } else {
141                    grad.clone()
142                };
143                (i, g)
144            };
145            #[cfg(not(feature = "wgpu_backend"))]
146            let (input, grad) = (self.input.clone(), grad.clone());
147
148            let s = sigmoid(&input);
149
150            // dS = s * (1 - s)
151            // This creates intermediates.
152            // Optimization: fused kernel for dS * grad
153
154            // Manual fused implementation for speed
155            let s_guard = s.data();
156            let grad_guard = grad.data();
157            let s_data = &*s_guard;
158            let grad_data = &*grad_guard;
159
160            let grad_input_data: Vec<f32> = s_data
161                .par_iter()
162                .zip(grad_data.par_iter())
163                .map(|(s_val, g_val)| g_val * s_val * (1.0 - s_val))
164                .collect();
165
166            let grad_input = Tensor::new_with_storage(Storage::new(grad_input_data), grad.shape());
167
168            self.input.accumulate_grad(&grad_input);
169            self.input.backward_step();
170        }
171    }
172}
173
174// --- Tanh ---
175pub fn tanh(input: &Tensor) -> Tensor {
176    #[cfg(feature = "wgpu_backend")]
177    {
178        if let Some(input_buf) = input.storage().wgpu_buffer() {
179            if !input.is_contiguous() {
180                return tanh(&input.contiguous());
181            }
182
183            use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
184            let size: usize = input.shape().iter().product();
185            let output_buf = elementwise_wgpu_buffer(
186                input_buf,
187                input.shape(),
188                input.strides(),
189                None,
190                input.shape(),
191                ElementwiseOp::Tanh,
192                None,
193            );
194            let storage = Storage::new_wgpu(output_buf, size, 0);
195            let mut tensor = Tensor::new_with_storage(storage, input.shape());
196
197            if input.requires_grad() {
198                tensor.set_requires_grad_mut(true);
199                tensor.set_op(Arc::new(TanhBackward {
200                    input: input.clone(),
201                }));
202            }
203            return tensor;
204        }
205    }
206
207    if !input.is_contiguous() {
208        return tanh(&input.contiguous());
209    }
210
211    let input_guard = input.data();
212    let input_data = &*input_guard;
213
214    let result_data: Vec<f32> = input_data.par_iter().map(|&x| x.tanh()).collect();
215
216    let storage = Storage::new(result_data);
217    let mut tensor = Tensor::new_with_storage(storage, input.shape());
218
219    if input.requires_grad() {
220        tensor.set_requires_grad_mut(true);
221        tensor.set_op(Arc::new(TanhBackward {
222            input: input.clone(),
223        }));
224    }
225
226    tensor
227}
228
229#[derive(Debug)]
230pub struct TanhBackward {
231    pub input: Tensor,
232}
233
234impl BackwardOp for TanhBackward {
235    fn backward(&self, grad: &Tensor) {
236        if self.input.requires_grad() {
237            #[cfg(feature = "wgpu_backend")]
238            {
239                if let Some(_) = self.input.storage().wgpu_buffer() {
240                    // Recompute tanh
241                    let t = tanh(&self.input);
242                    let t_buf = t
243                        .storage()
244                        .wgpu_buffer()
245                        .expect("Tanh output should be on GPU");
246
247                    let grad_contig = if !grad.is_contiguous() {
248                        grad.contiguous()
249                    } else {
250                        grad.clone()
251                    };
252                    let grad_buf = grad_contig
253                        .storage()
254                        .wgpu_buffer()
255                        .expect("Grad should be on GPU");
256
257                    use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
258                    let size = grad.shape().iter().product();
259                    // TanhBackward: (1 - t^2) * grad
260                    let output_buf = elementwise_wgpu_buffer(
261                        t_buf,
262                        t.shape(),
263                        t.strides(),
264                        Some((grad_buf, grad.shape(), grad.strides())),
265                        grad.shape(),
266                        ElementwiseOp::TanhBackward,
267                        None,
268                    );
269
270                    let storage = Storage::new_wgpu(output_buf, size, 0);
271                    let grad_input = Tensor::new_with_storage(storage, grad.shape());
272
273                    self.input.accumulate_grad(&grad_input);
274                    self.input.backward_step();
275                    return;
276                }
277            }
278
279            // Fix: CPU Fallback
280            #[cfg(feature = "wgpu_backend")]
281            let (input, grad) = {
282                let i = if self.input.storage().device().is_wgpu() {
283                    self.input.to_cpu()
284                } else {
285                    self.input.clone()
286                };
287                let g = if grad.storage().device().is_wgpu() {
288                    grad.to_cpu()
289                } else {
290                    grad.clone()
291                };
292                (i, g)
293            };
294            #[cfg(not(feature = "wgpu_backend"))]
295            let (input, grad) = (self.input.clone(), grad.clone());
296
297            let t = tanh(&input);
298
299            let t_guard = t.data();
300            let grad_guard = grad.data();
301            let t_data = &*t_guard;
302            let grad_data = &*grad_guard;
303
304            let grad_input_data: Vec<f32> = t_data
305                .par_iter()
306                .zip(grad_data.par_iter())
307                .map(|(t_val, g_val)| g_val * (1.0 - t_val * t_val))
308                .collect();
309
310            let grad_input = Tensor::new_with_storage(Storage::new(grad_input_data), grad.shape());
311
312            self.input.accumulate_grad(&grad_input);
313            self.input.backward_step();
314        }
315    }
316}
317
318// --- Softmax ---
319// Naive implementation along last dim
320pub fn softmax(input: &Tensor, dim: i64) -> Tensor {
321    // Handle negative dim
322    let ndim = input.shape().len() as i64;
323    let dim = if dim < 0 { ndim + dim } else { dim } as usize;
324
325    if dim != input.shape().len() - 1 {
326        // For now only support last dim for simplicity in parallel iter
327        panic!("Softmax currently only supports last dimension (dim=-1)");
328    }
329
330    let shape = input.shape();
331    let last_dim_size = shape[shape.len() - 1];
332    let _outer_size: usize = shape.iter().take(shape.len() - 1).product();
333
334    if !input.is_contiguous() {
335        return softmax(&input.contiguous(), dim as i64);
336    }
337
338    #[cfg(feature = "wgpu_backend")]
339    let input = if input.storage().device().is_wgpu() {
340        input.to_cpu()
341    } else {
342        input.clone()
343    };
344
345    let input_guard = input.data();
346    let input_data = &*input_guard;
347
348    let mut output_data = vec![0.0; input_data.len()];
349
350    // Parallel over outer dimensions
351    output_data
352        .par_chunks_mut(last_dim_size)
353        .enumerate()
354        .for_each(|(i, out_row)| {
355            let offset = i * last_dim_size;
356            let in_row = &input_data[offset..offset + last_dim_size];
357
358            // Max for numerical stability
359            let max_val = in_row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
360
361            let mut sum_exp = 0.0;
362            for (j, &val) in in_row.iter().enumerate() {
363                let exp_val = (val - max_val).exp();
364                out_row[j] = exp_val;
365                sum_exp += exp_val;
366            }
367
368            for val in out_row.iter_mut() {
369                *val /= sum_exp;
370            }
371        });
372
373    let storage = Storage::new(output_data);
374    let mut tensor = Tensor::new_with_storage(storage, shape);
375
376    if input.requires_grad() {
377        tensor.set_requires_grad_mut(true);
378        // SoftmaxBackward
379        // dS_i/dx_j = S_i * (delta_ij - S_j)
380        // grad_input_j = sum_i (grad_i * dS_i/dx_j)
381        //              = sum_i (grad_i * S_i * (delta_ij - S_j))
382        //              = S_j * (grad_j - sum_k(grad_k * S_k))
383        //              = S_j * (grad_j - (grad . S))
384
385        // We need the output S for backward. Recomputing it is safer for graph.
386        tensor.set_op(Arc::new(SoftmaxBackward {
387            output: tensor.clone(), // Wait, cycle?
388            // Yes, storing tensor in its own op creates cycle: Tensor -> Op -> Tensor.
389            // But we can store input and recompute.
390            input: input.clone(),
391            dim,
392        }));
393    }
394
395    tensor
396}
397
398#[derive(Debug)]
399pub struct SoftmaxBackward {
400    pub input: Tensor,
401    pub output: Tensor, // Warning: Cycle if not careful.
402    // Actually, if we drop the graph, cycle breaks.
403    // But `output` here is the result of forward.
404    // Ideally we should store `Weak<TensorImpl>` or recompute.
405    // For MVP, let's store `input` and recompute softmax in backward.
406    pub dim: usize,
407}
408
409impl BackwardOp for SoftmaxBackward {
410    fn backward(&self, grad: &Tensor) {
411        if self.input.requires_grad() {
412            // Recompute softmax
413            let s = softmax(&self.input, self.dim as i64);
414
415            // grad_input = S * (grad - sum(grad * S, dim=keepdim))
416            // We need sum reduction.
417            // Let's implement manually for last dim.
418
419            #[cfg(feature = "wgpu_backend")]
420            let (s, grad) = {
421                let s = if s.storage().device().is_wgpu() {
422                    s.to_cpu()
423                } else {
424                    s
425                };
426                let g = if grad.storage().device().is_wgpu() {
427                    grad.to_cpu()
428                } else {
429                    grad.clone()
430                };
431                (s, g)
432            };
433
434            let s_guard = s.data();
435            let s_data = &*s_guard;
436
437            let grad_guard = grad.data();
438            let grad_data = &*grad_guard;
439
440            let shape = s.shape();
441            let last_dim = shape[shape.len() - 1];
442
443            let mut grad_input_data = vec![0.0; s_data.len()];
444
445            grad_input_data
446                .par_chunks_mut(last_dim)
447                .enumerate()
448                .for_each(|(i, out_row)| {
449                    let offset = i * last_dim;
450                    let s_row = &s_data[offset..offset + last_dim];
451                    let g_row = &grad_data[offset..offset + last_dim];
452
453                    let mut dot = 0.0;
454                    for j in 0..last_dim {
455                        dot += s_row[j] * g_row[j];
456                    }
457
458                    for j in 0..last_dim {
459                        out_row[j] = s_row[j] * (g_row[j] - dot);
460                    }
461                });
462
463            let grad_input = Tensor::new(&grad_input_data, shape);
464            self.input.accumulate_grad(&grad_input);
465            self.input.backward_step();
466        }
467    }
468}