Skip to main content

yscv_autograd/
graph.rs

1use std::num::NonZeroUsize;
2
3use yscv_kernels::{
4    BackwardOps, BatchNorm2dParams, ThreadedCpuBackend, add as kernel_add, avg_pool2d_nhwc,
5    batch_norm2d_nhwc, conv2d_nhwc, conv3d, depthwise_conv2d_nhwc, gelu as kernel_gelu, matmul_2d,
6    mish as kernel_mish, mul as kernel_mul, relu, sigmoid as kernel_sigmoid, silu as kernel_silu,
7    sub as kernel_sub,
8};
9use yscv_tensor::Tensor;
10
11use super::error::AutogradError;
12use super::node::{AuxData, Node, NodeId, Op};
13
14/// Eager autograd graph with explicit backward pass.
15pub struct Graph {
16    pub(crate) nodes: Vec<Node>,
17    pub(crate) backend: Option<Box<dyn BackwardOps>>,
18}
19
20impl Default for Graph {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl Graph {
27    /// Creates an empty graph with automatic parallel backend.
28    ///
29    /// Uses all available CPU threads for parallel matmul, conv2d, softmax, etc.
30    pub fn new() -> Self {
31        let threads = std::thread::available_parallelism()
32            .unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
33        let backend = ThreadedCpuBackend::new(threads)
34            .ok()
35            .map(|b| Box::new(b) as Box<dyn BackwardOps>);
36        Self {
37            nodes: Vec::new(),
38            backend,
39        }
40    }
41
42    /// Creates an empty graph without a parallel backend (single-threaded).
43    pub fn new_single_threaded() -> Self {
44        Self {
45            nodes: Vec::new(),
46            backend: None,
47        }
48    }
49
50    /// Set a compute backend for GPU-accelerated operations.
51    /// When set, supported ops will dispatch through this backend.
52    /// When None (default), ops use direct CPU kernel calls.
53    pub fn set_backend(&mut self, backend: Box<dyn BackwardOps>) {
54        self.backend = Some(backend);
55    }
56
57    /// Remove the backend, reverting to CPU kernel calls.
58    pub fn clear_backend(&mut self) {
59        self.backend = None;
60    }
61
62    /// Adds a trainable leaf node.
63    pub fn variable(&mut self, value: Tensor) -> NodeId {
64        self.push_node(value, true, Op::Leaf)
65    }
66
67    /// Adds a non-trainable leaf node.
68    pub fn constant(&mut self, value: Tensor) -> NodeId {
69        self.push_node(value, false, Op::Leaf)
70    }
71
72    /// Returns immutable node value.
73    pub fn value(&self, node: NodeId) -> Result<&Tensor, AutogradError> {
74        Ok(&self.node(node)?.value)
75    }
76
77    /// Returns mutable node value.
78    pub fn value_mut(&mut self, node: NodeId) -> Result<&mut Tensor, AutogradError> {
79        Ok(&mut self.node_mut(node)?.value)
80    }
81
82    /// Returns whether node is trainable.
83    pub fn requires_grad(&self, node: NodeId) -> Result<bool, AutogradError> {
84        Ok(self.node(node)?.requires_grad)
85    }
86
87    /// Returns immutable gradient if already computed.
88    pub fn grad(&self, node: NodeId) -> Result<Option<&Tensor>, AutogradError> {
89        Ok(self.node(node)?.grad.as_ref())
90    }
91
92    /// Returns mutable gradient if already computed.
93    pub fn grad_mut(&mut self, node: NodeId) -> Result<Option<&mut Tensor>, AutogradError> {
94        Ok(self.node_mut(node)?.grad.as_mut())
95    }
96
97    /// Sets the gradient for a node, replacing any existing gradient.
98    pub fn set_grad(&mut self, node: NodeId, grad: Tensor) -> Result<(), AutogradError> {
99        self.node_mut(node)?.grad = Some(grad);
100        Ok(())
101    }
102
103    /// Returns current node count in the graph.
104    pub fn node_count(&self) -> usize {
105        self.nodes.len()
106    }
107
108    /// Clears gradients for all nodes.
109    pub fn zero_grads(&mut self) {
110        for node in &mut self.nodes {
111            node.grad = None;
112        }
113    }
114
115    /// Truncates graph to a given node count.
116    pub fn truncate(&mut self, keep_nodes: usize) -> Result<(), AutogradError> {
117        if keep_nodes > self.nodes.len() {
118            return Err(AutogradError::InvalidTruncate {
119                requested: keep_nodes,
120                available: self.nodes.len(),
121            });
122        }
123        self.nodes.truncate(keep_nodes);
124        Ok(())
125    }
126
127    /// Adds two nodes with broadcasting support.
128    pub fn add(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
129        let (value, requires_grad) = {
130            let lv = &self.nodes[left.0].value;
131            let rv = &self.nodes[right.0].value;
132            let result = if let Some(ref backend) = self.backend {
133                backend.add(lv, rv)?
134            } else {
135                kernel_add(lv, rv)?
136            };
137            (
138                result,
139                self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
140            )
141        };
142        Ok(self.push_node(value, requires_grad, Op::Add(left, right)))
143    }
144
145    /// Subtracts two nodes with broadcasting support.
146    pub fn sub(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
147        let (value, requires_grad) = {
148            let lv = &self.nodes[left.0].value;
149            let rv = &self.nodes[right.0].value;
150            let result = if let Some(ref backend) = self.backend {
151                backend.sub(lv, rv)?
152            } else {
153                kernel_sub(lv, rv)?
154            };
155            (
156                result,
157                self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
158            )
159        };
160        Ok(self.push_node(value, requires_grad, Op::Sub(left, right)))
161    }
162
163    /// Multiplies two nodes elementwise with broadcasting support.
164    pub fn mul(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
165        let (value, requires_grad) = {
166            let lv = &self.nodes[left.0].value;
167            let rv = &self.nodes[right.0].value;
168            let result = if let Some(ref backend) = self.backend {
169                backend.mul(lv, rv)?
170            } else {
171                kernel_mul(lv, rv)?
172            };
173            (
174                result,
175                self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
176            )
177        };
178        Ok(self.push_node(value, requires_grad, Op::Mul(left, right)))
179    }
180
181    /// Applies ReLU activation to one node.
182    pub fn relu(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
183        let (value, requires_grad) = {
184            let v = &self.nodes[input.0].value;
185            let result = if let Some(ref backend) = self.backend {
186                backend.relu(v)
187            } else {
188                relu(v)
189            };
190            (result, self.nodes[input.0].requires_grad)
191        };
192        Ok(self.push_node(value, requires_grad, Op::Relu(input)))
193    }
194
195    /// Performs rank-2 matrix multiplication.
196    pub fn matmul_2d(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
197        let (value, requires_grad) = {
198            let lv = &self.nodes[left.0].value;
199            let rv = &self.nodes[right.0].value;
200            let result = if let Some(ref backend) = self.backend {
201                backend.matmul_2d(lv, rv)?
202            } else {
203                matmul_2d(lv, rv)?
204            };
205            (
206                result,
207                self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
208            )
209        };
210        Ok(self.push_node(value, requires_grad, Op::MatMul2D(left, right)))
211    }
212
213    /// Divides two nodes elementwise with broadcasting support.
214    pub fn div(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
215        let (value, requires_grad) = {
216            let lv = &self.nodes[left.0].value;
217            let rv = &self.nodes[right.0].value;
218            (
219                lv.div(rv)?,
220                self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
221            )
222        };
223        Ok(self.push_node(value, requires_grad, Op::Div(left, right)))
224    }
225
226    /// Applies element-wise negation.
227    pub fn neg(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
228        let (value, requires_grad) = {
229            let v = &self.nodes[input.0].value;
230            (v.neg(), self.nodes[input.0].requires_grad)
231        };
232        Ok(self.push_node(value, requires_grad, Op::Neg(input)))
233    }
234
235    /// Applies element-wise natural exponential.
236    pub fn exp(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
237        let (value, requires_grad) = {
238            let v = &self.nodes[input.0].value;
239            (v.exp(), self.nodes[input.0].requires_grad)
240        };
241        Ok(self.push_node(value, requires_grad, Op::Exp(input)))
242    }
243
244    /// Applies element-wise natural logarithm.
245    pub fn log(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
246        let (value, requires_grad) = {
247            let v = &self.nodes[input.0].value;
248            (v.ln(), self.nodes[input.0].requires_grad)
249        };
250        Ok(self.push_node(value, requires_grad, Op::Log(input)))
251    }
252
253    /// Applies element-wise square root.
254    pub fn sqrt(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
255        let (value, requires_grad) = {
256            let v = &self.nodes[input.0].value;
257            (v.sqrt(), self.nodes[input.0].requires_grad)
258        };
259        Ok(self.push_node(value, requires_grad, Op::Sqrt(input)))
260    }
261
262    /// Applies element-wise sigmoid activation: `1 / (1 + exp(-x))`.
263    pub fn sigmoid(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
264        let (value, requires_grad) = {
265            let v = &self.nodes[input.0].value;
266            let result = if let Some(ref backend) = self.backend {
267                backend.sigmoid(v)
268            } else {
269                kernel_sigmoid(v)
270            };
271            (result, self.nodes[input.0].requires_grad)
272        };
273        Ok(self.push_node(value, requires_grad, Op::Sigmoid(input)))
274    }
275
276    /// Applies element-wise GELU activation (fast approximation): `x * sigmoid(1.702 * x)`.
277    pub fn gelu(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
278        let (value, requires_grad) = {
279            let v = &self.nodes[input.0].value;
280            (kernel_gelu(v), self.nodes[input.0].requires_grad)
281        };
282        Ok(self.push_node(value, requires_grad, Op::Gelu(input)))
283    }
284
285    /// Applies element-wise SiLU (Swish) activation: `x * sigmoid(x)`.
286    pub fn silu(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
287        let (value, requires_grad) = {
288            let v = &self.nodes[input.0].value;
289            (kernel_silu(v), self.nodes[input.0].requires_grad)
290        };
291        Ok(self.push_node(value, requires_grad, Op::Silu(input)))
292    }
293
294    /// Applies element-wise Mish activation: `x * tanh(softplus(x))`.
295    pub fn mish(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
296        let (value, requires_grad) = {
297            let v = &self.nodes[input.0].value;
298            (kernel_mish(v), self.nodes[input.0].requires_grad)
299        };
300        Ok(self.push_node(value, requires_grad, Op::Mish(input)))
301    }
302
303    /// Applies element-wise hyperbolic tangent.
304    pub fn tanh(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
305        let (value, requires_grad) = {
306            let v = &self.nodes[input.0].value;
307            let data: Vec<f32> = v.data().iter().map(|&x| x.tanh()).collect();
308            (
309                Tensor::from_vec(v.shape().to_vec(), data)?,
310                self.nodes[input.0].requires_grad,
311            )
312        };
313        Ok(self.push_node(value, requires_grad, Op::Tanh(input)))
314    }
315
316    /// Applies element-wise absolute value.
317    pub fn abs(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
318        let (value, requires_grad) = {
319            let v = &self.nodes[input.0].value;
320            let data: Vec<f32> = v.data().iter().map(|&x| x.abs()).collect();
321            (
322                Tensor::from_vec(v.shape().to_vec(), data)?,
323                self.nodes[input.0].requires_grad,
324            )
325        };
326        Ok(self.push_node(value, requires_grad, Op::Abs(input)))
327    }
328
329    /// Applies element-wise power: `base ^ exponent`.
330    pub fn pow(&mut self, base: NodeId, exponent: NodeId) -> Result<NodeId, AutogradError> {
331        let (value, requires_grad) = {
332            let bv = &self.nodes[base.0].value;
333            let ev = &self.nodes[exponent.0].value;
334            (
335                bv.pow(ev)?,
336                self.nodes[base.0].requires_grad || self.nodes[exponent.0].requires_grad,
337            )
338        };
339        Ok(self.push_node(value, requires_grad, Op::Pow(base, exponent)))
340    }
341
342    /// Applies element-wise clamping to `[min_val, max_val]`.
343    pub fn clamp(
344        &mut self,
345        input: NodeId,
346        min_val: f32,
347        max_val: f32,
348    ) -> Result<NodeId, AutogradError> {
349        let (value, requires_grad) = {
350            let v = &self.nodes[input.0].value;
351            (v.clamp(min_val, max_val), self.nodes[input.0].requires_grad)
352        };
353        Ok(self.push_node(
354            value,
355            requires_grad,
356            Op::Clamp {
357                input,
358                min_bits: min_val.to_bits(),
359                max_bits: max_val.to_bits(),
360            },
361        ))
362    }
363
364    /// Applies element-wise leaky ReLU: `max(0, x) + negative_slope * min(0, x)`.
365    pub fn leaky_relu(
366        &mut self,
367        input: NodeId,
368        negative_slope: f32,
369    ) -> Result<NodeId, AutogradError> {
370        let (value, requires_grad) = {
371            let v = &self.nodes[input.0].value;
372            let data: Vec<f32> = v
373                .data()
374                .iter()
375                .map(|&x| if x >= 0.0 { x } else { negative_slope * x })
376                .collect();
377            (
378                Tensor::from_vec(v.shape().to_vec(), data)?,
379                self.nodes[input.0].requires_grad,
380            )
381        };
382        Ok(self.push_node(
383            value,
384            requires_grad,
385            Op::LeakyRelu {
386                input,
387                negative_slope: negative_slope.to_bits(),
388            },
389        ))
390    }
391
392    /// Applies softmax along the last dimension.
393    pub fn softmax(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
394        let (value, requires_grad) = {
395            let v = &self.nodes[input.0].value;
396            let result = if let Some(ref backend) = self.backend {
397                backend.softmax_last_dim(v)?
398            } else {
399                softmax_last_dim(v)
400            };
401            (result, self.nodes[input.0].requires_grad)
402        };
403        Ok(self.push_node(value, requires_grad, Op::Softmax(input)))
404    }
405
406    /// Applies log-softmax along the last dimension.
407    pub fn log_softmax(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
408        let (value, requires_grad) = {
409            let v = &self.nodes[input.0].value;
410            let sm = softmax_last_dim(v);
411            let data: Vec<f32> = sm.data().iter().map(|&x| x.max(1e-12).ln()).collect();
412            (
413                Tensor::from_vec(sm.shape().to_vec(), data)?,
414                self.nodes[input.0].requires_grad,
415            )
416        };
417        Ok(self.push_node(value, requires_grad, Op::LogSoftmax(input)))
418    }
419
420    /// 2D matrix transpose in graph (for backward).
421    pub fn transpose_2d(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
422        let (value, requires_grad) = {
423            let v = &self.nodes[input.0].value;
424            (v.transpose_2d()?, self.nodes[input.0].requires_grad)
425        };
426        Ok(self.push_node(value, requires_grad, Op::Transpose2D(input)))
427    }
428
429    /// Reshape in graph (preserves backward path).
430    pub fn reshape(
431        &mut self,
432        input: NodeId,
433        new_shape: Vec<usize>,
434    ) -> Result<NodeId, AutogradError> {
435        let (value, requires_grad) = {
436            let v = &self.nodes[input.0].value;
437            (v.reshape(new_shape)?, self.nodes[input.0].requires_grad)
438        };
439        Ok(self.push_node(value, requires_grad, Op::ReshapeView { input }))
440    }
441
442    /// Unsqueeze in graph (preserves backward path).
443    pub fn unsqueeze(&mut self, input: NodeId, axis: usize) -> Result<NodeId, AutogradError> {
444        let (value, requires_grad) = {
445            let v = &self.nodes[input.0].value;
446            (v.unsqueeze(axis)?, self.nodes[input.0].requires_grad)
447        };
448        Ok(self.push_node(
449            value,
450            requires_grad,
451            Op::UnsqueezeView {
452                input,
453                axis: axis as u16,
454            },
455        ))
456    }
457
458    /// Squeeze in graph (preserves backward path).
459    pub fn squeeze(&mut self, input: NodeId, axis: usize) -> Result<NodeId, AutogradError> {
460        let (value, requires_grad) = {
461            let v = &self.nodes[input.0].value;
462            (v.squeeze(axis)?, self.nodes[input.0].requires_grad)
463        };
464        Ok(self.push_node(
465            value,
466            requires_grad,
467            Op::SqueezeView {
468                input,
469                axis: axis as u16,
470            },
471        ))
472    }
473
474    /// Concatenates multiple nodes along `axis`.
475    pub fn cat(&mut self, inputs: &[NodeId], axis: usize) -> Result<NodeId, AutogradError> {
476        if inputs.is_empty() {
477            return Err(AutogradError::InvalidRankForOperation {
478                op: "cat",
479                expected: 1,
480                got: 0,
481            });
482        }
483        let tensors: Vec<&Tensor> = inputs.iter().map(|&id| &self.nodes[id.0].value).collect();
484        let value = Tensor::cat(&tensors, axis)?;
485        let requires_grad = inputs.iter().any(|&id| self.nodes[id.0].requires_grad);
486        Ok(self.push_node(
487            value,
488            requires_grad,
489            Op::Cat {
490                inputs: inputs.to_vec(),
491                axis: axis as u16,
492            },
493        ))
494    }
495
496    /// Selects a single index along `axis`, reducing that dimension.
497    pub fn select(
498        &mut self,
499        input: NodeId,
500        axis: usize,
501        index: usize,
502    ) -> Result<NodeId, AutogradError> {
503        let (value, requires_grad) = {
504            let v = &self.nodes[input.0].value;
505            (v.select(axis, index)?, self.nodes[input.0].requires_grad)
506        };
507        Ok(self.push_node(
508            value,
509            requires_grad,
510            Op::Select {
511                input,
512                axis: axis as u16,
513                index: index as u32,
514            },
515        ))
516    }
517
518    /// Narrows (slices) a node along `axis` from `start` for `length` elements.
519    pub fn narrow(
520        &mut self,
521        input: NodeId,
522        axis: usize,
523        start: usize,
524        length: usize,
525    ) -> Result<NodeId, AutogradError> {
526        let (value, requires_grad) = {
527            let v = &self.nodes[input.0].value;
528            (
529                v.narrow(axis, start, length)?,
530                self.nodes[input.0].requires_grad,
531            )
532        };
533        Ok(self.push_node(
534            value,
535            requires_grad,
536            Op::Narrow {
537                input,
538                axis: axis as u16,
539                start: start as u32,
540                len: length as u32,
541            },
542        ))
543    }
544
545    /// Gathers elements along `axis` using an index tensor (from another node).
546    ///
547    /// For each position in the index tensor, retrieves the value from `input` at the index along `axis`.
548    pub fn gather(
549        &mut self,
550        input: NodeId,
551        axis: usize,
552        index: NodeId,
553    ) -> Result<NodeId, AutogradError> {
554        let (value, requires_grad) = {
555            let iv = &self.nodes[input.0].value;
556            let idx = &self.nodes[index.0].value;
557            (iv.gather(axis, idx)?, self.nodes[input.0].requires_grad)
558        };
559        Ok(self.push_node(
560            value,
561            requires_grad,
562            Op::Gather {
563                input,
564                axis: axis as u16,
565                index,
566            },
567        ))
568    }
569
570    /// Scatter-add operation: scatters `src` values into `input` at `index` positions along `axis`.
571    ///
572    /// Forward: `output = input.scatter_add(axis, index, src)`
573    pub fn scatter_add(
574        &mut self,
575        input: NodeId,
576        index: NodeId,
577        src: NodeId,
578        axis: usize,
579    ) -> Result<NodeId, AutogradError> {
580        let (value, requires_grad) = {
581            let iv = &self.nodes[input.0].value;
582            let idx = &self.nodes[index.0].value;
583            let sv = &self.nodes[src.0].value;
584            (
585                iv.scatter_add(axis, idx, sv)?,
586                self.nodes[input.0].requires_grad || self.nodes[src.0].requires_grad,
587            )
588        };
589        Ok(self.push_node(
590            value,
591            requires_grad,
592            Op::ScatterAdd {
593                input,
594                axis: axis as u16,
595                index,
596                src,
597            },
598        ))
599    }
600
601    /// Pads the tensor with a constant value along each dimension.
602    ///
603    /// `padding` is a flat array of `[before_0, after_0, before_1, after_1, ...]` pairs per dim.
604    pub fn pad(
605        &mut self,
606        input: NodeId,
607        padding: &[usize],
608        value: f32,
609    ) -> Result<NodeId, AutogradError> {
610        let (result, requires_grad, pad_before, pad_after) = {
611            let iv = &self.nodes[input.0].value;
612            let shape = iv.shape();
613            let rank = shape.len();
614            if padding.len() != rank * 2 {
615                return Err(AutogradError::InvalidRankForOperation {
616                    op: "pad",
617                    expected: rank * 2,
618                    got: padding.len(),
619                });
620            }
621            let mut new_shape = Vec::with_capacity(rank);
622            let mut pad_before = Vec::with_capacity(rank);
623            let mut pad_after = Vec::with_capacity(rank);
624            for d in 0..rank {
625                let pb = padding[d * 2];
626                let pa = padding[d * 2 + 1];
627                pad_before.push(pb as u32);
628                pad_after.push(pa as u32);
629                new_shape.push(shape[d] + pb + pa);
630            }
631            let total: usize = new_shape.iter().product();
632            let mut out_data = vec![value; total];
633            let data = iv.data();
634            copy_region_nd(data, shape, &mut out_data, &new_shape, &pad_before);
635
636            let result = Tensor::from_vec(new_shape, out_data)?;
637            (
638                result,
639                self.nodes[input.0].requires_grad,
640                pad_before,
641                pad_after,
642            )
643        };
644        Ok(self.push_node(
645            result,
646            requires_grad,
647            Op::Pad {
648                input,
649                pad_before,
650                pad_after,
651            },
652        ))
653    }
654
655    /// Repeats the tensor along each dimension.
656    pub fn repeat(&mut self, input: NodeId, repeats: &[usize]) -> Result<NodeId, AutogradError> {
657        let (result, requires_grad) = {
658            let v = &self.nodes[input.0].value;
659            (v.repeat(repeats)?, self.nodes[input.0].requires_grad)
660        };
661        let reps: Vec<u32> = repeats.iter().map(|&r| r as u32).collect();
662        Ok(self.push_node(
663            result,
664            requires_grad,
665            Op::Repeat {
666                input,
667                repeats: reps,
668            },
669        ))
670    }
671
672    /// Reduces one node to scalar sum.
673    pub fn sum(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
674        let (value, requires_grad) = {
675            let v = &self.nodes[input.0].value;
676            (Tensor::scalar(v.sum()), self.nodes[input.0].requires_grad)
677        };
678        Ok(self.push_node(value, requires_grad, Op::Sum(input)))
679    }
680
681    /// Reduces one node to scalar mean.
682    pub fn mean(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
683        let (value, requires_grad) = {
684            let v = &self.nodes[input.0].value;
685            (Tensor::scalar(v.mean()), self.nodes[input.0].requires_grad)
686        };
687        Ok(self.push_node(value, requires_grad, Op::Mean(input)))
688    }
689
690    /// NHWC 2-D convolution forward.
691    /// `input` shape \[N,H,W,C_in\], `weight` shape \[KH,KW,C_in,C_out\],
692    /// optional `bias` shape \[C_out\].
693    pub fn conv2d_nhwc(
694        &mut self,
695        input: NodeId,
696        weight: NodeId,
697        bias: Option<NodeId>,
698        stride_h: usize,
699        stride_w: usize,
700    ) -> Result<NodeId, AutogradError> {
701        let (value, requires_grad) = {
702            let iv = &self.nodes[input.0].value;
703            let wv = &self.nodes[weight.0].value;
704            let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
705            let rg = self.nodes[input.0].requires_grad
706                || self.nodes[weight.0].requires_grad
707                || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
708            let result = if let Some(ref backend) = self.backend {
709                backend.conv2d_nhwc(iv, wv, bv, stride_h, stride_w)?
710            } else {
711                conv2d_nhwc(iv, wv, bv, stride_h, stride_w)?
712            };
713            (result, rg)
714        };
715        Ok(self.push_node(
716            value,
717            requires_grad,
718            Op::Conv2dNhwc {
719                input,
720                weight,
721                bias,
722                stride_h: stride_h as u16,
723                stride_w: stride_w as u16,
724            },
725        ))
726    }
727
728    /// NHWC max-pooling forward with argmax tracking for backward.
729    pub fn max_pool2d_nhwc(
730        &mut self,
731        input: NodeId,
732        kernel_h: usize,
733        kernel_w: usize,
734        stride_h: usize,
735        stride_w: usize,
736    ) -> Result<NodeId, AutogradError> {
737        let (value, requires_grad, indices) = {
738            let iv = &self.nodes[input.0].value;
739            let shape = iv.shape();
740            if shape.len() != 4 {
741                return Err(AutogradError::InvalidRankForOperation {
742                    op: "max_pool2d_nhwc",
743                    expected: 4,
744                    got: shape.len(),
745                });
746            }
747            let (n, ih, iw, c) = (shape[0], shape[1], shape[2], shape[3]);
748            let oh = (ih - kernel_h) / stride_h + 1;
749            let ow = (iw - kernel_w) / stride_w + 1;
750
751            // Always compute argmax indices for backward support.
752            let mut indices = vec![0usize; n * oh * ow * c];
753            let in_data = iv.data();
754
755            for batch in 0..n {
756                for row in 0..oh {
757                    for col in 0..ow {
758                        for ch in 0..c {
759                            let out_idx = ((batch * oh + row) * ow + col) * c + ch;
760                            let mut best_val = f32::NEG_INFINITY;
761                            let mut best_offset = 0usize;
762                            for kh in 0..kernel_h {
763                                for kw in 0..kernel_w {
764                                    let ih_pos = row * stride_h + kh;
765                                    let iw_pos = col * stride_w + kw;
766                                    let in_idx = ((batch * ih + ih_pos) * iw + iw_pos) * c + ch;
767                                    let v = in_data[in_idx];
768                                    if v > best_val {
769                                        best_val = v;
770                                        best_offset = in_idx;
771                                    }
772                                }
773                            }
774                            indices[out_idx] = best_offset;
775                        }
776                    }
777                }
778            }
779
780            let value = if let Some(ref backend) = self.backend {
781                backend.max_pool2d_nhwc(iv, kernel_h, kernel_w, stride_h, stride_w)?
782            } else {
783                let mut out_data = vec![f32::NEG_INFINITY; n * oh * ow * c];
784                for batch in 0..n {
785                    for row in 0..oh {
786                        for col in 0..ow {
787                            for ch in 0..c {
788                                let out_idx = ((batch * oh + row) * ow + col) * c + ch;
789                                out_data[out_idx] = in_data[indices[out_idx]];
790                            }
791                        }
792                    }
793                }
794                Tensor::from_vec(vec![n, oh, ow, c], out_data)?
795            };
796            (value, self.nodes[input.0].requires_grad, indices)
797        };
798        Ok(self.push_node_with_aux(
799            value,
800            requires_grad,
801            Op::MaxPool2dNhwc {
802                input,
803                kernel_h: kernel_h as u16,
804                kernel_w: kernel_w as u16,
805                stride_h: stride_h as u16,
806                stride_w: stride_w as u16,
807            },
808            AuxData::MaxPoolIndices(indices),
809        ))
810    }
811
812    /// NHWC average-pooling forward.
813    pub fn avg_pool2d_nhwc(
814        &mut self,
815        input: NodeId,
816        kernel_h: usize,
817        kernel_w: usize,
818        stride_h: usize,
819        stride_w: usize,
820    ) -> Result<NodeId, AutogradError> {
821        let (value, requires_grad) = {
822            let v = &self.nodes[input.0].value;
823            let result = if let Some(ref backend) = self.backend {
824                backend.avg_pool2d_nhwc(v, kernel_h, kernel_w, stride_h, stride_w)?
825            } else {
826                avg_pool2d_nhwc(v, kernel_h, kernel_w, stride_h, stride_w)?
827            };
828            (result, self.nodes[input.0].requires_grad)
829        };
830        Ok(self.push_node(
831            value,
832            requires_grad,
833            Op::AvgPool2dNhwc {
834                input,
835                kernel_h: kernel_h as u16,
836                kernel_w: kernel_w as u16,
837                stride_h: stride_h as u16,
838                stride_w: stride_w as u16,
839            },
840        ))
841    }
842
843    /// NHWC batch-normalization forward (inference mode: uses running stats).
844    /// `gamma`/`beta`/`running_mean`/`running_var` must be rank-1 of size `C`.
845    pub fn batch_norm2d_nhwc(
846        &mut self,
847        input: NodeId,
848        gamma: NodeId,
849        beta: NodeId,
850        running_mean: NodeId,
851        running_var: NodeId,
852        epsilon: f32,
853    ) -> Result<NodeId, AutogradError> {
854        let eps_bits = epsilon.to_bits();
855        let (value, requires_grad, norm_tensor) = {
856            let iv = &self.nodes[input.0].value;
857            let gv = &self.nodes[gamma.0].value;
858            let bv = &self.nodes[beta.0].value;
859            let mv = &self.nodes[running_mean.0].value;
860            let vv = &self.nodes[running_var.0].value;
861
862            let shape = iv.shape();
863            if shape.len() != 4 {
864                return Err(AutogradError::InvalidRankForOperation {
865                    op: "batch_norm2d_nhwc",
866                    expected: 4,
867                    got: shape.len(),
868                });
869            }
870            // Always compute normalized tensor for backward support.
871            let c = shape[3];
872            let total = iv.len();
873            let in_data = iv.data();
874            let mean_data = mv.data();
875            let var_data = vv.data();
876            let mut normalized = vec![0.0f32; total];
877            for i in 0..total {
878                let ch = i % c;
879                let inv_std = 1.0 / (var_data[ch] + epsilon).sqrt();
880                normalized[i] = (in_data[i] - mean_data[ch]) * inv_std;
881            }
882            let norm_tensor = Tensor::from_vec(shape.to_vec(), normalized)?;
883
884            let params = BatchNorm2dParams {
885                gamma: gv,
886                beta: bv,
887                mean: mv,
888                variance: vv,
889                epsilon,
890            };
891            let value = if let Some(ref backend) = self.backend {
892                backend.batch_norm2d_nhwc(iv, params)?
893            } else {
894                batch_norm2d_nhwc(iv, params)?
895            };
896            let rg = self.nodes[input.0].requires_grad
897                || self.nodes[gamma.0].requires_grad
898                || self.nodes[beta.0].requires_grad;
899            (value, rg, norm_tensor)
900        };
901        Ok(self.push_node_with_aux(
902            value,
903            requires_grad,
904            Op::BatchNorm2dNhwc {
905                input,
906                gamma,
907                beta,
908                running_mean,
909                running_var,
910                epsilon: eps_bits,
911            },
912            AuxData::BatchNormNormalized(norm_tensor),
913        ))
914    }
915
916    /// Layer normalization over the last dimension.
917    ///
918    /// Input can be any rank; normalization is applied over the last axis.
919    /// `gamma` and `beta` must have shape `[last_dim]`.
920    pub fn layer_norm(
921        &mut self,
922        input: NodeId,
923        gamma: NodeId,
924        beta: NodeId,
925        epsilon: f32,
926    ) -> Result<NodeId, AutogradError> {
927        let eps_bits = epsilon.to_bits();
928        let (value, requires_grad, norm_tensor) = {
929            let iv = &self.nodes[input.0].value;
930            let gv = &self.nodes[gamma.0].value;
931            let bv = &self.nodes[beta.0].value;
932            let shape = iv.shape();
933            let last_dim = *shape.last().ok_or(AutogradError::InvalidRankForOperation {
934                op: "layer_norm",
935                expected: 1,
936                got: 0,
937            })?;
938            let data = iv.data();
939            let gamma_data = gv.data();
940            let beta_data = bv.data();
941            let num_groups = data.len() / last_dim;
942            let mut out = vec![0.0f32; data.len()];
943            let mut normalized = vec![0.0f32; data.len()];
944            for g in 0..num_groups {
945                let base = g * last_dim;
946                let slice = &data[base..base + last_dim];
947                let mean = slice.iter().sum::<f32>() / last_dim as f32;
948                let var =
949                    slice.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / last_dim as f32;
950                let inv_std = 1.0 / (var + epsilon).sqrt();
951                for i in 0..last_dim {
952                    let x_hat = (slice[i] - mean) * inv_std;
953                    normalized[base + i] = x_hat;
954                    out[base + i] = x_hat * gamma_data[i] + beta_data[i];
955                }
956            }
957            let value = Tensor::from_vec(shape.to_vec(), out)?;
958            let norm_tensor = Tensor::from_vec(shape.to_vec(), normalized)?;
959            let rg = self.nodes[input.0].requires_grad
960                || self.nodes[gamma.0].requires_grad
961                || self.nodes[beta.0].requires_grad;
962            (value, rg, norm_tensor)
963        };
964        Ok(self.push_node_with_aux(
965            value,
966            requires_grad,
967            Op::LayerNorm {
968                input,
969                gamma,
970                beta,
971                eps_bits,
972            },
973            AuxData::NormNormalized(norm_tensor),
974        ))
975    }
976
977    /// Group normalization on NHWC input `[N, H, W, C]`.
978    ///
979    /// `gamma` and `beta` must have shape `[C]`.
980    /// `num_groups` must divide `C`.
981    pub fn group_norm(
982        &mut self,
983        input: NodeId,
984        gamma: NodeId,
985        beta: NodeId,
986        num_groups: usize,
987        epsilon: f32,
988    ) -> Result<NodeId, AutogradError> {
989        let eps_bits = epsilon.to_bits();
990        let (value, requires_grad, norm_tensor) = {
991            let iv = &self.nodes[input.0].value;
992            let gv = &self.nodes[gamma.0].value;
993            let bv = &self.nodes[beta.0].value;
994            let shape = iv.shape();
995            if shape.len() != 4 {
996                return Err(AutogradError::InvalidRankForOperation {
997                    op: "group_norm",
998                    expected: 4,
999                    got: shape.len(),
1000                });
1001            }
1002            let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1003            let channels_per_group = c / num_groups;
1004            let spatial = h * w;
1005            let data = iv.data();
1006            let gamma_data = gv.data();
1007            let beta_data = bv.data();
1008            let mut out = vec![0.0f32; data.len()];
1009            let mut normalized = vec![0.0f32; data.len()];
1010
1011            for ni in 0..n {
1012                for gi in 0..num_groups {
1013                    let c_start = gi * channels_per_group;
1014                    let c_end = c_start + channels_per_group;
1015                    let group_size = spatial * channels_per_group;
1016                    let mut sum = 0.0f32;
1017                    for hi in 0..h {
1018                        for wi in 0..w {
1019                            let base = ((ni * h + hi) * w + wi) * c;
1020                            for ci in c_start..c_end {
1021                                sum += data[base + ci];
1022                            }
1023                        }
1024                    }
1025                    let mean = sum / group_size as f32;
1026                    let mut var_sum = 0.0f32;
1027                    for hi in 0..h {
1028                        for wi in 0..w {
1029                            let base = ((ni * h + hi) * w + wi) * c;
1030                            for ci in c_start..c_end {
1031                                let d = data[base + ci] - mean;
1032                                var_sum += d * d;
1033                            }
1034                        }
1035                    }
1036                    let inv_std = 1.0 / (var_sum / group_size as f32 + epsilon).sqrt();
1037                    for hi in 0..h {
1038                        for wi in 0..w {
1039                            let base = ((ni * h + hi) * w + wi) * c;
1040                            for ci in c_start..c_end {
1041                                let x_hat = (data[base + ci] - mean) * inv_std;
1042                                normalized[base + ci] = x_hat;
1043                                out[base + ci] = x_hat * gamma_data[ci] + beta_data[ci];
1044                            }
1045                        }
1046                    }
1047                }
1048            }
1049            let value = Tensor::from_vec(shape.to_vec(), out)?;
1050            let norm_tensor = Tensor::from_vec(shape.to_vec(), normalized)?;
1051            let rg = self.nodes[input.0].requires_grad
1052                || self.nodes[gamma.0].requires_grad
1053                || self.nodes[beta.0].requires_grad;
1054            (value, rg, norm_tensor)
1055        };
1056        Ok(self.push_node_with_aux(
1057            value,
1058            requires_grad,
1059            Op::GroupNorm {
1060                input,
1061                gamma,
1062                beta,
1063                num_groups: num_groups as u16,
1064                eps_bits,
1065            },
1066            AuxData::NormNormalized(norm_tensor),
1067        ))
1068    }
1069
1070    /// Flatten rank-4 NHWC tensor `[N,H,W,C]` to rank-2 `[N, H*W*C]`.
1071    pub fn flatten(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
1072        let (value, requires_grad) = {
1073            let v = &self.nodes[input.0].value;
1074            let shape = v.shape();
1075            if shape.len() < 2 {
1076                return Err(AutogradError::InvalidRankForOperation {
1077                    op: "flatten",
1078                    expected: 2,
1079                    got: shape.len(),
1080                });
1081            }
1082            let n = shape[0];
1083            let flat = v.len() / n;
1084            (v.reshape(vec![n, flat])?, self.nodes[input.0].requires_grad)
1085        };
1086        Ok(self.push_node(value, requires_grad, Op::Flatten(input)))
1087    }
1088
1089    /// Reduces one node by summing along a single axis (removing that dimension).
1090    pub fn sum_axis(&mut self, input: NodeId, axis: usize) -> Result<NodeId, AutogradError> {
1091        let (value, requires_grad) = {
1092            let v = &self.nodes[input.0].value;
1093            (v.sum_axis(axis)?, self.nodes[input.0].requires_grad)
1094        };
1095        Ok(self.push_node(
1096            value,
1097            requires_grad,
1098            Op::SumAxis {
1099                input,
1100                axis: axis as u16,
1101            },
1102        ))
1103    }
1104
1105    /// Reduces one node by averaging along a single axis (removing that dimension).
1106    pub fn mean_axis(&mut self, input: NodeId, axis: usize) -> Result<NodeId, AutogradError> {
1107        let (value, requires_grad) = {
1108            let v = &self.nodes[input.0].value;
1109            (v.mean_axis(axis)?, self.nodes[input.0].requires_grad)
1110        };
1111        Ok(self.push_node(
1112            value,
1113            requires_grad,
1114            Op::MeanAxis {
1115                input,
1116                axis: axis as u16,
1117            },
1118        ))
1119    }
1120
1121    /// NHWC depthwise 2-D convolution forward.
1122    /// `input` shape `[N,H,W,C]`, `weight` shape `[KH,KW,C,1]`,
1123    /// optional `bias` shape `[C]`.
1124    pub fn depthwise_conv2d_nhwc(
1125        &mut self,
1126        input: NodeId,
1127        weight: NodeId,
1128        bias: Option<NodeId>,
1129        stride_h: usize,
1130        stride_w: usize,
1131    ) -> Result<NodeId, AutogradError> {
1132        let (value, requires_grad) = {
1133            let iv = &self.nodes[input.0].value;
1134            let wv = &self.nodes[weight.0].value;
1135            let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
1136            let rg = self.nodes[input.0].requires_grad
1137                || self.nodes[weight.0].requires_grad
1138                || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
1139            let result = if let Some(ref backend) = self.backend {
1140                backend.depthwise_conv2d_nhwc(iv, wv, bv, stride_h, stride_w)?
1141            } else {
1142                depthwise_conv2d_nhwc(iv, wv, bv, stride_h, stride_w)?
1143            };
1144            (result, rg)
1145        };
1146        Ok(self.push_node(
1147            value,
1148            requires_grad,
1149            Op::DepthwiseConv2dNhwc {
1150                input,
1151                weight,
1152                bias,
1153                stride_h: stride_h as u16,
1154                stride_w: stride_w as u16,
1155            },
1156        ))
1157    }
1158
1159    /// Scatter: write values from `src` into `input` at row positions given by `indices`.
1160    /// input shape: `[N, D]`, indices shape: `[M]`, src shape: `[M, D]`.
1161    /// Result: input with rows at indices replaced by src rows.
1162    pub fn scatter(
1163        &mut self,
1164        input: NodeId,
1165        indices: NodeId,
1166        src: NodeId,
1167    ) -> Result<NodeId, AutogradError> {
1168        let (value, requires_grad) = {
1169            let iv = &self.nodes[input.0].value;
1170            let idx = &self.nodes[indices.0].value;
1171            let sv = &self.nodes[src.0].value;
1172            let shape = iv.shape();
1173            let d = shape[1];
1174            let mut out_data = iv.data().to_vec();
1175            for (i, &raw_idx) in idx.data().iter().enumerate() {
1176                let row = raw_idx as usize;
1177                let src_offset = i * d;
1178                let dst_offset = row * d;
1179                out_data[dst_offset..dst_offset + d]
1180                    .copy_from_slice(&sv.data()[src_offset..src_offset + d]);
1181            }
1182            let rg = self.nodes[input.0].requires_grad || self.nodes[src.0].requires_grad;
1183            (Tensor::from_vec(shape.to_vec(), out_data)?, rg)
1184        };
1185        Ok(self.push_node(
1186            value,
1187            requires_grad,
1188            Op::Scatter {
1189                input,
1190                indices,
1191                src,
1192            },
1193        ))
1194    }
1195
1196    /// Embedding lookup: gather rows from weight matrix at given indices.
1197    /// weight shape: `[vocab_size, embed_dim]`, indices shape: `[seq_len]`.
1198    /// Result shape: `[seq_len, embed_dim]`.
1199    pub fn embedding_lookup(
1200        &mut self,
1201        weight: NodeId,
1202        indices: NodeId,
1203    ) -> Result<NodeId, AutogradError> {
1204        let (value, requires_grad) = {
1205            let wv = &self.nodes[weight.0].value;
1206            let idx = &self.nodes[indices.0].value;
1207            let embed_dim = wv.shape()[1];
1208            let seq_len = idx.data().len();
1209            let w_data = wv.data();
1210            let mut out_data = vec![0.0f32; seq_len * embed_dim];
1211            for (i, &raw_idx) in idx.data().iter().enumerate() {
1212                let row = raw_idx as usize;
1213                let src_offset = row * embed_dim;
1214                let dst_offset = i * embed_dim;
1215                out_data[dst_offset..dst_offset + embed_dim]
1216                    .copy_from_slice(&w_data[src_offset..src_offset + embed_dim]);
1217            }
1218            (
1219                Tensor::from_vec(vec![seq_len, embed_dim], out_data)?,
1220                self.nodes[weight.0].requires_grad,
1221            )
1222        };
1223        Ok(self.push_node(
1224            value,
1225            requires_grad,
1226            Op::EmbeddingLookup { weight, indices },
1227        ))
1228    }
1229
1230    /// NLC 1-D convolution forward.
1231    /// `input` shape \[N,L,C_in\], `weight` shape \[K,C_in,C_out\],
1232    /// optional `bias` shape \[C_out\].
1233    pub fn conv1d_nlc(
1234        &mut self,
1235        input: NodeId,
1236        weight: NodeId,
1237        bias: Option<NodeId>,
1238        stride: usize,
1239    ) -> Result<NodeId, AutogradError> {
1240        let (value, requires_grad) = {
1241            let iv = &self.nodes[input.0].value;
1242            let wv = &self.nodes[weight.0].value;
1243            let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
1244            let rg = self.nodes[input.0].requires_grad
1245                || self.nodes[weight.0].requires_grad
1246                || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
1247            let in_shape = iv.shape();
1248            let w_shape = wv.shape();
1249            let (batch, length, _c_in) = (in_shape[0], in_shape[1], in_shape[2]);
1250            let (kernel_size, c_in, c_out) = (w_shape[0], w_shape[1], w_shape[2]);
1251            let out_len = (length - kernel_size) / stride + 1;
1252            let in_data = iv.data();
1253            let w_data = wv.data();
1254            let mut out = vec![0.0f32; batch * out_len * c_out];
1255            for b in 0..batch {
1256                for ol in 0..out_len {
1257                    let start = ol * stride;
1258                    for oc in 0..c_out {
1259                        let mut sum = 0.0f32;
1260                        for k in 0..kernel_size {
1261                            for ci in 0..c_in {
1262                                sum += in_data[(b * length + start + k) * c_in + ci]
1263                                    * w_data[(k * c_in + ci) * c_out + oc];
1264                            }
1265                        }
1266                        if let Some(bv) = bv {
1267                            sum += bv.data()[oc];
1268                        }
1269                        out[(b * out_len + ol) * c_out + oc] = sum;
1270                    }
1271                }
1272            }
1273            (Tensor::from_vec(vec![batch, out_len, c_out], out)?, rg)
1274        };
1275        Ok(self.push_node(
1276            value,
1277            requires_grad,
1278            Op::Conv1dNlc {
1279                input,
1280                weight,
1281                bias,
1282                stride: stride as u16,
1283            },
1284        ))
1285    }
1286
1287    /// NDHWC 3-D convolution forward (no padding).
1288    /// `input` shape \[N,D,H,W,C_in\], `weight` shape \[KD,KH,KW,C_in,C_out\],
1289    /// optional `bias` shape \[C_out\].
1290    pub fn conv3d_ndhwc(
1291        &mut self,
1292        input: NodeId,
1293        weight: NodeId,
1294        bias: Option<NodeId>,
1295        stride_d: usize,
1296        stride_h: usize,
1297        stride_w: usize,
1298    ) -> Result<NodeId, AutogradError> {
1299        let (value, requires_grad) = {
1300            let iv = &self.nodes[input.0].value;
1301            let wv = &self.nodes[weight.0].value;
1302            let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
1303            let rg = self.nodes[input.0].requires_grad
1304                || self.nodes[weight.0].requires_grad
1305                || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
1306            let (out_data, out_shape) = conv3d(
1307                iv.data(),
1308                iv.shape(),
1309                wv.data(),
1310                wv.shape(),
1311                (stride_d, stride_h, stride_w),
1312                (0, 0, 0),
1313            );
1314            let mut result = Tensor::from_vec(out_shape, out_data)?;
1315            if let Some(bv) = bv {
1316                let c_out = wv.shape()[4];
1317                let data = result.data_mut();
1318                let bias_data = bv.data();
1319                for pixel in data.chunks_mut(c_out) {
1320                    for (v, &bval) in pixel.iter_mut().zip(bias_data.iter()) {
1321                        *v += bval;
1322                    }
1323                }
1324            }
1325            (result, rg)
1326        };
1327        Ok(self.push_node(
1328            value,
1329            requires_grad,
1330            Op::Conv3dNdhwc {
1331                input,
1332                weight,
1333                bias,
1334                stride_d: stride_d as u16,
1335                stride_h: stride_h as u16,
1336                stride_w: stride_w as u16,
1337            },
1338        ))
1339    }
1340
1341    /// Scaled dot-product attention forward.
1342    /// `query` shape `[seq_q, d_k]`, `key` shape `[seq_k, d_k]`, `value` shape `[seq_k, d_v]`.
1343    /// Returns `[seq_q, d_v]`.
1344    pub fn scaled_dot_product_attention(
1345        &mut self,
1346        query: NodeId,
1347        key: NodeId,
1348        value: NodeId,
1349    ) -> Result<NodeId, AutogradError> {
1350        let (output, attn_weights, requires_grad) = {
1351            let qv = &self.nodes[query.0].value;
1352            let kv = &self.nodes[key.0].value;
1353            let vv = &self.nodes[value.0].value;
1354            let rg = self.nodes[query.0].requires_grad
1355                || self.nodes[key.0].requires_grad
1356                || self.nodes[value.0].requires_grad;
1357            let d_k = qv.shape()[1];
1358            let scale = (d_k as f32).sqrt().recip();
1359
1360            // scores = Q @ K^T, scaled
1361            let kt = kv.transpose_2d()?;
1362            let scores = matmul_2d(qv, &kt)?;
1363            let scaled = scores.scale(scale);
1364
1365            // softmax along last dim
1366            let weights = yscv_kernels::softmax_last_dim(&scaled)?;
1367
1368            // output = weights @ V
1369            let out = matmul_2d(&weights, vv)?;
1370            (out, weights, rg)
1371        };
1372        Ok(self.push_node_with_aux(
1373            output,
1374            requires_grad,
1375            Op::ScaledDotProductAttention { query, key, value },
1376            AuxData::AttentionWeights(attn_weights),
1377        ))
1378    }
1379
1380    /// NHWC transposed 2-D convolution forward.
1381    /// `input` shape \[N,H,W,C_in\], `weight` shape \[KH,KW,C_out,C_in\],
1382    /// optional `bias` shape \[C_out\].
1383    /// Output shape: `[N, (H-1)*stride_h + KH, (W-1)*stride_w + KW, C_out]`.
1384    pub fn conv_transpose2d_nhwc(
1385        &mut self,
1386        input: NodeId,
1387        weight: NodeId,
1388        bias: Option<NodeId>,
1389        stride_h: usize,
1390        stride_w: usize,
1391    ) -> Result<NodeId, AutogradError> {
1392        let (value, requires_grad) = {
1393            let iv = &self.nodes[input.0].value;
1394            let wv = &self.nodes[weight.0].value;
1395            let rg = self.nodes[input.0].requires_grad
1396                || self.nodes[weight.0].requires_grad
1397                || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
1398            let in_shape = iv.shape();
1399            let w_shape = wv.shape();
1400            let (n, h, w_dim, c_in) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
1401            let (kh, kw, c_out, _) = (w_shape[0], w_shape[1], w_shape[2], w_shape[3]);
1402            let out_h = (h - 1) * stride_h + kh;
1403            let out_w = (w_dim - 1) * stride_w + kw;
1404            let in_data = iv.data();
1405            let w_data = wv.data();
1406            let mut out = vec![0.0f32; n * out_h * out_w * c_out];
1407            for batch in 0..n {
1408                for ih in 0..h {
1409                    for iw in 0..w_dim {
1410                        for ic in 0..c_in {
1411                            let val = in_data[((batch * h + ih) * w_dim + iw) * c_in + ic];
1412                            for ki in 0..kh {
1413                                for kj in 0..kw {
1414                                    let oh = ih * stride_h + ki;
1415                                    let ow = iw * stride_w + kj;
1416                                    for oc in 0..c_out {
1417                                        let w_idx = ((ki * kw + kj) * c_out + oc) * c_in + ic;
1418                                        out[((batch * out_h + oh) * out_w + ow) * c_out + oc] +=
1419                                            val * w_data[w_idx];
1420                                    }
1421                                }
1422                            }
1423                        }
1424                    }
1425                }
1426            }
1427            if let Some(b_id) = bias {
1428                let bv = &self.nodes[b_id.0].value;
1429                let bd = bv.data();
1430                for i in 0..(n * out_h * out_w) {
1431                    for oc in 0..c_out {
1432                        out[i * c_out + oc] += bd[oc];
1433                    }
1434                }
1435            }
1436            (Tensor::from_vec(vec![n, out_h, out_w, c_out], out)?, rg)
1437        };
1438        Ok(self.push_node(
1439            value,
1440            requires_grad,
1441            Op::ConvTranspose2dNhwc {
1442                input,
1443                weight,
1444                bias,
1445                stride_h: stride_h as u16,
1446                stride_w: stride_w as u16,
1447            },
1448        ))
1449    }
1450
1451    /// NHWC adaptive average pool 2d forward.
1452    /// `input` shape `[N,H,W,C]`, output shape `[N,out_h,out_w,C]`.
1453    pub fn adaptive_avg_pool2d_nhwc(
1454        &mut self,
1455        input: NodeId,
1456        out_h: usize,
1457        out_w: usize,
1458    ) -> Result<NodeId, AutogradError> {
1459        let (value, requires_grad) = {
1460            let iv = &self.nodes[input.0].value;
1461            let shape = iv.shape();
1462            if shape.len() != 4 {
1463                return Err(AutogradError::InvalidRankForOperation {
1464                    op: "adaptive_avg_pool2d_nhwc",
1465                    expected: 4,
1466                    got: shape.len(),
1467                });
1468            }
1469            let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1470            let data = iv.data();
1471            let mut out = vec![0.0f32; n * out_h * out_w * c];
1472            for b in 0..n {
1473                for oh in 0..out_h {
1474                    let h_start = oh * h / out_h;
1475                    let h_end = ((oh + 1) * h / out_h).max(h_start + 1);
1476                    for ow in 0..out_w {
1477                        let w_start = ow * w / out_w;
1478                        let w_end = ((ow + 1) * w / out_w).max(w_start + 1);
1479                        let count = (h_end - h_start) * (w_end - w_start);
1480                        for ch in 0..c {
1481                            let mut sum = 0.0f32;
1482                            for ih in h_start..h_end {
1483                                for iw in w_start..w_end {
1484                                    sum += data[((b * h + ih) * w + iw) * c + ch];
1485                                }
1486                            }
1487                            out[((b * out_h + oh) * out_w + ow) * c + ch] = sum / count as f32;
1488                        }
1489                    }
1490                }
1491            }
1492            (
1493                Tensor::from_vec(vec![n, out_h, out_w, c], out)?,
1494                self.nodes[input.0].requires_grad,
1495            )
1496        };
1497        Ok(self.push_node(
1498            value,
1499            requires_grad,
1500            Op::AdaptiveAvgPool2dNhwc {
1501                input,
1502                out_h: out_h as u16,
1503                out_w: out_w as u16,
1504            },
1505        ))
1506    }
1507
1508    /// NHWC adaptive max pool 2d forward with argmax tracking for backward.
1509    pub fn adaptive_max_pool2d_nhwc(
1510        &mut self,
1511        input: NodeId,
1512        out_h: usize,
1513        out_w: usize,
1514    ) -> Result<NodeId, AutogradError> {
1515        let (value, requires_grad, indices) = {
1516            let iv = &self.nodes[input.0].value;
1517            let shape = iv.shape();
1518            if shape.len() != 4 {
1519                return Err(AutogradError::InvalidRankForOperation {
1520                    op: "adaptive_max_pool2d_nhwc",
1521                    expected: 4,
1522                    got: shape.len(),
1523                });
1524            }
1525            let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1526            let data = iv.data();
1527            let out_len = n * out_h * out_w * c;
1528            let mut out = vec![f32::NEG_INFINITY; out_len];
1529            let mut indices = vec![0usize; out_len];
1530            for b in 0..n {
1531                for oh in 0..out_h {
1532                    let h_start = oh * h / out_h;
1533                    let h_end = ((oh + 1) * h / out_h).max(h_start + 1);
1534                    for ow in 0..out_w {
1535                        let w_start = ow * w / out_w;
1536                        let w_end = ((ow + 1) * w / out_w).max(w_start + 1);
1537                        for ch in 0..c {
1538                            let out_idx = ((b * out_h + oh) * out_w + ow) * c + ch;
1539                            let mut best_val = f32::NEG_INFINITY;
1540                            let mut best_in = 0usize;
1541                            for ih in h_start..h_end {
1542                                for iw in w_start..w_end {
1543                                    let in_idx = ((b * h + ih) * w + iw) * c + ch;
1544                                    let v = data[in_idx];
1545                                    if v > best_val {
1546                                        best_val = v;
1547                                        best_in = in_idx;
1548                                    }
1549                                }
1550                            }
1551                            out[out_idx] = best_val;
1552                            indices[out_idx] = best_in;
1553                        }
1554                    }
1555                }
1556            }
1557            (
1558                Tensor::from_vec(vec![n, out_h, out_w, c], out)?,
1559                self.nodes[input.0].requires_grad,
1560                indices,
1561            )
1562        };
1563        Ok(self.push_node_with_aux(
1564            value,
1565            requires_grad,
1566            Op::AdaptiveMaxPool2dNhwc {
1567                input,
1568                out_h: out_h as u16,
1569                out_w: out_w as u16,
1570            },
1571            AuxData::MaxPoolIndices(indices),
1572        ))
1573    }
1574
1575    /// Instance normalization (NHWC) forward.
1576    /// Normalizes per (N,C) pair across H*W spatial dimensions.
1577    /// `gamma` and `beta` must have shape `[C]`.
1578    pub fn instance_norm_nhwc(
1579        &mut self,
1580        input: NodeId,
1581        gamma: NodeId,
1582        beta: NodeId,
1583        epsilon: f32,
1584    ) -> Result<NodeId, AutogradError> {
1585        let eps_bits = epsilon.to_bits();
1586        let (value, requires_grad, norm_tensor) = {
1587            let iv = &self.nodes[input.0].value;
1588            let gv = &self.nodes[gamma.0].value;
1589            let bv = &self.nodes[beta.0].value;
1590            let shape = iv.shape();
1591            if shape.len() != 4 {
1592                return Err(AutogradError::InvalidRankForOperation {
1593                    op: "instance_norm_nhwc",
1594                    expected: 4,
1595                    got: shape.len(),
1596                });
1597            }
1598            let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1599            let spatial = h * w;
1600            let data = iv.data();
1601            let gamma_data = gv.data();
1602            let beta_data = bv.data();
1603            let mut out = vec![0.0f32; data.len()];
1604            let mut normalized = vec![0.0f32; data.len()];
1605
1606            for ni in 0..n {
1607                for ch in 0..c {
1608                    let mut sum = 0.0f32;
1609                    for s in 0..spatial {
1610                        let idx = (ni * h * w + s) * c + ch;
1611                        sum += data[idx];
1612                    }
1613                    let mean = sum / spatial as f32;
1614                    let mut var_sum = 0.0f32;
1615                    for s in 0..spatial {
1616                        let idx = (ni * h * w + s) * c + ch;
1617                        let d = data[idx] - mean;
1618                        var_sum += d * d;
1619                    }
1620                    let inv_std = 1.0 / (var_sum / spatial as f32 + epsilon).sqrt();
1621                    for s in 0..spatial {
1622                        let idx = (ni * h * w + s) * c + ch;
1623                        let x_hat = (data[idx] - mean) * inv_std;
1624                        normalized[idx] = x_hat;
1625                        out[idx] = x_hat * gamma_data[ch] + beta_data[ch];
1626                    }
1627                }
1628            }
1629            let value = Tensor::from_vec(shape.to_vec(), out)?;
1630            let norm_tensor = Tensor::from_vec(shape.to_vec(), normalized)?;
1631            let rg = self.nodes[input.0].requires_grad
1632                || self.nodes[gamma.0].requires_grad
1633                || self.nodes[beta.0].requires_grad;
1634            (value, rg, norm_tensor)
1635        };
1636        Ok(self.push_node_with_aux(
1637            value,
1638            requires_grad,
1639            Op::InstanceNormNhwc {
1640                input,
1641                gamma,
1642                beta,
1643                eps_bits,
1644            },
1645            AuxData::NormNormalized(norm_tensor),
1646        ))
1647    }
1648
1649    /// PReLU activation forward.
1650    /// `alpha` is a parameter node with shape `[C]` or `[1]`.
1651    /// For NHWC inputs, channels are the last dimension.
1652    pub fn prelu(&mut self, input: NodeId, alpha: NodeId) -> Result<NodeId, AutogradError> {
1653        let (value, requires_grad) = {
1654            let iv = &self.nodes[input.0].value;
1655            let av = &self.nodes[alpha.0].value;
1656            let in_data = iv.data();
1657            let alpha_data = av.data();
1658            let alpha_len = alpha_data.len();
1659            let out: Vec<f32> = in_data
1660                .iter()
1661                .enumerate()
1662                .map(|(i, &x)| {
1663                    if x > 0.0 {
1664                        x
1665                    } else {
1666                        let a = if alpha_len == 1 {
1667                            alpha_data[0]
1668                        } else {
1669                            alpha_data[i % alpha_len]
1670                        };
1671                        a * x
1672                    }
1673                })
1674                .collect();
1675            let rg = self.nodes[input.0].requires_grad || self.nodes[alpha.0].requires_grad;
1676            (Tensor::from_vec(iv.shape().to_vec(), out)?, rg)
1677        };
1678        Ok(self.push_node(value, requires_grad, Op::PRelu { input, alpha }))
1679    }
1680
1681    /// Pixel shuffle forward: rearranges [N, H, W, C*r^2] -> [N, H*r, W*r, C].
1682    pub fn pixel_shuffle(
1683        &mut self,
1684        input: NodeId,
1685        upscale_factor: usize,
1686    ) -> Result<NodeId, AutogradError> {
1687        let (value, requires_grad) = {
1688            let iv = &self.nodes[input.0].value;
1689            let shape = iv.shape();
1690            if shape.len() != 4 {
1691                return Err(AutogradError::InvalidRankForOperation {
1692                    op: "pixel_shuffle",
1693                    expected: 4,
1694                    got: shape.len(),
1695                });
1696            }
1697            let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1698            let r = upscale_factor;
1699            let out_c = c / (r * r);
1700            let out_h = h * r;
1701            let out_w = w * r;
1702            let data = iv.data();
1703            let mut out = vec![0.0f32; batch * out_h * out_w * out_c];
1704            for b in 0..batch {
1705                for ih in 0..h {
1706                    for iw in 0..w {
1707                        for oc in 0..out_c {
1708                            for ry in 0..r {
1709                                for rx in 0..r {
1710                                    let ic = oc * r * r + ry * r + rx;
1711                                    let oh = ih * r + ry;
1712                                    let ow = iw * r + rx;
1713                                    out[((b * out_h + oh) * out_w + ow) * out_c + oc] =
1714                                        data[((b * h + ih) * w + iw) * c + ic];
1715                                }
1716                            }
1717                        }
1718                    }
1719                }
1720            }
1721            (
1722                Tensor::from_vec(vec![batch, out_h, out_w, out_c], out)?,
1723                self.nodes[input.0].requires_grad,
1724            )
1725        };
1726        Ok(self.push_node(
1727            value,
1728            requires_grad,
1729            Op::PixelShuffle {
1730                input,
1731                upscale_factor: upscale_factor as u16,
1732            },
1733        ))
1734    }
1735
1736    /// Nearest-neighbor upsample forward: [N, H, W, C] -> [N, H*r, W*r, C].
1737    pub fn upsample_nearest(
1738        &mut self,
1739        input: NodeId,
1740        scale_factor: usize,
1741    ) -> Result<NodeId, AutogradError> {
1742        let (value, requires_grad) = {
1743            let iv = &self.nodes[input.0].value;
1744            let shape = iv.shape();
1745            if shape.len() != 4 {
1746                return Err(AutogradError::InvalidRankForOperation {
1747                    op: "upsample_nearest",
1748                    expected: 4,
1749                    got: shape.len(),
1750                });
1751            }
1752            let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1753            let r = scale_factor;
1754            let out_h = h * r;
1755            let out_w = w * r;
1756            let data = iv.data();
1757            let mut out = vec![0.0f32; batch * out_h * out_w * c];
1758            for b in 0..batch {
1759                for oh in 0..out_h {
1760                    let ih = oh / r;
1761                    for ow in 0..out_w {
1762                        let iw = ow / r;
1763                        let src = ((b * h + ih) * w + iw) * c;
1764                        let dst = ((b * out_h + oh) * out_w + ow) * c;
1765                        out[dst..dst + c].copy_from_slice(&data[src..src + c]);
1766                    }
1767                }
1768            }
1769            (
1770                Tensor::from_vec(vec![batch, out_h, out_w, c], out)?,
1771                self.nodes[input.0].requires_grad,
1772            )
1773        };
1774        Ok(self.push_node(
1775            value,
1776            requires_grad,
1777            Op::UpsampleNearest {
1778                input,
1779                scale_factor: scale_factor as u16,
1780            },
1781        ))
1782    }
1783
1784    /// RNN forward pass through all timesteps (for BPTT).
1785    /// input: `[seq_len, input_size]`, w_ih: `[input_size, hidden_size]`,
1786    /// w_hh: `[hidden_size, hidden_size]`, bias: `[hidden_size]`.
1787    /// Returns output `[seq_len, hidden_size]`.
1788    pub fn rnn_forward(
1789        &mut self,
1790        input: NodeId,
1791        w_ih: NodeId,
1792        w_hh: NodeId,
1793        bias: NodeId,
1794    ) -> Result<NodeId, AutogradError> {
1795        let (value, requires_grad, hidden_states) = {
1796            let iv = &self.nodes[input.0].value;
1797            let wih = &self.nodes[w_ih.0].value;
1798            let whh = &self.nodes[w_hh.0].value;
1799            let bv = &self.nodes[bias.0].value;
1800            let shape = iv.shape();
1801            let seq_len = shape[0];
1802            let hidden_size = wih.shape()[1];
1803            let in_data = iv.data();
1804            let wih_data = wih.data();
1805            let whh_data = whh.data();
1806            let b_data = bv.data();
1807            let input_size = shape[1];
1808
1809            let mut hidden_states = Vec::with_capacity(seq_len + 1);
1810            // h_0 = zeros
1811            hidden_states.push(Tensor::zeros(vec![hidden_size])?);
1812            let mut output_data = vec![0.0f32; seq_len * hidden_size];
1813
1814            for t in 0..seq_len {
1815                let h_prev = hidden_states[t].data();
1816                let x_base = t * input_size;
1817                let mut h_new = vec![0.0f32; hidden_size];
1818                for j in 0..hidden_size {
1819                    let mut sum = b_data[j];
1820                    for i in 0..input_size {
1821                        sum += in_data[x_base + i] * wih_data[i * hidden_size + j];
1822                    }
1823                    for i in 0..hidden_size {
1824                        sum += h_prev[i] * whh_data[i * hidden_size + j];
1825                    }
1826                    h_new[j] = sum.tanh();
1827                }
1828                output_data[t * hidden_size..(t + 1) * hidden_size].copy_from_slice(&h_new);
1829                hidden_states.push(Tensor::from_vec(vec![hidden_size], h_new)?);
1830            }
1831
1832            let rg = self.nodes[input.0].requires_grad
1833                || self.nodes[w_ih.0].requires_grad
1834                || self.nodes[w_hh.0].requires_grad
1835                || self.nodes[bias.0].requires_grad;
1836            (
1837                Tensor::from_vec(vec![seq_len, hidden_size], output_data)?,
1838                rg,
1839                hidden_states,
1840            )
1841        };
1842        Ok(self.push_node_with_aux(
1843            value,
1844            requires_grad,
1845            Op::Rnn {
1846                input,
1847                w_ih,
1848                w_hh,
1849                bias,
1850            },
1851            AuxData::RnnHiddenStates(hidden_states),
1852        ))
1853    }
1854
1855    /// LSTM forward pass through all timesteps (for BPTT).
1856    /// input: `[seq_len, input_size]`, w_ih: `[input_size, 4*hidden_size]`,
1857    /// w_hh: `[hidden_size, 4*hidden_size]`, bias: `[4*hidden_size]`.
1858    /// Returns output `[seq_len, hidden_size]`.
1859    pub fn lstm_forward(
1860        &mut self,
1861        input: NodeId,
1862        w_ih: NodeId,
1863        w_hh: NodeId,
1864        bias: NodeId,
1865    ) -> Result<NodeId, AutogradError> {
1866        let (value, requires_grad, hidden_states, cell_states, gates) = {
1867            let iv = &self.nodes[input.0].value;
1868            let wih = &self.nodes[w_ih.0].value;
1869            let whh = &self.nodes[w_hh.0].value;
1870            let bv = &self.nodes[bias.0].value;
1871            let shape = iv.shape();
1872            let seq_len = shape[0];
1873            let input_size = shape[1];
1874            let hidden_size = wih.shape()[1] / 4;
1875            let in_data = iv.data();
1876            let wih_data = wih.data();
1877            let whh_data = whh.data();
1878            let b_data = bv.data();
1879
1880            let mut hidden_states = Vec::with_capacity(seq_len + 1);
1881            let mut cell_states = Vec::with_capacity(seq_len + 1);
1882            let mut gates_vec: Vec<(Tensor, Tensor, Tensor, Tensor)> = Vec::with_capacity(seq_len);
1883            hidden_states.push(Tensor::zeros(vec![hidden_size])?);
1884            cell_states.push(Tensor::zeros(vec![hidden_size])?);
1885            let mut output_data = vec![0.0f32; seq_len * hidden_size];
1886
1887            for t in 0..seq_len {
1888                let h_prev = hidden_states[t].data();
1889                let c_prev = cell_states[t].data();
1890                let x_base = t * input_size;
1891                let h4 = 4 * hidden_size;
1892
1893                // Compute gates: [i, f, g, o] = x @ W_ih + h @ W_hh + bias
1894                let mut raw_gates = vec![0.0f32; h4];
1895                for j in 0..h4 {
1896                    let mut sum = b_data[j];
1897                    for i in 0..input_size {
1898                        sum += in_data[x_base + i] * wih_data[i * h4 + j];
1899                    }
1900                    for i in 0..hidden_size {
1901                        sum += h_prev[i] * whh_data[i * h4 + j];
1902                    }
1903                    raw_gates[j] = sum;
1904                }
1905
1906                let mut i_gate = vec![0.0f32; hidden_size];
1907                let mut f_gate = vec![0.0f32; hidden_size];
1908                let mut g_gate = vec![0.0f32; hidden_size];
1909                let mut o_gate = vec![0.0f32; hidden_size];
1910                let mut c_new = vec![0.0f32; hidden_size];
1911                let mut h_new = vec![0.0f32; hidden_size];
1912
1913                for j in 0..hidden_size {
1914                    i_gate[j] = sigmoid_f32(raw_gates[j]);
1915                    f_gate[j] = sigmoid_f32(raw_gates[hidden_size + j]);
1916                    g_gate[j] = raw_gates[2 * hidden_size + j].tanh();
1917                    o_gate[j] = sigmoid_f32(raw_gates[3 * hidden_size + j]);
1918                    c_new[j] = f_gate[j] * c_prev[j] + i_gate[j] * g_gate[j];
1919                    h_new[j] = o_gate[j] * c_new[j].tanh();
1920                }
1921
1922                output_data[t * hidden_size..(t + 1) * hidden_size].copy_from_slice(&h_new);
1923                hidden_states.push(Tensor::from_vec(vec![hidden_size], h_new)?);
1924                cell_states.push(Tensor::from_vec(vec![hidden_size], c_new)?);
1925                gates_vec.push((
1926                    Tensor::from_vec(vec![hidden_size], i_gate)?,
1927                    Tensor::from_vec(vec![hidden_size], f_gate)?,
1928                    Tensor::from_vec(vec![hidden_size], g_gate)?,
1929                    Tensor::from_vec(vec![hidden_size], o_gate)?,
1930                ));
1931            }
1932
1933            let rg = self.nodes[input.0].requires_grad
1934                || self.nodes[w_ih.0].requires_grad
1935                || self.nodes[w_hh.0].requires_grad
1936                || self.nodes[bias.0].requires_grad;
1937            (
1938                Tensor::from_vec(vec![seq_len, hidden_size], output_data)?,
1939                rg,
1940                hidden_states,
1941                cell_states,
1942                gates_vec,
1943            )
1944        };
1945        Ok(self.push_node_with_aux(
1946            value,
1947            requires_grad,
1948            Op::Lstm {
1949                input,
1950                w_ih,
1951                w_hh,
1952                bias,
1953            },
1954            AuxData::LstmStates {
1955                hidden_states,
1956                cell_states,
1957                gates,
1958            },
1959        ))
1960    }
1961
1962    /// GRU forward pass through all timesteps (for BPTT).
1963    /// input: `[seq_len, input_size]`, w_ih: `[input_size, 3*hidden_size]`,
1964    /// w_hh: `[hidden_size, 3*hidden_size]`, bias_ih: `[3*hidden_size]`, bias_hh: `[3*hidden_size]`.
1965    /// Returns output `[seq_len, hidden_size]`.
1966    pub fn gru_forward(
1967        &mut self,
1968        input: NodeId,
1969        w_ih: NodeId,
1970        w_hh: NodeId,
1971        bias_ih: NodeId,
1972        bias_hh: NodeId,
1973    ) -> Result<NodeId, AutogradError> {
1974        let (value, requires_grad, hidden_states, gates) = {
1975            let iv = &self.nodes[input.0].value;
1976            let wih = &self.nodes[w_ih.0].value;
1977            let whh = &self.nodes[w_hh.0].value;
1978            let bih = &self.nodes[bias_ih.0].value;
1979            let bhh = &self.nodes[bias_hh.0].value;
1980            let shape = iv.shape();
1981            let seq_len = shape[0];
1982            let input_size = shape[1];
1983            let hidden_size = wih.shape()[1] / 3;
1984            let in_data = iv.data();
1985            let wih_data = wih.data();
1986            let whh_data = whh.data();
1987            let bih_data = bih.data();
1988            let bhh_data = bhh.data();
1989
1990            let mut hidden_states = Vec::with_capacity(seq_len + 1);
1991            let mut gates_vec: Vec<(Tensor, Tensor, Tensor)> = Vec::with_capacity(seq_len);
1992            hidden_states.push(Tensor::zeros(vec![hidden_size])?);
1993            let mut output_data = vec![0.0f32; seq_len * hidden_size];
1994
1995            for t in 0..seq_len {
1996                let h_prev = hidden_states[t].data();
1997                let x_base = t * input_size;
1998                let h3 = 3 * hidden_size;
1999
2000                // x_proj = x @ W_ih + bias_ih
2001                let mut x_proj = vec![0.0f32; h3];
2002                for j in 0..h3 {
2003                    let mut sum = bih_data[j];
2004                    for i in 0..input_size {
2005                        sum += in_data[x_base + i] * wih_data[i * h3 + j];
2006                    }
2007                    x_proj[j] = sum;
2008                }
2009                // h_proj = h @ W_hh + bias_hh
2010                let mut h_proj = vec![0.0f32; h3];
2011                for j in 0..h3 {
2012                    let mut sum = bhh_data[j];
2013                    for i in 0..hidden_size {
2014                        sum += h_prev[i] * whh_data[i * h3 + j];
2015                    }
2016                    h_proj[j] = sum;
2017                }
2018
2019                let mut r_gate = vec![0.0f32; hidden_size];
2020                let mut z_gate = vec![0.0f32; hidden_size];
2021                let mut n_candidate = vec![0.0f32; hidden_size];
2022                let mut h_new = vec![0.0f32; hidden_size];
2023
2024                for j in 0..hidden_size {
2025                    r_gate[j] = sigmoid_f32(x_proj[j] + h_proj[j]);
2026                    z_gate[j] = sigmoid_f32(x_proj[hidden_size + j] + h_proj[hidden_size + j]);
2027                    n_candidate[j] = (x_proj[2 * hidden_size + j]
2028                        + r_gate[j] * h_proj[2 * hidden_size + j])
2029                        .tanh();
2030                    h_new[j] = (1.0 - z_gate[j]) * n_candidate[j] + z_gate[j] * h_prev[j];
2031                }
2032
2033                output_data[t * hidden_size..(t + 1) * hidden_size].copy_from_slice(&h_new);
2034                hidden_states.push(Tensor::from_vec(vec![hidden_size], h_new)?);
2035                gates_vec.push((
2036                    Tensor::from_vec(vec![hidden_size], r_gate)?,
2037                    Tensor::from_vec(vec![hidden_size], z_gate)?,
2038                    Tensor::from_vec(vec![hidden_size], n_candidate)?,
2039                ));
2040            }
2041
2042            let rg = self.nodes[input.0].requires_grad
2043                || self.nodes[w_ih.0].requires_grad
2044                || self.nodes[w_hh.0].requires_grad
2045                || self.nodes[bias_ih.0].requires_grad
2046                || self.nodes[bias_hh.0].requires_grad;
2047            (
2048                Tensor::from_vec(vec![seq_len, hidden_size], output_data)?,
2049                rg,
2050                hidden_states,
2051                gates_vec,
2052            )
2053        };
2054        Ok(self.push_node_with_aux(
2055            value,
2056            requires_grad,
2057            Op::Gru {
2058                input,
2059                w_ih,
2060                w_hh,
2061                bias_ih,
2062                bias_hh,
2063            },
2064            AuxData::GruStates {
2065                hidden_states,
2066                gates,
2067            },
2068        ))
2069    }
2070
2071    /// Deformable conv2d NHWC forward.
2072    /// input: \[N,H,W,C_in\], weight: \[KH,KW,C_in,C_out\], offsets: \[N,OH,OW,KH\*KW\*2\].
2073    pub fn deformable_conv2d_nhwc(
2074        &mut self,
2075        input: NodeId,
2076        weight: NodeId,
2077        offsets: NodeId,
2078        bias: Option<NodeId>,
2079        stride: usize,
2080        padding: usize,
2081    ) -> Result<NodeId, AutogradError> {
2082        let (value, requires_grad) = {
2083            let iv = &self.nodes[input.0].value;
2084            let wv = &self.nodes[weight.0].value;
2085            let ov = &self.nodes[offsets.0].value;
2086            let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
2087            let rg = self.nodes[input.0].requires_grad
2088                || self.nodes[weight.0].requires_grad
2089                || self.nodes[offsets.0].requires_grad
2090                || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
2091            let result = yscv_kernels::deformable_conv2d_nhwc(iv, wv, ov, bv, stride, padding)?;
2092            (result, rg)
2093        };
2094        Ok(self.push_node(
2095            value,
2096            requires_grad,
2097            Op::DeformableConv2dNhwc {
2098                input,
2099                weight,
2100                offsets,
2101                bias,
2102                stride: stride as u16,
2103                padding: padding as u16,
2104            },
2105        ))
2106    }
2107
2108    pub(crate) fn push_node(&mut self, value: Tensor, requires_grad: bool, op: Op) -> NodeId {
2109        let id = NodeId(self.nodes.len());
2110        self.nodes.push(Node {
2111            value,
2112            grad: None,
2113            requires_grad,
2114            op,
2115            aux: None,
2116        });
2117        id
2118    }
2119
2120    pub(crate) fn push_node_with_aux(
2121        &mut self,
2122        value: Tensor,
2123        requires_grad: bool,
2124        op: Op,
2125        aux: super::node::AuxData,
2126    ) -> NodeId {
2127        let id = NodeId(self.nodes.len());
2128        self.nodes.push(Node {
2129            value,
2130            grad: None,
2131            requires_grad,
2132            op,
2133            aux: Some(aux),
2134        });
2135        id
2136    }
2137
2138    pub(crate) fn node(&self, id: NodeId) -> Result<&Node, AutogradError> {
2139        self.nodes
2140            .get(id.0)
2141            .ok_or(AutogradError::NodeNotFound { id: id.0 })
2142    }
2143
2144    pub(crate) fn node_mut(&mut self, id: NodeId) -> Result<&mut Node, AutogradError> {
2145        self.nodes
2146            .get_mut(id.0)
2147            .ok_or(AutogradError::NodeNotFound { id: id.0 })
2148    }
2149
2150    /// Clips gradients by global L2 norm. Returns the original norm before clipping.
2151    /// If total_norm > max_norm, scales all gradients by max_norm / total_norm.
2152    pub fn clip_grad_norm(&mut self, param_nodes: &[NodeId], max_norm: f32) -> f32 {
2153        let mut total_norm_sq = 0.0f32;
2154        for &node_id in param_nodes {
2155            if let Some(grad) = &self.nodes[node_id.0].grad {
2156                for &v in grad.data() {
2157                    total_norm_sq += v * v;
2158                }
2159            }
2160        }
2161        let total_norm = total_norm_sq.sqrt();
2162
2163        if total_norm > max_norm {
2164            let scale = max_norm / total_norm;
2165            for &node_id in param_nodes {
2166                if let Some(grad) = &mut self.nodes[node_id.0].grad {
2167                    for v in grad.data_mut() {
2168                        *v *= scale;
2169                    }
2170                }
2171            }
2172        }
2173        total_norm
2174    }
2175
2176    /// Clips each gradient element to [-max_value, max_value].
2177    pub fn clip_grad_value(&mut self, param_nodes: &[NodeId], max_value: f32) {
2178        for &node_id in param_nodes {
2179            if let Some(grad) = &mut self.nodes[node_id.0].grad {
2180                for v in grad.data_mut() {
2181                    *v = v.clamp(-max_value, max_value);
2182                }
2183            }
2184        }
2185    }
2186}
2187
2188/// Softmax along last dimension (local utility to avoid tight coupling with kernel version).
2189fn softmax_last_dim(input: &Tensor) -> Tensor {
2190    let shape = input.shape();
2191    if shape.is_empty() {
2192        return input.clone();
2193    }
2194    let last = *shape.last().expect("non-empty shape");
2195    let outer = input.len() / last;
2196    let data = input.data();
2197    let mut out = vec![0.0f32; input.len()];
2198
2199    for o in 0..outer {
2200        let base = o * last;
2201        let max_val = data[base..base + last]
2202            .iter()
2203            .copied()
2204            .fold(f32::NEG_INFINITY, f32::max);
2205        let mut sum = 0.0f32;
2206        for i in 0..last {
2207            let e = (data[base + i] - max_val).exp();
2208            out[base + i] = e;
2209            sum += e;
2210        }
2211        let inv = 1.0 / sum;
2212        for i in 0..last {
2213            out[base + i] *= inv;
2214        }
2215    }
2216
2217    Tensor::from_vec(shape.to_vec(), out).expect("softmax_last_dim preserves shape")
2218}
2219
2220#[inline]
2221fn sigmoid_f32(x: f32) -> f32 {
2222    1.0 / (1.0 + (-x).exp())
2223}
2224
2225/// Copy src data (with `src_shape`) into a region of dst data (with `dst_shape`) at the given offsets.
2226#[allow(clippy::needless_range_loop)]
2227fn copy_region_nd(
2228    src: &[f32],
2229    src_shape: &[usize],
2230    dst: &mut [f32],
2231    dst_shape: &[usize],
2232    offsets: &[u32],
2233) {
2234    let rank = src_shape.len();
2235    if rank == 0 {
2236        return;
2237    }
2238    let total: usize = src_shape.iter().product();
2239    let mut src_strides = vec![1usize; rank];
2240    let mut dst_strides = vec![1usize; rank];
2241    for d in (0..rank - 1).rev() {
2242        src_strides[d] = src_strides[d + 1] * src_shape[d + 1];
2243        dst_strides[d] = dst_strides[d + 1] * dst_shape[d + 1];
2244    }
2245    for flat in 0..total {
2246        let mut rem = flat;
2247        let mut dst_flat = 0usize;
2248        for d in 0..rank {
2249            let coord = rem / src_strides[d];
2250            rem %= src_strides[d];
2251            dst_flat += (coord + offsets[d] as usize) * dst_strides[d];
2252        }
2253        dst[dst_flat] = src[flat];
2254    }
2255}