1use std::num::NonZeroUsize;
2
3use yscv_kernels::{
4 BackwardOps, BatchNorm2dParams, ThreadedCpuBackend, add as kernel_add, avg_pool2d_nhwc,
5 batch_norm2d_nhwc, conv2d_nhwc, conv3d, depthwise_conv2d_nhwc, gelu as kernel_gelu, matmul_2d,
6 mish as kernel_mish, mul as kernel_mul, relu, sigmoid as kernel_sigmoid, silu as kernel_silu,
7 sub as kernel_sub,
8};
9use yscv_tensor::Tensor;
10
11use super::error::AutogradError;
12use super::node::{AuxData, Node, NodeId, Op};
13
14pub struct Graph {
16 pub(crate) nodes: Vec<Node>,
17 pub(crate) backend: Option<Box<dyn BackwardOps>>,
18}
19
20impl Default for Graph {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl Graph {
27 pub fn new() -> Self {
31 let threads = std::thread::available_parallelism()
32 .unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
33 let backend = ThreadedCpuBackend::new(threads)
34 .ok()
35 .map(|b| Box::new(b) as Box<dyn BackwardOps>);
36 Self {
37 nodes: Vec::new(),
38 backend,
39 }
40 }
41
42 pub fn new_single_threaded() -> Self {
44 Self {
45 nodes: Vec::new(),
46 backend: None,
47 }
48 }
49
50 pub fn set_backend(&mut self, backend: Box<dyn BackwardOps>) {
54 self.backend = Some(backend);
55 }
56
57 pub fn clear_backend(&mut self) {
59 self.backend = None;
60 }
61
62 pub fn variable(&mut self, value: Tensor) -> NodeId {
64 self.push_node(value, true, Op::Leaf)
65 }
66
67 pub fn constant(&mut self, value: Tensor) -> NodeId {
69 self.push_node(value, false, Op::Leaf)
70 }
71
72 pub fn value(&self, node: NodeId) -> Result<&Tensor, AutogradError> {
74 Ok(&self.node(node)?.value)
75 }
76
77 pub fn value_mut(&mut self, node: NodeId) -> Result<&mut Tensor, AutogradError> {
79 Ok(&mut self.node_mut(node)?.value)
80 }
81
82 pub fn requires_grad(&self, node: NodeId) -> Result<bool, AutogradError> {
84 Ok(self.node(node)?.requires_grad)
85 }
86
87 pub fn grad(&self, node: NodeId) -> Result<Option<&Tensor>, AutogradError> {
89 Ok(self.node(node)?.grad.as_ref())
90 }
91
92 pub fn grad_mut(&mut self, node: NodeId) -> Result<Option<&mut Tensor>, AutogradError> {
94 Ok(self.node_mut(node)?.grad.as_mut())
95 }
96
97 pub fn set_grad(&mut self, node: NodeId, grad: Tensor) -> Result<(), AutogradError> {
99 self.node_mut(node)?.grad = Some(grad);
100 Ok(())
101 }
102
103 pub fn node_count(&self) -> usize {
105 self.nodes.len()
106 }
107
108 pub fn zero_grads(&mut self) {
110 for node in &mut self.nodes {
111 node.grad = None;
112 }
113 }
114
115 pub fn truncate(&mut self, keep_nodes: usize) -> Result<(), AutogradError> {
117 if keep_nodes > self.nodes.len() {
118 return Err(AutogradError::InvalidTruncate {
119 requested: keep_nodes,
120 available: self.nodes.len(),
121 });
122 }
123 self.nodes.truncate(keep_nodes);
124 Ok(())
125 }
126
127 pub fn add(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
129 let (value, requires_grad) = {
130 let lv = &self.nodes[left.0].value;
131 let rv = &self.nodes[right.0].value;
132 let result = if let Some(ref backend) = self.backend {
133 backend.add(lv, rv)?
134 } else {
135 kernel_add(lv, rv)?
136 };
137 (
138 result,
139 self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
140 )
141 };
142 Ok(self.push_node(value, requires_grad, Op::Add(left, right)))
143 }
144
145 pub fn sub(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
147 let (value, requires_grad) = {
148 let lv = &self.nodes[left.0].value;
149 let rv = &self.nodes[right.0].value;
150 let result = if let Some(ref backend) = self.backend {
151 backend.sub(lv, rv)?
152 } else {
153 kernel_sub(lv, rv)?
154 };
155 (
156 result,
157 self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
158 )
159 };
160 Ok(self.push_node(value, requires_grad, Op::Sub(left, right)))
161 }
162
163 pub fn mul(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
165 let (value, requires_grad) = {
166 let lv = &self.nodes[left.0].value;
167 let rv = &self.nodes[right.0].value;
168 let result = if let Some(ref backend) = self.backend {
169 backend.mul(lv, rv)?
170 } else {
171 kernel_mul(lv, rv)?
172 };
173 (
174 result,
175 self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
176 )
177 };
178 Ok(self.push_node(value, requires_grad, Op::Mul(left, right)))
179 }
180
181 pub fn relu(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
183 let (value, requires_grad) = {
184 let v = &self.nodes[input.0].value;
185 let result = if let Some(ref backend) = self.backend {
186 backend.relu(v)
187 } else {
188 relu(v)
189 };
190 (result, self.nodes[input.0].requires_grad)
191 };
192 Ok(self.push_node(value, requires_grad, Op::Relu(input)))
193 }
194
195 pub fn matmul_2d(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
197 let (value, requires_grad) = {
198 let lv = &self.nodes[left.0].value;
199 let rv = &self.nodes[right.0].value;
200 let result = if let Some(ref backend) = self.backend {
201 backend.matmul_2d(lv, rv)?
202 } else {
203 matmul_2d(lv, rv)?
204 };
205 (
206 result,
207 self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
208 )
209 };
210 Ok(self.push_node(value, requires_grad, Op::MatMul2D(left, right)))
211 }
212
213 pub fn div(&mut self, left: NodeId, right: NodeId) -> Result<NodeId, AutogradError> {
215 let (value, requires_grad) = {
216 let lv = &self.nodes[left.0].value;
217 let rv = &self.nodes[right.0].value;
218 (
219 lv.div(rv)?,
220 self.nodes[left.0].requires_grad || self.nodes[right.0].requires_grad,
221 )
222 };
223 Ok(self.push_node(value, requires_grad, Op::Div(left, right)))
224 }
225
226 pub fn neg(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
228 let (value, requires_grad) = {
229 let v = &self.nodes[input.0].value;
230 (v.neg(), self.nodes[input.0].requires_grad)
231 };
232 Ok(self.push_node(value, requires_grad, Op::Neg(input)))
233 }
234
235 pub fn exp(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
237 let (value, requires_grad) = {
238 let v = &self.nodes[input.0].value;
239 (v.exp(), self.nodes[input.0].requires_grad)
240 };
241 Ok(self.push_node(value, requires_grad, Op::Exp(input)))
242 }
243
244 pub fn log(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
246 let (value, requires_grad) = {
247 let v = &self.nodes[input.0].value;
248 (v.ln(), self.nodes[input.0].requires_grad)
249 };
250 Ok(self.push_node(value, requires_grad, Op::Log(input)))
251 }
252
253 pub fn sqrt(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
255 let (value, requires_grad) = {
256 let v = &self.nodes[input.0].value;
257 (v.sqrt(), self.nodes[input.0].requires_grad)
258 };
259 Ok(self.push_node(value, requires_grad, Op::Sqrt(input)))
260 }
261
262 pub fn sigmoid(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
264 let (value, requires_grad) = {
265 let v = &self.nodes[input.0].value;
266 let result = if let Some(ref backend) = self.backend {
267 backend.sigmoid(v)
268 } else {
269 kernel_sigmoid(v)
270 };
271 (result, self.nodes[input.0].requires_grad)
272 };
273 Ok(self.push_node(value, requires_grad, Op::Sigmoid(input)))
274 }
275
276 pub fn gelu(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
278 let (value, requires_grad) = {
279 let v = &self.nodes[input.0].value;
280 (kernel_gelu(v), self.nodes[input.0].requires_grad)
281 };
282 Ok(self.push_node(value, requires_grad, Op::Gelu(input)))
283 }
284
285 pub fn silu(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
287 let (value, requires_grad) = {
288 let v = &self.nodes[input.0].value;
289 (kernel_silu(v), self.nodes[input.0].requires_grad)
290 };
291 Ok(self.push_node(value, requires_grad, Op::Silu(input)))
292 }
293
294 pub fn mish(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
296 let (value, requires_grad) = {
297 let v = &self.nodes[input.0].value;
298 (kernel_mish(v), self.nodes[input.0].requires_grad)
299 };
300 Ok(self.push_node(value, requires_grad, Op::Mish(input)))
301 }
302
303 pub fn tanh(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
305 let (value, requires_grad) = {
306 let v = &self.nodes[input.0].value;
307 let data: Vec<f32> = v.data().iter().map(|&x| x.tanh()).collect();
308 (
309 Tensor::from_vec(v.shape().to_vec(), data)?,
310 self.nodes[input.0].requires_grad,
311 )
312 };
313 Ok(self.push_node(value, requires_grad, Op::Tanh(input)))
314 }
315
316 pub fn abs(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
318 let (value, requires_grad) = {
319 let v = &self.nodes[input.0].value;
320 let data: Vec<f32> = v.data().iter().map(|&x| x.abs()).collect();
321 (
322 Tensor::from_vec(v.shape().to_vec(), data)?,
323 self.nodes[input.0].requires_grad,
324 )
325 };
326 Ok(self.push_node(value, requires_grad, Op::Abs(input)))
327 }
328
329 pub fn pow(&mut self, base: NodeId, exponent: NodeId) -> Result<NodeId, AutogradError> {
331 let (value, requires_grad) = {
332 let bv = &self.nodes[base.0].value;
333 let ev = &self.nodes[exponent.0].value;
334 (
335 bv.pow(ev)?,
336 self.nodes[base.0].requires_grad || self.nodes[exponent.0].requires_grad,
337 )
338 };
339 Ok(self.push_node(value, requires_grad, Op::Pow(base, exponent)))
340 }
341
342 pub fn clamp(
344 &mut self,
345 input: NodeId,
346 min_val: f32,
347 max_val: f32,
348 ) -> Result<NodeId, AutogradError> {
349 let (value, requires_grad) = {
350 let v = &self.nodes[input.0].value;
351 (v.clamp(min_val, max_val), self.nodes[input.0].requires_grad)
352 };
353 Ok(self.push_node(
354 value,
355 requires_grad,
356 Op::Clamp {
357 input,
358 min_bits: min_val.to_bits(),
359 max_bits: max_val.to_bits(),
360 },
361 ))
362 }
363
364 pub fn leaky_relu(
366 &mut self,
367 input: NodeId,
368 negative_slope: f32,
369 ) -> Result<NodeId, AutogradError> {
370 let (value, requires_grad) = {
371 let v = &self.nodes[input.0].value;
372 let data: Vec<f32> = v
373 .data()
374 .iter()
375 .map(|&x| if x >= 0.0 { x } else { negative_slope * x })
376 .collect();
377 (
378 Tensor::from_vec(v.shape().to_vec(), data)?,
379 self.nodes[input.0].requires_grad,
380 )
381 };
382 Ok(self.push_node(
383 value,
384 requires_grad,
385 Op::LeakyRelu {
386 input,
387 negative_slope: negative_slope.to_bits(),
388 },
389 ))
390 }
391
392 pub fn softmax(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
394 let (value, requires_grad) = {
395 let v = &self.nodes[input.0].value;
396 let result = if let Some(ref backend) = self.backend {
397 backend.softmax_last_dim(v)?
398 } else {
399 softmax_last_dim(v)
400 };
401 (result, self.nodes[input.0].requires_grad)
402 };
403 Ok(self.push_node(value, requires_grad, Op::Softmax(input)))
404 }
405
406 pub fn log_softmax(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
408 let (value, requires_grad) = {
409 let v = &self.nodes[input.0].value;
410 let sm = softmax_last_dim(v);
411 let data: Vec<f32> = sm.data().iter().map(|&x| x.max(1e-12).ln()).collect();
412 (
413 Tensor::from_vec(sm.shape().to_vec(), data)?,
414 self.nodes[input.0].requires_grad,
415 )
416 };
417 Ok(self.push_node(value, requires_grad, Op::LogSoftmax(input)))
418 }
419
420 pub fn transpose_2d(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
422 let (value, requires_grad) = {
423 let v = &self.nodes[input.0].value;
424 (v.transpose_2d()?, self.nodes[input.0].requires_grad)
425 };
426 Ok(self.push_node(value, requires_grad, Op::Transpose2D(input)))
427 }
428
429 pub fn reshape(
431 &mut self,
432 input: NodeId,
433 new_shape: Vec<usize>,
434 ) -> Result<NodeId, AutogradError> {
435 let (value, requires_grad) = {
436 let v = &self.nodes[input.0].value;
437 (v.reshape(new_shape)?, self.nodes[input.0].requires_grad)
438 };
439 Ok(self.push_node(value, requires_grad, Op::ReshapeView { input }))
440 }
441
442 pub fn unsqueeze(&mut self, input: NodeId, axis: usize) -> Result<NodeId, AutogradError> {
444 let (value, requires_grad) = {
445 let v = &self.nodes[input.0].value;
446 (v.unsqueeze(axis)?, self.nodes[input.0].requires_grad)
447 };
448 Ok(self.push_node(
449 value,
450 requires_grad,
451 Op::UnsqueezeView {
452 input,
453 axis: axis as u16,
454 },
455 ))
456 }
457
458 pub fn squeeze(&mut self, input: NodeId, axis: usize) -> Result<NodeId, AutogradError> {
460 let (value, requires_grad) = {
461 let v = &self.nodes[input.0].value;
462 (v.squeeze(axis)?, self.nodes[input.0].requires_grad)
463 };
464 Ok(self.push_node(
465 value,
466 requires_grad,
467 Op::SqueezeView {
468 input,
469 axis: axis as u16,
470 },
471 ))
472 }
473
474 pub fn cat(&mut self, inputs: &[NodeId], axis: usize) -> Result<NodeId, AutogradError> {
476 if inputs.is_empty() {
477 return Err(AutogradError::InvalidRankForOperation {
478 op: "cat",
479 expected: 1,
480 got: 0,
481 });
482 }
483 let tensors: Vec<&Tensor> = inputs.iter().map(|&id| &self.nodes[id.0].value).collect();
484 let value = Tensor::cat(&tensors, axis)?;
485 let requires_grad = inputs.iter().any(|&id| self.nodes[id.0].requires_grad);
486 Ok(self.push_node(
487 value,
488 requires_grad,
489 Op::Cat {
490 inputs: inputs.to_vec(),
491 axis: axis as u16,
492 },
493 ))
494 }
495
496 pub fn select(
498 &mut self,
499 input: NodeId,
500 axis: usize,
501 index: usize,
502 ) -> Result<NodeId, AutogradError> {
503 let (value, requires_grad) = {
504 let v = &self.nodes[input.0].value;
505 (v.select(axis, index)?, self.nodes[input.0].requires_grad)
506 };
507 Ok(self.push_node(
508 value,
509 requires_grad,
510 Op::Select {
511 input,
512 axis: axis as u16,
513 index: index as u32,
514 },
515 ))
516 }
517
518 pub fn narrow(
520 &mut self,
521 input: NodeId,
522 axis: usize,
523 start: usize,
524 length: usize,
525 ) -> Result<NodeId, AutogradError> {
526 let (value, requires_grad) = {
527 let v = &self.nodes[input.0].value;
528 (
529 v.narrow(axis, start, length)?,
530 self.nodes[input.0].requires_grad,
531 )
532 };
533 Ok(self.push_node(
534 value,
535 requires_grad,
536 Op::Narrow {
537 input,
538 axis: axis as u16,
539 start: start as u32,
540 len: length as u32,
541 },
542 ))
543 }
544
545 pub fn gather(
549 &mut self,
550 input: NodeId,
551 axis: usize,
552 index: NodeId,
553 ) -> Result<NodeId, AutogradError> {
554 let (value, requires_grad) = {
555 let iv = &self.nodes[input.0].value;
556 let idx = &self.nodes[index.0].value;
557 (iv.gather(axis, idx)?, self.nodes[input.0].requires_grad)
558 };
559 Ok(self.push_node(
560 value,
561 requires_grad,
562 Op::Gather {
563 input,
564 axis: axis as u16,
565 index,
566 },
567 ))
568 }
569
570 pub fn scatter_add(
574 &mut self,
575 input: NodeId,
576 index: NodeId,
577 src: NodeId,
578 axis: usize,
579 ) -> Result<NodeId, AutogradError> {
580 let (value, requires_grad) = {
581 let iv = &self.nodes[input.0].value;
582 let idx = &self.nodes[index.0].value;
583 let sv = &self.nodes[src.0].value;
584 (
585 iv.scatter_add(axis, idx, sv)?,
586 self.nodes[input.0].requires_grad || self.nodes[src.0].requires_grad,
587 )
588 };
589 Ok(self.push_node(
590 value,
591 requires_grad,
592 Op::ScatterAdd {
593 input,
594 axis: axis as u16,
595 index,
596 src,
597 },
598 ))
599 }
600
601 pub fn pad(
605 &mut self,
606 input: NodeId,
607 padding: &[usize],
608 value: f32,
609 ) -> Result<NodeId, AutogradError> {
610 let (result, requires_grad, pad_before, pad_after) = {
611 let iv = &self.nodes[input.0].value;
612 let shape = iv.shape();
613 let rank = shape.len();
614 if padding.len() != rank * 2 {
615 return Err(AutogradError::InvalidRankForOperation {
616 op: "pad",
617 expected: rank * 2,
618 got: padding.len(),
619 });
620 }
621 let mut new_shape = Vec::with_capacity(rank);
622 let mut pad_before = Vec::with_capacity(rank);
623 let mut pad_after = Vec::with_capacity(rank);
624 for d in 0..rank {
625 let pb = padding[d * 2];
626 let pa = padding[d * 2 + 1];
627 pad_before.push(pb as u32);
628 pad_after.push(pa as u32);
629 new_shape.push(shape[d] + pb + pa);
630 }
631 let total: usize = new_shape.iter().product();
632 let mut out_data = vec![value; total];
633 let data = iv.data();
634 copy_region_nd(data, shape, &mut out_data, &new_shape, &pad_before);
635
636 let result = Tensor::from_vec(new_shape, out_data)?;
637 (
638 result,
639 self.nodes[input.0].requires_grad,
640 pad_before,
641 pad_after,
642 )
643 };
644 Ok(self.push_node(
645 result,
646 requires_grad,
647 Op::Pad {
648 input,
649 pad_before,
650 pad_after,
651 },
652 ))
653 }
654
655 pub fn repeat(&mut self, input: NodeId, repeats: &[usize]) -> Result<NodeId, AutogradError> {
657 let (result, requires_grad) = {
658 let v = &self.nodes[input.0].value;
659 (v.repeat(repeats)?, self.nodes[input.0].requires_grad)
660 };
661 let reps: Vec<u32> = repeats.iter().map(|&r| r as u32).collect();
662 Ok(self.push_node(
663 result,
664 requires_grad,
665 Op::Repeat {
666 input,
667 repeats: reps,
668 },
669 ))
670 }
671
672 pub fn sum(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
674 let (value, requires_grad) = {
675 let v = &self.nodes[input.0].value;
676 (Tensor::scalar(v.sum()), self.nodes[input.0].requires_grad)
677 };
678 Ok(self.push_node(value, requires_grad, Op::Sum(input)))
679 }
680
681 pub fn mean(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
683 let (value, requires_grad) = {
684 let v = &self.nodes[input.0].value;
685 (Tensor::scalar(v.mean()), self.nodes[input.0].requires_grad)
686 };
687 Ok(self.push_node(value, requires_grad, Op::Mean(input)))
688 }
689
690 pub fn conv2d_nhwc(
694 &mut self,
695 input: NodeId,
696 weight: NodeId,
697 bias: Option<NodeId>,
698 stride_h: usize,
699 stride_w: usize,
700 ) -> Result<NodeId, AutogradError> {
701 let (value, requires_grad) = {
702 let iv = &self.nodes[input.0].value;
703 let wv = &self.nodes[weight.0].value;
704 let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
705 let rg = self.nodes[input.0].requires_grad
706 || self.nodes[weight.0].requires_grad
707 || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
708 let result = if let Some(ref backend) = self.backend {
709 backend.conv2d_nhwc(iv, wv, bv, stride_h, stride_w)?
710 } else {
711 conv2d_nhwc(iv, wv, bv, stride_h, stride_w)?
712 };
713 (result, rg)
714 };
715 Ok(self.push_node(
716 value,
717 requires_grad,
718 Op::Conv2dNhwc {
719 input,
720 weight,
721 bias,
722 stride_h: stride_h as u16,
723 stride_w: stride_w as u16,
724 },
725 ))
726 }
727
728 pub fn max_pool2d_nhwc(
730 &mut self,
731 input: NodeId,
732 kernel_h: usize,
733 kernel_w: usize,
734 stride_h: usize,
735 stride_w: usize,
736 ) -> Result<NodeId, AutogradError> {
737 let (value, requires_grad, indices) = {
738 let iv = &self.nodes[input.0].value;
739 let shape = iv.shape();
740 if shape.len() != 4 {
741 return Err(AutogradError::InvalidRankForOperation {
742 op: "max_pool2d_nhwc",
743 expected: 4,
744 got: shape.len(),
745 });
746 }
747 let (n, ih, iw, c) = (shape[0], shape[1], shape[2], shape[3]);
748 let oh = (ih - kernel_h) / stride_h + 1;
749 let ow = (iw - kernel_w) / stride_w + 1;
750
751 let mut indices = vec![0usize; n * oh * ow * c];
753 let in_data = iv.data();
754
755 for batch in 0..n {
756 for row in 0..oh {
757 for col in 0..ow {
758 for ch in 0..c {
759 let out_idx = ((batch * oh + row) * ow + col) * c + ch;
760 let mut best_val = f32::NEG_INFINITY;
761 let mut best_offset = 0usize;
762 for kh in 0..kernel_h {
763 for kw in 0..kernel_w {
764 let ih_pos = row * stride_h + kh;
765 let iw_pos = col * stride_w + kw;
766 let in_idx = ((batch * ih + ih_pos) * iw + iw_pos) * c + ch;
767 let v = in_data[in_idx];
768 if v > best_val {
769 best_val = v;
770 best_offset = in_idx;
771 }
772 }
773 }
774 indices[out_idx] = best_offset;
775 }
776 }
777 }
778 }
779
780 let value = if let Some(ref backend) = self.backend {
781 backend.max_pool2d_nhwc(iv, kernel_h, kernel_w, stride_h, stride_w)?
782 } else {
783 let mut out_data = vec![f32::NEG_INFINITY; n * oh * ow * c];
784 for batch in 0..n {
785 for row in 0..oh {
786 for col in 0..ow {
787 for ch in 0..c {
788 let out_idx = ((batch * oh + row) * ow + col) * c + ch;
789 out_data[out_idx] = in_data[indices[out_idx]];
790 }
791 }
792 }
793 }
794 Tensor::from_vec(vec![n, oh, ow, c], out_data)?
795 };
796 (value, self.nodes[input.0].requires_grad, indices)
797 };
798 Ok(self.push_node_with_aux(
799 value,
800 requires_grad,
801 Op::MaxPool2dNhwc {
802 input,
803 kernel_h: kernel_h as u16,
804 kernel_w: kernel_w as u16,
805 stride_h: stride_h as u16,
806 stride_w: stride_w as u16,
807 },
808 AuxData::MaxPoolIndices(indices),
809 ))
810 }
811
812 pub fn avg_pool2d_nhwc(
814 &mut self,
815 input: NodeId,
816 kernel_h: usize,
817 kernel_w: usize,
818 stride_h: usize,
819 stride_w: usize,
820 ) -> Result<NodeId, AutogradError> {
821 let (value, requires_grad) = {
822 let v = &self.nodes[input.0].value;
823 let result = if let Some(ref backend) = self.backend {
824 backend.avg_pool2d_nhwc(v, kernel_h, kernel_w, stride_h, stride_w)?
825 } else {
826 avg_pool2d_nhwc(v, kernel_h, kernel_w, stride_h, stride_w)?
827 };
828 (result, self.nodes[input.0].requires_grad)
829 };
830 Ok(self.push_node(
831 value,
832 requires_grad,
833 Op::AvgPool2dNhwc {
834 input,
835 kernel_h: kernel_h as u16,
836 kernel_w: kernel_w as u16,
837 stride_h: stride_h as u16,
838 stride_w: stride_w as u16,
839 },
840 ))
841 }
842
843 pub fn batch_norm2d_nhwc(
846 &mut self,
847 input: NodeId,
848 gamma: NodeId,
849 beta: NodeId,
850 running_mean: NodeId,
851 running_var: NodeId,
852 epsilon: f32,
853 ) -> Result<NodeId, AutogradError> {
854 let eps_bits = epsilon.to_bits();
855 let (value, requires_grad, norm_tensor) = {
856 let iv = &self.nodes[input.0].value;
857 let gv = &self.nodes[gamma.0].value;
858 let bv = &self.nodes[beta.0].value;
859 let mv = &self.nodes[running_mean.0].value;
860 let vv = &self.nodes[running_var.0].value;
861
862 let shape = iv.shape();
863 if shape.len() != 4 {
864 return Err(AutogradError::InvalidRankForOperation {
865 op: "batch_norm2d_nhwc",
866 expected: 4,
867 got: shape.len(),
868 });
869 }
870 let c = shape[3];
872 let total = iv.len();
873 let in_data = iv.data();
874 let mean_data = mv.data();
875 let var_data = vv.data();
876 let mut normalized = vec![0.0f32; total];
877 for i in 0..total {
878 let ch = i % c;
879 let inv_std = 1.0 / (var_data[ch] + epsilon).sqrt();
880 normalized[i] = (in_data[i] - mean_data[ch]) * inv_std;
881 }
882 let norm_tensor = Tensor::from_vec(shape.to_vec(), normalized)?;
883
884 let params = BatchNorm2dParams {
885 gamma: gv,
886 beta: bv,
887 mean: mv,
888 variance: vv,
889 epsilon,
890 };
891 let value = if let Some(ref backend) = self.backend {
892 backend.batch_norm2d_nhwc(iv, params)?
893 } else {
894 batch_norm2d_nhwc(iv, params)?
895 };
896 let rg = self.nodes[input.0].requires_grad
897 || self.nodes[gamma.0].requires_grad
898 || self.nodes[beta.0].requires_grad;
899 (value, rg, norm_tensor)
900 };
901 Ok(self.push_node_with_aux(
902 value,
903 requires_grad,
904 Op::BatchNorm2dNhwc {
905 input,
906 gamma,
907 beta,
908 running_mean,
909 running_var,
910 epsilon: eps_bits,
911 },
912 AuxData::BatchNormNormalized(norm_tensor),
913 ))
914 }
915
916 pub fn layer_norm(
921 &mut self,
922 input: NodeId,
923 gamma: NodeId,
924 beta: NodeId,
925 epsilon: f32,
926 ) -> Result<NodeId, AutogradError> {
927 let eps_bits = epsilon.to_bits();
928 let (value, requires_grad, norm_tensor) = {
929 let iv = &self.nodes[input.0].value;
930 let gv = &self.nodes[gamma.0].value;
931 let bv = &self.nodes[beta.0].value;
932 let shape = iv.shape();
933 let last_dim = *shape.last().ok_or(AutogradError::InvalidRankForOperation {
934 op: "layer_norm",
935 expected: 1,
936 got: 0,
937 })?;
938 let data = iv.data();
939 let gamma_data = gv.data();
940 let beta_data = bv.data();
941 let num_groups = data.len() / last_dim;
942 let mut out = vec![0.0f32; data.len()];
943 let mut normalized = vec![0.0f32; data.len()];
944 for g in 0..num_groups {
945 let base = g * last_dim;
946 let slice = &data[base..base + last_dim];
947 let mean = slice.iter().sum::<f32>() / last_dim as f32;
948 let var =
949 slice.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / last_dim as f32;
950 let inv_std = 1.0 / (var + epsilon).sqrt();
951 for i in 0..last_dim {
952 let x_hat = (slice[i] - mean) * inv_std;
953 normalized[base + i] = x_hat;
954 out[base + i] = x_hat * gamma_data[i] + beta_data[i];
955 }
956 }
957 let value = Tensor::from_vec(shape.to_vec(), out)?;
958 let norm_tensor = Tensor::from_vec(shape.to_vec(), normalized)?;
959 let rg = self.nodes[input.0].requires_grad
960 || self.nodes[gamma.0].requires_grad
961 || self.nodes[beta.0].requires_grad;
962 (value, rg, norm_tensor)
963 };
964 Ok(self.push_node_with_aux(
965 value,
966 requires_grad,
967 Op::LayerNorm {
968 input,
969 gamma,
970 beta,
971 eps_bits,
972 },
973 AuxData::NormNormalized(norm_tensor),
974 ))
975 }
976
977 pub fn group_norm(
982 &mut self,
983 input: NodeId,
984 gamma: NodeId,
985 beta: NodeId,
986 num_groups: usize,
987 epsilon: f32,
988 ) -> Result<NodeId, AutogradError> {
989 let eps_bits = epsilon.to_bits();
990 let (value, requires_grad, norm_tensor) = {
991 let iv = &self.nodes[input.0].value;
992 let gv = &self.nodes[gamma.0].value;
993 let bv = &self.nodes[beta.0].value;
994 let shape = iv.shape();
995 if shape.len() != 4 {
996 return Err(AutogradError::InvalidRankForOperation {
997 op: "group_norm",
998 expected: 4,
999 got: shape.len(),
1000 });
1001 }
1002 let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1003 let channels_per_group = c / num_groups;
1004 let spatial = h * w;
1005 let data = iv.data();
1006 let gamma_data = gv.data();
1007 let beta_data = bv.data();
1008 let mut out = vec![0.0f32; data.len()];
1009 let mut normalized = vec![0.0f32; data.len()];
1010
1011 for ni in 0..n {
1012 for gi in 0..num_groups {
1013 let c_start = gi * channels_per_group;
1014 let c_end = c_start + channels_per_group;
1015 let group_size = spatial * channels_per_group;
1016 let mut sum = 0.0f32;
1017 for hi in 0..h {
1018 for wi in 0..w {
1019 let base = ((ni * h + hi) * w + wi) * c;
1020 for ci in c_start..c_end {
1021 sum += data[base + ci];
1022 }
1023 }
1024 }
1025 let mean = sum / group_size as f32;
1026 let mut var_sum = 0.0f32;
1027 for hi in 0..h {
1028 for wi in 0..w {
1029 let base = ((ni * h + hi) * w + wi) * c;
1030 for ci in c_start..c_end {
1031 let d = data[base + ci] - mean;
1032 var_sum += d * d;
1033 }
1034 }
1035 }
1036 let inv_std = 1.0 / (var_sum / group_size as f32 + epsilon).sqrt();
1037 for hi in 0..h {
1038 for wi in 0..w {
1039 let base = ((ni * h + hi) * w + wi) * c;
1040 for ci in c_start..c_end {
1041 let x_hat = (data[base + ci] - mean) * inv_std;
1042 normalized[base + ci] = x_hat;
1043 out[base + ci] = x_hat * gamma_data[ci] + beta_data[ci];
1044 }
1045 }
1046 }
1047 }
1048 }
1049 let value = Tensor::from_vec(shape.to_vec(), out)?;
1050 let norm_tensor = Tensor::from_vec(shape.to_vec(), normalized)?;
1051 let rg = self.nodes[input.0].requires_grad
1052 || self.nodes[gamma.0].requires_grad
1053 || self.nodes[beta.0].requires_grad;
1054 (value, rg, norm_tensor)
1055 };
1056 Ok(self.push_node_with_aux(
1057 value,
1058 requires_grad,
1059 Op::GroupNorm {
1060 input,
1061 gamma,
1062 beta,
1063 num_groups: num_groups as u16,
1064 eps_bits,
1065 },
1066 AuxData::NormNormalized(norm_tensor),
1067 ))
1068 }
1069
1070 pub fn flatten(&mut self, input: NodeId) -> Result<NodeId, AutogradError> {
1072 let (value, requires_grad) = {
1073 let v = &self.nodes[input.0].value;
1074 let shape = v.shape();
1075 if shape.len() < 2 {
1076 return Err(AutogradError::InvalidRankForOperation {
1077 op: "flatten",
1078 expected: 2,
1079 got: shape.len(),
1080 });
1081 }
1082 let n = shape[0];
1083 let flat = v.len() / n;
1084 (v.reshape(vec![n, flat])?, self.nodes[input.0].requires_grad)
1085 };
1086 Ok(self.push_node(value, requires_grad, Op::Flatten(input)))
1087 }
1088
1089 pub fn sum_axis(&mut self, input: NodeId, axis: usize) -> Result<NodeId, AutogradError> {
1091 let (value, requires_grad) = {
1092 let v = &self.nodes[input.0].value;
1093 (v.sum_axis(axis)?, self.nodes[input.0].requires_grad)
1094 };
1095 Ok(self.push_node(
1096 value,
1097 requires_grad,
1098 Op::SumAxis {
1099 input,
1100 axis: axis as u16,
1101 },
1102 ))
1103 }
1104
1105 pub fn mean_axis(&mut self, input: NodeId, axis: usize) -> Result<NodeId, AutogradError> {
1107 let (value, requires_grad) = {
1108 let v = &self.nodes[input.0].value;
1109 (v.mean_axis(axis)?, self.nodes[input.0].requires_grad)
1110 };
1111 Ok(self.push_node(
1112 value,
1113 requires_grad,
1114 Op::MeanAxis {
1115 input,
1116 axis: axis as u16,
1117 },
1118 ))
1119 }
1120
1121 pub fn depthwise_conv2d_nhwc(
1125 &mut self,
1126 input: NodeId,
1127 weight: NodeId,
1128 bias: Option<NodeId>,
1129 stride_h: usize,
1130 stride_w: usize,
1131 ) -> Result<NodeId, AutogradError> {
1132 let (value, requires_grad) = {
1133 let iv = &self.nodes[input.0].value;
1134 let wv = &self.nodes[weight.0].value;
1135 let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
1136 let rg = self.nodes[input.0].requires_grad
1137 || self.nodes[weight.0].requires_grad
1138 || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
1139 let result = if let Some(ref backend) = self.backend {
1140 backend.depthwise_conv2d_nhwc(iv, wv, bv, stride_h, stride_w)?
1141 } else {
1142 depthwise_conv2d_nhwc(iv, wv, bv, stride_h, stride_w)?
1143 };
1144 (result, rg)
1145 };
1146 Ok(self.push_node(
1147 value,
1148 requires_grad,
1149 Op::DepthwiseConv2dNhwc {
1150 input,
1151 weight,
1152 bias,
1153 stride_h: stride_h as u16,
1154 stride_w: stride_w as u16,
1155 },
1156 ))
1157 }
1158
1159 pub fn scatter(
1163 &mut self,
1164 input: NodeId,
1165 indices: NodeId,
1166 src: NodeId,
1167 ) -> Result<NodeId, AutogradError> {
1168 let (value, requires_grad) = {
1169 let iv = &self.nodes[input.0].value;
1170 let idx = &self.nodes[indices.0].value;
1171 let sv = &self.nodes[src.0].value;
1172 let shape = iv.shape();
1173 let d = shape[1];
1174 let mut out_data = iv.data().to_vec();
1175 for (i, &raw_idx) in idx.data().iter().enumerate() {
1176 let row = raw_idx as usize;
1177 let src_offset = i * d;
1178 let dst_offset = row * d;
1179 out_data[dst_offset..dst_offset + d]
1180 .copy_from_slice(&sv.data()[src_offset..src_offset + d]);
1181 }
1182 let rg = self.nodes[input.0].requires_grad || self.nodes[src.0].requires_grad;
1183 (Tensor::from_vec(shape.to_vec(), out_data)?, rg)
1184 };
1185 Ok(self.push_node(
1186 value,
1187 requires_grad,
1188 Op::Scatter {
1189 input,
1190 indices,
1191 src,
1192 },
1193 ))
1194 }
1195
1196 pub fn embedding_lookup(
1200 &mut self,
1201 weight: NodeId,
1202 indices: NodeId,
1203 ) -> Result<NodeId, AutogradError> {
1204 let (value, requires_grad) = {
1205 let wv = &self.nodes[weight.0].value;
1206 let idx = &self.nodes[indices.0].value;
1207 let embed_dim = wv.shape()[1];
1208 let seq_len = idx.data().len();
1209 let w_data = wv.data();
1210 let mut out_data = vec![0.0f32; seq_len * embed_dim];
1211 for (i, &raw_idx) in idx.data().iter().enumerate() {
1212 let row = raw_idx as usize;
1213 let src_offset = row * embed_dim;
1214 let dst_offset = i * embed_dim;
1215 out_data[dst_offset..dst_offset + embed_dim]
1216 .copy_from_slice(&w_data[src_offset..src_offset + embed_dim]);
1217 }
1218 (
1219 Tensor::from_vec(vec![seq_len, embed_dim], out_data)?,
1220 self.nodes[weight.0].requires_grad,
1221 )
1222 };
1223 Ok(self.push_node(
1224 value,
1225 requires_grad,
1226 Op::EmbeddingLookup { weight, indices },
1227 ))
1228 }
1229
1230 pub fn conv1d_nlc(
1234 &mut self,
1235 input: NodeId,
1236 weight: NodeId,
1237 bias: Option<NodeId>,
1238 stride: usize,
1239 ) -> Result<NodeId, AutogradError> {
1240 let (value, requires_grad) = {
1241 let iv = &self.nodes[input.0].value;
1242 let wv = &self.nodes[weight.0].value;
1243 let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
1244 let rg = self.nodes[input.0].requires_grad
1245 || self.nodes[weight.0].requires_grad
1246 || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
1247 let in_shape = iv.shape();
1248 let w_shape = wv.shape();
1249 let (batch, length, _c_in) = (in_shape[0], in_shape[1], in_shape[2]);
1250 let (kernel_size, c_in, c_out) = (w_shape[0], w_shape[1], w_shape[2]);
1251 let out_len = (length - kernel_size) / stride + 1;
1252 let in_data = iv.data();
1253 let w_data = wv.data();
1254 let mut out = vec![0.0f32; batch * out_len * c_out];
1255 for b in 0..batch {
1256 for ol in 0..out_len {
1257 let start = ol * stride;
1258 for oc in 0..c_out {
1259 let mut sum = 0.0f32;
1260 for k in 0..kernel_size {
1261 for ci in 0..c_in {
1262 sum += in_data[(b * length + start + k) * c_in + ci]
1263 * w_data[(k * c_in + ci) * c_out + oc];
1264 }
1265 }
1266 if let Some(bv) = bv {
1267 sum += bv.data()[oc];
1268 }
1269 out[(b * out_len + ol) * c_out + oc] = sum;
1270 }
1271 }
1272 }
1273 (Tensor::from_vec(vec![batch, out_len, c_out], out)?, rg)
1274 };
1275 Ok(self.push_node(
1276 value,
1277 requires_grad,
1278 Op::Conv1dNlc {
1279 input,
1280 weight,
1281 bias,
1282 stride: stride as u16,
1283 },
1284 ))
1285 }
1286
1287 pub fn conv3d_ndhwc(
1291 &mut self,
1292 input: NodeId,
1293 weight: NodeId,
1294 bias: Option<NodeId>,
1295 stride_d: usize,
1296 stride_h: usize,
1297 stride_w: usize,
1298 ) -> Result<NodeId, AutogradError> {
1299 let (value, requires_grad) = {
1300 let iv = &self.nodes[input.0].value;
1301 let wv = &self.nodes[weight.0].value;
1302 let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
1303 let rg = self.nodes[input.0].requires_grad
1304 || self.nodes[weight.0].requires_grad
1305 || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
1306 let (out_data, out_shape) = conv3d(
1307 iv.data(),
1308 iv.shape(),
1309 wv.data(),
1310 wv.shape(),
1311 (stride_d, stride_h, stride_w),
1312 (0, 0, 0),
1313 );
1314 let mut result = Tensor::from_vec(out_shape, out_data)?;
1315 if let Some(bv) = bv {
1316 let c_out = wv.shape()[4];
1317 let data = result.data_mut();
1318 let bias_data = bv.data();
1319 for pixel in data.chunks_mut(c_out) {
1320 for (v, &bval) in pixel.iter_mut().zip(bias_data.iter()) {
1321 *v += bval;
1322 }
1323 }
1324 }
1325 (result, rg)
1326 };
1327 Ok(self.push_node(
1328 value,
1329 requires_grad,
1330 Op::Conv3dNdhwc {
1331 input,
1332 weight,
1333 bias,
1334 stride_d: stride_d as u16,
1335 stride_h: stride_h as u16,
1336 stride_w: stride_w as u16,
1337 },
1338 ))
1339 }
1340
1341 pub fn scaled_dot_product_attention(
1345 &mut self,
1346 query: NodeId,
1347 key: NodeId,
1348 value: NodeId,
1349 ) -> Result<NodeId, AutogradError> {
1350 let (output, attn_weights, requires_grad) = {
1351 let qv = &self.nodes[query.0].value;
1352 let kv = &self.nodes[key.0].value;
1353 let vv = &self.nodes[value.0].value;
1354 let rg = self.nodes[query.0].requires_grad
1355 || self.nodes[key.0].requires_grad
1356 || self.nodes[value.0].requires_grad;
1357 let d_k = qv.shape()[1];
1358 let scale = (d_k as f32).sqrt().recip();
1359
1360 let kt = kv.transpose_2d()?;
1362 let scores = matmul_2d(qv, &kt)?;
1363 let scaled = scores.scale(scale);
1364
1365 let weights = yscv_kernels::softmax_last_dim(&scaled)?;
1367
1368 let out = matmul_2d(&weights, vv)?;
1370 (out, weights, rg)
1371 };
1372 Ok(self.push_node_with_aux(
1373 output,
1374 requires_grad,
1375 Op::ScaledDotProductAttention { query, key, value },
1376 AuxData::AttentionWeights(attn_weights),
1377 ))
1378 }
1379
1380 pub fn conv_transpose2d_nhwc(
1385 &mut self,
1386 input: NodeId,
1387 weight: NodeId,
1388 bias: Option<NodeId>,
1389 stride_h: usize,
1390 stride_w: usize,
1391 ) -> Result<NodeId, AutogradError> {
1392 let (value, requires_grad) = {
1393 let iv = &self.nodes[input.0].value;
1394 let wv = &self.nodes[weight.0].value;
1395 let rg = self.nodes[input.0].requires_grad
1396 || self.nodes[weight.0].requires_grad
1397 || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
1398 let in_shape = iv.shape();
1399 let w_shape = wv.shape();
1400 let (n, h, w_dim, c_in) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
1401 let (kh, kw, c_out, _) = (w_shape[0], w_shape[1], w_shape[2], w_shape[3]);
1402 let out_h = (h - 1) * stride_h + kh;
1403 let out_w = (w_dim - 1) * stride_w + kw;
1404 let in_data = iv.data();
1405 let w_data = wv.data();
1406 let mut out = vec![0.0f32; n * out_h * out_w * c_out];
1407 for batch in 0..n {
1408 for ih in 0..h {
1409 for iw in 0..w_dim {
1410 for ic in 0..c_in {
1411 let val = in_data[((batch * h + ih) * w_dim + iw) * c_in + ic];
1412 for ki in 0..kh {
1413 for kj in 0..kw {
1414 let oh = ih * stride_h + ki;
1415 let ow = iw * stride_w + kj;
1416 for oc in 0..c_out {
1417 let w_idx = ((ki * kw + kj) * c_out + oc) * c_in + ic;
1418 out[((batch * out_h + oh) * out_w + ow) * c_out + oc] +=
1419 val * w_data[w_idx];
1420 }
1421 }
1422 }
1423 }
1424 }
1425 }
1426 }
1427 if let Some(b_id) = bias {
1428 let bv = &self.nodes[b_id.0].value;
1429 let bd = bv.data();
1430 for i in 0..(n * out_h * out_w) {
1431 for oc in 0..c_out {
1432 out[i * c_out + oc] += bd[oc];
1433 }
1434 }
1435 }
1436 (Tensor::from_vec(vec![n, out_h, out_w, c_out], out)?, rg)
1437 };
1438 Ok(self.push_node(
1439 value,
1440 requires_grad,
1441 Op::ConvTranspose2dNhwc {
1442 input,
1443 weight,
1444 bias,
1445 stride_h: stride_h as u16,
1446 stride_w: stride_w as u16,
1447 },
1448 ))
1449 }
1450
1451 pub fn adaptive_avg_pool2d_nhwc(
1454 &mut self,
1455 input: NodeId,
1456 out_h: usize,
1457 out_w: usize,
1458 ) -> Result<NodeId, AutogradError> {
1459 let (value, requires_grad) = {
1460 let iv = &self.nodes[input.0].value;
1461 let shape = iv.shape();
1462 if shape.len() != 4 {
1463 return Err(AutogradError::InvalidRankForOperation {
1464 op: "adaptive_avg_pool2d_nhwc",
1465 expected: 4,
1466 got: shape.len(),
1467 });
1468 }
1469 let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1470 let data = iv.data();
1471 let mut out = vec![0.0f32; n * out_h * out_w * c];
1472 for b in 0..n {
1473 for oh in 0..out_h {
1474 let h_start = oh * h / out_h;
1475 let h_end = ((oh + 1) * h / out_h).max(h_start + 1);
1476 for ow in 0..out_w {
1477 let w_start = ow * w / out_w;
1478 let w_end = ((ow + 1) * w / out_w).max(w_start + 1);
1479 let count = (h_end - h_start) * (w_end - w_start);
1480 for ch in 0..c {
1481 let mut sum = 0.0f32;
1482 for ih in h_start..h_end {
1483 for iw in w_start..w_end {
1484 sum += data[((b * h + ih) * w + iw) * c + ch];
1485 }
1486 }
1487 out[((b * out_h + oh) * out_w + ow) * c + ch] = sum / count as f32;
1488 }
1489 }
1490 }
1491 }
1492 (
1493 Tensor::from_vec(vec![n, out_h, out_w, c], out)?,
1494 self.nodes[input.0].requires_grad,
1495 )
1496 };
1497 Ok(self.push_node(
1498 value,
1499 requires_grad,
1500 Op::AdaptiveAvgPool2dNhwc {
1501 input,
1502 out_h: out_h as u16,
1503 out_w: out_w as u16,
1504 },
1505 ))
1506 }
1507
1508 pub fn adaptive_max_pool2d_nhwc(
1510 &mut self,
1511 input: NodeId,
1512 out_h: usize,
1513 out_w: usize,
1514 ) -> Result<NodeId, AutogradError> {
1515 let (value, requires_grad, indices) = {
1516 let iv = &self.nodes[input.0].value;
1517 let shape = iv.shape();
1518 if shape.len() != 4 {
1519 return Err(AutogradError::InvalidRankForOperation {
1520 op: "adaptive_max_pool2d_nhwc",
1521 expected: 4,
1522 got: shape.len(),
1523 });
1524 }
1525 let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1526 let data = iv.data();
1527 let out_len = n * out_h * out_w * c;
1528 let mut out = vec![f32::NEG_INFINITY; out_len];
1529 let mut indices = vec![0usize; out_len];
1530 for b in 0..n {
1531 for oh in 0..out_h {
1532 let h_start = oh * h / out_h;
1533 let h_end = ((oh + 1) * h / out_h).max(h_start + 1);
1534 for ow in 0..out_w {
1535 let w_start = ow * w / out_w;
1536 let w_end = ((ow + 1) * w / out_w).max(w_start + 1);
1537 for ch in 0..c {
1538 let out_idx = ((b * out_h + oh) * out_w + ow) * c + ch;
1539 let mut best_val = f32::NEG_INFINITY;
1540 let mut best_in = 0usize;
1541 for ih in h_start..h_end {
1542 for iw in w_start..w_end {
1543 let in_idx = ((b * h + ih) * w + iw) * c + ch;
1544 let v = data[in_idx];
1545 if v > best_val {
1546 best_val = v;
1547 best_in = in_idx;
1548 }
1549 }
1550 }
1551 out[out_idx] = best_val;
1552 indices[out_idx] = best_in;
1553 }
1554 }
1555 }
1556 }
1557 (
1558 Tensor::from_vec(vec![n, out_h, out_w, c], out)?,
1559 self.nodes[input.0].requires_grad,
1560 indices,
1561 )
1562 };
1563 Ok(self.push_node_with_aux(
1564 value,
1565 requires_grad,
1566 Op::AdaptiveMaxPool2dNhwc {
1567 input,
1568 out_h: out_h as u16,
1569 out_w: out_w as u16,
1570 },
1571 AuxData::MaxPoolIndices(indices),
1572 ))
1573 }
1574
1575 pub fn instance_norm_nhwc(
1579 &mut self,
1580 input: NodeId,
1581 gamma: NodeId,
1582 beta: NodeId,
1583 epsilon: f32,
1584 ) -> Result<NodeId, AutogradError> {
1585 let eps_bits = epsilon.to_bits();
1586 let (value, requires_grad, norm_tensor) = {
1587 let iv = &self.nodes[input.0].value;
1588 let gv = &self.nodes[gamma.0].value;
1589 let bv = &self.nodes[beta.0].value;
1590 let shape = iv.shape();
1591 if shape.len() != 4 {
1592 return Err(AutogradError::InvalidRankForOperation {
1593 op: "instance_norm_nhwc",
1594 expected: 4,
1595 got: shape.len(),
1596 });
1597 }
1598 let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1599 let spatial = h * w;
1600 let data = iv.data();
1601 let gamma_data = gv.data();
1602 let beta_data = bv.data();
1603 let mut out = vec![0.0f32; data.len()];
1604 let mut normalized = vec![0.0f32; data.len()];
1605
1606 for ni in 0..n {
1607 for ch in 0..c {
1608 let mut sum = 0.0f32;
1609 for s in 0..spatial {
1610 let idx = (ni * h * w + s) * c + ch;
1611 sum += data[idx];
1612 }
1613 let mean = sum / spatial as f32;
1614 let mut var_sum = 0.0f32;
1615 for s in 0..spatial {
1616 let idx = (ni * h * w + s) * c + ch;
1617 let d = data[idx] - mean;
1618 var_sum += d * d;
1619 }
1620 let inv_std = 1.0 / (var_sum / spatial as f32 + epsilon).sqrt();
1621 for s in 0..spatial {
1622 let idx = (ni * h * w + s) * c + ch;
1623 let x_hat = (data[idx] - mean) * inv_std;
1624 normalized[idx] = x_hat;
1625 out[idx] = x_hat * gamma_data[ch] + beta_data[ch];
1626 }
1627 }
1628 }
1629 let value = Tensor::from_vec(shape.to_vec(), out)?;
1630 let norm_tensor = Tensor::from_vec(shape.to_vec(), normalized)?;
1631 let rg = self.nodes[input.0].requires_grad
1632 || self.nodes[gamma.0].requires_grad
1633 || self.nodes[beta.0].requires_grad;
1634 (value, rg, norm_tensor)
1635 };
1636 Ok(self.push_node_with_aux(
1637 value,
1638 requires_grad,
1639 Op::InstanceNormNhwc {
1640 input,
1641 gamma,
1642 beta,
1643 eps_bits,
1644 },
1645 AuxData::NormNormalized(norm_tensor),
1646 ))
1647 }
1648
1649 pub fn prelu(&mut self, input: NodeId, alpha: NodeId) -> Result<NodeId, AutogradError> {
1653 let (value, requires_grad) = {
1654 let iv = &self.nodes[input.0].value;
1655 let av = &self.nodes[alpha.0].value;
1656 let in_data = iv.data();
1657 let alpha_data = av.data();
1658 let alpha_len = alpha_data.len();
1659 let out: Vec<f32> = in_data
1660 .iter()
1661 .enumerate()
1662 .map(|(i, &x)| {
1663 if x > 0.0 {
1664 x
1665 } else {
1666 let a = if alpha_len == 1 {
1667 alpha_data[0]
1668 } else {
1669 alpha_data[i % alpha_len]
1670 };
1671 a * x
1672 }
1673 })
1674 .collect();
1675 let rg = self.nodes[input.0].requires_grad || self.nodes[alpha.0].requires_grad;
1676 (Tensor::from_vec(iv.shape().to_vec(), out)?, rg)
1677 };
1678 Ok(self.push_node(value, requires_grad, Op::PRelu { input, alpha }))
1679 }
1680
1681 pub fn pixel_shuffle(
1683 &mut self,
1684 input: NodeId,
1685 upscale_factor: usize,
1686 ) -> Result<NodeId, AutogradError> {
1687 let (value, requires_grad) = {
1688 let iv = &self.nodes[input.0].value;
1689 let shape = iv.shape();
1690 if shape.len() != 4 {
1691 return Err(AutogradError::InvalidRankForOperation {
1692 op: "pixel_shuffle",
1693 expected: 4,
1694 got: shape.len(),
1695 });
1696 }
1697 let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1698 let r = upscale_factor;
1699 let out_c = c / (r * r);
1700 let out_h = h * r;
1701 let out_w = w * r;
1702 let data = iv.data();
1703 let mut out = vec![0.0f32; batch * out_h * out_w * out_c];
1704 for b in 0..batch {
1705 for ih in 0..h {
1706 for iw in 0..w {
1707 for oc in 0..out_c {
1708 for ry in 0..r {
1709 for rx in 0..r {
1710 let ic = oc * r * r + ry * r + rx;
1711 let oh = ih * r + ry;
1712 let ow = iw * r + rx;
1713 out[((b * out_h + oh) * out_w + ow) * out_c + oc] =
1714 data[((b * h + ih) * w + iw) * c + ic];
1715 }
1716 }
1717 }
1718 }
1719 }
1720 }
1721 (
1722 Tensor::from_vec(vec![batch, out_h, out_w, out_c], out)?,
1723 self.nodes[input.0].requires_grad,
1724 )
1725 };
1726 Ok(self.push_node(
1727 value,
1728 requires_grad,
1729 Op::PixelShuffle {
1730 input,
1731 upscale_factor: upscale_factor as u16,
1732 },
1733 ))
1734 }
1735
1736 pub fn upsample_nearest(
1738 &mut self,
1739 input: NodeId,
1740 scale_factor: usize,
1741 ) -> Result<NodeId, AutogradError> {
1742 let (value, requires_grad) = {
1743 let iv = &self.nodes[input.0].value;
1744 let shape = iv.shape();
1745 if shape.len() != 4 {
1746 return Err(AutogradError::InvalidRankForOperation {
1747 op: "upsample_nearest",
1748 expected: 4,
1749 got: shape.len(),
1750 });
1751 }
1752 let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
1753 let r = scale_factor;
1754 let out_h = h * r;
1755 let out_w = w * r;
1756 let data = iv.data();
1757 let mut out = vec![0.0f32; batch * out_h * out_w * c];
1758 for b in 0..batch {
1759 for oh in 0..out_h {
1760 let ih = oh / r;
1761 for ow in 0..out_w {
1762 let iw = ow / r;
1763 let src = ((b * h + ih) * w + iw) * c;
1764 let dst = ((b * out_h + oh) * out_w + ow) * c;
1765 out[dst..dst + c].copy_from_slice(&data[src..src + c]);
1766 }
1767 }
1768 }
1769 (
1770 Tensor::from_vec(vec![batch, out_h, out_w, c], out)?,
1771 self.nodes[input.0].requires_grad,
1772 )
1773 };
1774 Ok(self.push_node(
1775 value,
1776 requires_grad,
1777 Op::UpsampleNearest {
1778 input,
1779 scale_factor: scale_factor as u16,
1780 },
1781 ))
1782 }
1783
1784 pub fn rnn_forward(
1789 &mut self,
1790 input: NodeId,
1791 w_ih: NodeId,
1792 w_hh: NodeId,
1793 bias: NodeId,
1794 ) -> Result<NodeId, AutogradError> {
1795 let (value, requires_grad, hidden_states) = {
1796 let iv = &self.nodes[input.0].value;
1797 let wih = &self.nodes[w_ih.0].value;
1798 let whh = &self.nodes[w_hh.0].value;
1799 let bv = &self.nodes[bias.0].value;
1800 let shape = iv.shape();
1801 let seq_len = shape[0];
1802 let hidden_size = wih.shape()[1];
1803 let in_data = iv.data();
1804 let wih_data = wih.data();
1805 let whh_data = whh.data();
1806 let b_data = bv.data();
1807 let input_size = shape[1];
1808
1809 let mut hidden_states = Vec::with_capacity(seq_len + 1);
1810 hidden_states.push(Tensor::zeros(vec![hidden_size])?);
1812 let mut output_data = vec![0.0f32; seq_len * hidden_size];
1813
1814 for t in 0..seq_len {
1815 let h_prev = hidden_states[t].data();
1816 let x_base = t * input_size;
1817 let mut h_new = vec![0.0f32; hidden_size];
1818 for j in 0..hidden_size {
1819 let mut sum = b_data[j];
1820 for i in 0..input_size {
1821 sum += in_data[x_base + i] * wih_data[i * hidden_size + j];
1822 }
1823 for i in 0..hidden_size {
1824 sum += h_prev[i] * whh_data[i * hidden_size + j];
1825 }
1826 h_new[j] = sum.tanh();
1827 }
1828 output_data[t * hidden_size..(t + 1) * hidden_size].copy_from_slice(&h_new);
1829 hidden_states.push(Tensor::from_vec(vec![hidden_size], h_new)?);
1830 }
1831
1832 let rg = self.nodes[input.0].requires_grad
1833 || self.nodes[w_ih.0].requires_grad
1834 || self.nodes[w_hh.0].requires_grad
1835 || self.nodes[bias.0].requires_grad;
1836 (
1837 Tensor::from_vec(vec![seq_len, hidden_size], output_data)?,
1838 rg,
1839 hidden_states,
1840 )
1841 };
1842 Ok(self.push_node_with_aux(
1843 value,
1844 requires_grad,
1845 Op::Rnn {
1846 input,
1847 w_ih,
1848 w_hh,
1849 bias,
1850 },
1851 AuxData::RnnHiddenStates(hidden_states),
1852 ))
1853 }
1854
1855 pub fn lstm_forward(
1860 &mut self,
1861 input: NodeId,
1862 w_ih: NodeId,
1863 w_hh: NodeId,
1864 bias: NodeId,
1865 ) -> Result<NodeId, AutogradError> {
1866 let (value, requires_grad, hidden_states, cell_states, gates) = {
1867 let iv = &self.nodes[input.0].value;
1868 let wih = &self.nodes[w_ih.0].value;
1869 let whh = &self.nodes[w_hh.0].value;
1870 let bv = &self.nodes[bias.0].value;
1871 let shape = iv.shape();
1872 let seq_len = shape[0];
1873 let input_size = shape[1];
1874 let hidden_size = wih.shape()[1] / 4;
1875 let in_data = iv.data();
1876 let wih_data = wih.data();
1877 let whh_data = whh.data();
1878 let b_data = bv.data();
1879
1880 let mut hidden_states = Vec::with_capacity(seq_len + 1);
1881 let mut cell_states = Vec::with_capacity(seq_len + 1);
1882 let mut gates_vec: Vec<(Tensor, Tensor, Tensor, Tensor)> = Vec::with_capacity(seq_len);
1883 hidden_states.push(Tensor::zeros(vec![hidden_size])?);
1884 cell_states.push(Tensor::zeros(vec![hidden_size])?);
1885 let mut output_data = vec![0.0f32; seq_len * hidden_size];
1886
1887 for t in 0..seq_len {
1888 let h_prev = hidden_states[t].data();
1889 let c_prev = cell_states[t].data();
1890 let x_base = t * input_size;
1891 let h4 = 4 * hidden_size;
1892
1893 let mut raw_gates = vec![0.0f32; h4];
1895 for j in 0..h4 {
1896 let mut sum = b_data[j];
1897 for i in 0..input_size {
1898 sum += in_data[x_base + i] * wih_data[i * h4 + j];
1899 }
1900 for i in 0..hidden_size {
1901 sum += h_prev[i] * whh_data[i * h4 + j];
1902 }
1903 raw_gates[j] = sum;
1904 }
1905
1906 let mut i_gate = vec![0.0f32; hidden_size];
1907 let mut f_gate = vec![0.0f32; hidden_size];
1908 let mut g_gate = vec![0.0f32; hidden_size];
1909 let mut o_gate = vec![0.0f32; hidden_size];
1910 let mut c_new = vec![0.0f32; hidden_size];
1911 let mut h_new = vec![0.0f32; hidden_size];
1912
1913 for j in 0..hidden_size {
1914 i_gate[j] = sigmoid_f32(raw_gates[j]);
1915 f_gate[j] = sigmoid_f32(raw_gates[hidden_size + j]);
1916 g_gate[j] = raw_gates[2 * hidden_size + j].tanh();
1917 o_gate[j] = sigmoid_f32(raw_gates[3 * hidden_size + j]);
1918 c_new[j] = f_gate[j] * c_prev[j] + i_gate[j] * g_gate[j];
1919 h_new[j] = o_gate[j] * c_new[j].tanh();
1920 }
1921
1922 output_data[t * hidden_size..(t + 1) * hidden_size].copy_from_slice(&h_new);
1923 hidden_states.push(Tensor::from_vec(vec![hidden_size], h_new)?);
1924 cell_states.push(Tensor::from_vec(vec![hidden_size], c_new)?);
1925 gates_vec.push((
1926 Tensor::from_vec(vec![hidden_size], i_gate)?,
1927 Tensor::from_vec(vec![hidden_size], f_gate)?,
1928 Tensor::from_vec(vec![hidden_size], g_gate)?,
1929 Tensor::from_vec(vec![hidden_size], o_gate)?,
1930 ));
1931 }
1932
1933 let rg = self.nodes[input.0].requires_grad
1934 || self.nodes[w_ih.0].requires_grad
1935 || self.nodes[w_hh.0].requires_grad
1936 || self.nodes[bias.0].requires_grad;
1937 (
1938 Tensor::from_vec(vec![seq_len, hidden_size], output_data)?,
1939 rg,
1940 hidden_states,
1941 cell_states,
1942 gates_vec,
1943 )
1944 };
1945 Ok(self.push_node_with_aux(
1946 value,
1947 requires_grad,
1948 Op::Lstm {
1949 input,
1950 w_ih,
1951 w_hh,
1952 bias,
1953 },
1954 AuxData::LstmStates {
1955 hidden_states,
1956 cell_states,
1957 gates,
1958 },
1959 ))
1960 }
1961
1962 pub fn gru_forward(
1967 &mut self,
1968 input: NodeId,
1969 w_ih: NodeId,
1970 w_hh: NodeId,
1971 bias_ih: NodeId,
1972 bias_hh: NodeId,
1973 ) -> Result<NodeId, AutogradError> {
1974 let (value, requires_grad, hidden_states, gates) = {
1975 let iv = &self.nodes[input.0].value;
1976 let wih = &self.nodes[w_ih.0].value;
1977 let whh = &self.nodes[w_hh.0].value;
1978 let bih = &self.nodes[bias_ih.0].value;
1979 let bhh = &self.nodes[bias_hh.0].value;
1980 let shape = iv.shape();
1981 let seq_len = shape[0];
1982 let input_size = shape[1];
1983 let hidden_size = wih.shape()[1] / 3;
1984 let in_data = iv.data();
1985 let wih_data = wih.data();
1986 let whh_data = whh.data();
1987 let bih_data = bih.data();
1988 let bhh_data = bhh.data();
1989
1990 let mut hidden_states = Vec::with_capacity(seq_len + 1);
1991 let mut gates_vec: Vec<(Tensor, Tensor, Tensor)> = Vec::with_capacity(seq_len);
1992 hidden_states.push(Tensor::zeros(vec![hidden_size])?);
1993 let mut output_data = vec![0.0f32; seq_len * hidden_size];
1994
1995 for t in 0..seq_len {
1996 let h_prev = hidden_states[t].data();
1997 let x_base = t * input_size;
1998 let h3 = 3 * hidden_size;
1999
2000 let mut x_proj = vec![0.0f32; h3];
2002 for j in 0..h3 {
2003 let mut sum = bih_data[j];
2004 for i in 0..input_size {
2005 sum += in_data[x_base + i] * wih_data[i * h3 + j];
2006 }
2007 x_proj[j] = sum;
2008 }
2009 let mut h_proj = vec![0.0f32; h3];
2011 for j in 0..h3 {
2012 let mut sum = bhh_data[j];
2013 for i in 0..hidden_size {
2014 sum += h_prev[i] * whh_data[i * h3 + j];
2015 }
2016 h_proj[j] = sum;
2017 }
2018
2019 let mut r_gate = vec![0.0f32; hidden_size];
2020 let mut z_gate = vec![0.0f32; hidden_size];
2021 let mut n_candidate = vec![0.0f32; hidden_size];
2022 let mut h_new = vec![0.0f32; hidden_size];
2023
2024 for j in 0..hidden_size {
2025 r_gate[j] = sigmoid_f32(x_proj[j] + h_proj[j]);
2026 z_gate[j] = sigmoid_f32(x_proj[hidden_size + j] + h_proj[hidden_size + j]);
2027 n_candidate[j] = (x_proj[2 * hidden_size + j]
2028 + r_gate[j] * h_proj[2 * hidden_size + j])
2029 .tanh();
2030 h_new[j] = (1.0 - z_gate[j]) * n_candidate[j] + z_gate[j] * h_prev[j];
2031 }
2032
2033 output_data[t * hidden_size..(t + 1) * hidden_size].copy_from_slice(&h_new);
2034 hidden_states.push(Tensor::from_vec(vec![hidden_size], h_new)?);
2035 gates_vec.push((
2036 Tensor::from_vec(vec![hidden_size], r_gate)?,
2037 Tensor::from_vec(vec![hidden_size], z_gate)?,
2038 Tensor::from_vec(vec![hidden_size], n_candidate)?,
2039 ));
2040 }
2041
2042 let rg = self.nodes[input.0].requires_grad
2043 || self.nodes[w_ih.0].requires_grad
2044 || self.nodes[w_hh.0].requires_grad
2045 || self.nodes[bias_ih.0].requires_grad
2046 || self.nodes[bias_hh.0].requires_grad;
2047 (
2048 Tensor::from_vec(vec![seq_len, hidden_size], output_data)?,
2049 rg,
2050 hidden_states,
2051 gates_vec,
2052 )
2053 };
2054 Ok(self.push_node_with_aux(
2055 value,
2056 requires_grad,
2057 Op::Gru {
2058 input,
2059 w_ih,
2060 w_hh,
2061 bias_ih,
2062 bias_hh,
2063 },
2064 AuxData::GruStates {
2065 hidden_states,
2066 gates,
2067 },
2068 ))
2069 }
2070
2071 pub fn deformable_conv2d_nhwc(
2074 &mut self,
2075 input: NodeId,
2076 weight: NodeId,
2077 offsets: NodeId,
2078 bias: Option<NodeId>,
2079 stride: usize,
2080 padding: usize,
2081 ) -> Result<NodeId, AutogradError> {
2082 let (value, requires_grad) = {
2083 let iv = &self.nodes[input.0].value;
2084 let wv = &self.nodes[weight.0].value;
2085 let ov = &self.nodes[offsets.0].value;
2086 let bv: Option<&Tensor> = bias.map(|b| &self.nodes[b.0].value);
2087 let rg = self.nodes[input.0].requires_grad
2088 || self.nodes[weight.0].requires_grad
2089 || self.nodes[offsets.0].requires_grad
2090 || bias.is_some_and(|b| self.nodes[b.0].requires_grad);
2091 let result = yscv_kernels::deformable_conv2d_nhwc(iv, wv, ov, bv, stride, padding)?;
2092 (result, rg)
2093 };
2094 Ok(self.push_node(
2095 value,
2096 requires_grad,
2097 Op::DeformableConv2dNhwc {
2098 input,
2099 weight,
2100 offsets,
2101 bias,
2102 stride: stride as u16,
2103 padding: padding as u16,
2104 },
2105 ))
2106 }
2107
2108 pub(crate) fn push_node(&mut self, value: Tensor, requires_grad: bool, op: Op) -> NodeId {
2109 let id = NodeId(self.nodes.len());
2110 self.nodes.push(Node {
2111 value,
2112 grad: None,
2113 requires_grad,
2114 op,
2115 aux: None,
2116 });
2117 id
2118 }
2119
2120 pub(crate) fn push_node_with_aux(
2121 &mut self,
2122 value: Tensor,
2123 requires_grad: bool,
2124 op: Op,
2125 aux: super::node::AuxData,
2126 ) -> NodeId {
2127 let id = NodeId(self.nodes.len());
2128 self.nodes.push(Node {
2129 value,
2130 grad: None,
2131 requires_grad,
2132 op,
2133 aux: Some(aux),
2134 });
2135 id
2136 }
2137
2138 pub(crate) fn node(&self, id: NodeId) -> Result<&Node, AutogradError> {
2139 self.nodes
2140 .get(id.0)
2141 .ok_or(AutogradError::NodeNotFound { id: id.0 })
2142 }
2143
2144 pub(crate) fn node_mut(&mut self, id: NodeId) -> Result<&mut Node, AutogradError> {
2145 self.nodes
2146 .get_mut(id.0)
2147 .ok_or(AutogradError::NodeNotFound { id: id.0 })
2148 }
2149
2150 pub fn clip_grad_norm(&mut self, param_nodes: &[NodeId], max_norm: f32) -> f32 {
2153 let mut total_norm_sq = 0.0f32;
2154 for &node_id in param_nodes {
2155 if let Some(grad) = &self.nodes[node_id.0].grad {
2156 for &v in grad.data() {
2157 total_norm_sq += v * v;
2158 }
2159 }
2160 }
2161 let total_norm = total_norm_sq.sqrt();
2162
2163 if total_norm > max_norm {
2164 let scale = max_norm / total_norm;
2165 for &node_id in param_nodes {
2166 if let Some(grad) = &mut self.nodes[node_id.0].grad {
2167 for v in grad.data_mut() {
2168 *v *= scale;
2169 }
2170 }
2171 }
2172 }
2173 total_norm
2174 }
2175
2176 pub fn clip_grad_value(&mut self, param_nodes: &[NodeId], max_value: f32) {
2178 for &node_id in param_nodes {
2179 if let Some(grad) = &mut self.nodes[node_id.0].grad {
2180 for v in grad.data_mut() {
2181 *v = v.clamp(-max_value, max_value);
2182 }
2183 }
2184 }
2185 }
2186}
2187
2188fn softmax_last_dim(input: &Tensor) -> Tensor {
2190 let shape = input.shape();
2191 if shape.is_empty() {
2192 return input.clone();
2193 }
2194 let last = *shape.last().expect("non-empty shape");
2195 let outer = input.len() / last;
2196 let data = input.data();
2197 let mut out = vec![0.0f32; input.len()];
2198
2199 for o in 0..outer {
2200 let base = o * last;
2201 let max_val = data[base..base + last]
2202 .iter()
2203 .copied()
2204 .fold(f32::NEG_INFINITY, f32::max);
2205 let mut sum = 0.0f32;
2206 for i in 0..last {
2207 let e = (data[base + i] - max_val).exp();
2208 out[base + i] = e;
2209 sum += e;
2210 }
2211 let inv = 1.0 / sum;
2212 for i in 0..last {
2213 out[base + i] *= inv;
2214 }
2215 }
2216
2217 Tensor::from_vec(shape.to_vec(), out).expect("softmax_last_dim preserves shape")
2218}
2219
2220#[inline]
2221fn sigmoid_f32(x: f32) -> f32 {
2222 1.0 / (1.0 + (-x).exp())
2223}
2224
2225#[allow(clippy::needless_range_loop)]
2227fn copy_region_nd(
2228 src: &[f32],
2229 src_shape: &[usize],
2230 dst: &mut [f32],
2231 dst_shape: &[usize],
2232 offsets: &[u32],
2233) {
2234 let rank = src_shape.len();
2235 if rank == 0 {
2236 return;
2237 }
2238 let total: usize = src_shape.iter().product();
2239 let mut src_strides = vec![1usize; rank];
2240 let mut dst_strides = vec![1usize; rank];
2241 for d in (0..rank - 1).rev() {
2242 src_strides[d] = src_strides[d + 1] * src_shape[d + 1];
2243 dst_strides[d] = dst_strides[d + 1] * dst_shape[d + 1];
2244 }
2245 for flat in 0..total {
2246 let mut rem = flat;
2247 let mut dst_flat = 0usize;
2248 for d in 0..rank {
2249 let coord = rem / src_strides[d];
2250 rem %= src_strides[d];
2251 dst_flat += (coord + offsets[d] as usize) * dst_strides[d];
2252 }
2253 dst[dst_flat] = src[flat];
2254 }
2255}