Skip to main content

shrew_core/
op.rs

1// Op — Computational graph node for automatic differentiation
2//
3// Every tensor that results from a computation records HOW it was created
4// via the Op enum. This forms a directed acyclic graph (DAG) that backward()
5// traverses to compute gradients.
6//
7// Example: c = a + b
8//   a.op = Op::None (leaf variable)
9//   b.op = Op::None (leaf variable)
10//   c.op = Op::Binary { lhs: a, rhs: b, op: Add }
11//
12// When we call c.backward():
13//   1. Start with grad_c = 1.0 (by convention, dL/dL = 1)
14//   2. Look at c.op → Binary { lhs: a, rhs: b, Add }
15//   3. grad_a += grad_c * d(a+b)/da = grad_c * 1 = grad_c
16//   4. grad_b += grad_c * d(a+b)/db = grad_c * 1 = grad_c
17//
18// WHY STORE Tensor<B> INSTEAD OF TensorId?
19//
20// In Phase 1, Op stored only TensorIds. Now in Phase 2, each Op variant stores
21// the actual Tensor<B> references to its inputs. Since Tensor<B> is Arc-wrapped,
22// cloning is cheap (just increment refcount). This means:
23//
24//   1. backward() can directly access input values for gradient computation
25//      (e.g., d(a*b)/da = b — we need the actual value of b)
26//   2. The computation graph keeps input tensors alive as long as the output
27//      tensor exists (correct: we need them for backward)
28//   3. No separate tensor registry needed — the graph IS the references
29//
30// MEMORY: The graph forms a DAG (no cycles), so Arc handles cleanup correctly.
31// When the loss tensor is dropped, all intermediate tensors' refcounts decrease.
32
33use crate::backend::{Backend, BinaryOp, ReduceOp, UnaryOp};
34
35/// Unique identifier for a tensor. Used as keys in GradStore.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub struct TensorId(pub(crate) u64);
38
39impl Default for TensorId {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl TensorId {
46    /// Generate a new unique tensor ID (uses a global atomic counter).
47    pub fn new() -> Self {
48        use std::sync::atomic::{AtomicU64, Ordering};
49        static COUNTER: AtomicU64 = AtomicU64::new(0);
50        TensorId(COUNTER.fetch_add(1, Ordering::Relaxed))
51    }
52}
53
54/// Records the operation that produced a tensor, storing references to inputs.
55///
56/// Each variant holds the actual input Tensor(s) (Arc-wrapped, cheap to clone)
57/// plus the operation parameters. backward() uses these to compute gradients
58/// via the chain rule.
59///
60/// Op<B> is generic over the Backend because it stores Tensor<B>.
61pub enum Op<B: Backend> {
62    /// No operation — this is a leaf tensor (input data or trainable parameter).
63    None,
64
65    /// Element-wise binary: result = op(lhs, rhs)
66    Binary {
67        lhs: crate::Tensor<B>,
68        rhs: crate::Tensor<B>,
69        op: BinaryOp,
70    },
71
72    /// Element-wise unary: result = op(input)
73    Unary {
74        input: crate::Tensor<B>,
75        op: UnaryOp,
76    },
77
78    /// Reduction: result = reduce(input, dims)
79    Reduce {
80        input: crate::Tensor<B>,
81        op: ReduceOp,
82        dims: Vec<usize>,
83        keep_dim: bool,
84    },
85
86    /// Matrix multiplication: result = lhs @ rhs
87    Matmul {
88        lhs: crate::Tensor<B>,
89        rhs: crate::Tensor<B>,
90    },
91
92    /// Reshape (includes squeeze/unsqueeze): same data, different shape.
93    /// src_shape records the original shape so backward can reshape gradients back.
94    Reshape {
95        input: crate::Tensor<B>,
96        src_shape: crate::Shape,
97    },
98
99    /// Transpose: swap two dimensions
100    Transpose {
101        input: crate::Tensor<B>,
102        dim0: usize,
103        dim1: usize,
104    },
105
106    /// Narrow/slice along a dimension
107    Narrow {
108        input: crate::Tensor<B>,
109        dim: usize,
110        start: usize,
111        len: usize,
112    },
113
114    /// Affine transform: result = input * mul + add
115    Affine {
116        input: crate::Tensor<B>,
117        mul: f64,
118        add: f64,
119    },
120
121    /// Contiguous copy: same logical values, but data is now contiguous in memory.
122    /// Gradient passes through unchanged.
123    Contiguous { input: crate::Tensor<B> },
124
125    /// 2D convolution: result = conv2d(input, weight) + bias
126    /// input: [N, C_in, H, W], weight: [C_out, C_in, kH, kW]
127    Conv2d {
128        input: crate::Tensor<B>,
129        weight: crate::Tensor<B>,
130        bias: Option<crate::Tensor<B>>,
131        stride: [usize; 2],
132        padding: [usize; 2],
133    },
134
135    /// 2D max-pooling.
136    /// input: [N, C, H, W]
137    /// indices stores the argmax positions for backward.
138    MaxPool2d {
139        input: crate::Tensor<B>,
140        kernel_size: [usize; 2],
141        stride: [usize; 2],
142        padding: [usize; 2],
143        indices: Vec<usize>,
144    },
145
146    /// Concatenation along a dimension.
147    /// `inputs` are the original tensors that were concatenated.
148    /// `dim` is the concatenation dimension.
149    /// `sizes` stores the size of each input along `dim` (needed by backward
150    /// to slice the gradient back into per-input pieces via narrow).
151    Cat {
152        inputs: Vec<crate::Tensor<B>>,
153        dim: usize,
154        sizes: Vec<usize>,
155    },
156
157    /// Element-wise power: result = input ^ exponent.
158    Powf {
159        input: crate::Tensor<B>,
160        exponent: f64,
161    },
162
163    /// Element-wise clamp: result = clamp(input, min, max).
164    Clamp {
165        input: crate::Tensor<B>,
166        min: f64,
167        max: f64,
168    },
169
170    /// Conditional select: result[i] = if mask[i] { on_true[i] } else { on_false[i] }.
171    WhereCond {
172        mask: crate::Tensor<B>,
173        on_true: crate::Tensor<B>,
174        on_false: crate::Tensor<B>,
175    },
176
177    /// Gather elements along a dimension using index tensor.
178    Gather {
179        input: crate::Tensor<B>,
180        index: crate::Tensor<B>,
181        dim: usize,
182    },
183
184    /// Constant padding.
185    Pad {
186        input: crate::Tensor<B>,
187        padding: Vec<[usize; 2]>,
188    },
189
190    /// 2D average-pooling.
191    /// input: [N, C, H, W]
192    AvgPool2d {
193        input: crate::Tensor<B>,
194        kernel_size: [usize; 2],
195        stride: [usize; 2],
196        padding: [usize; 2],
197    },
198
199    /// 1D convolution: result = conv1d(input, weight) + bias
200    /// input: [N, C_in, L], weight: [C_out, C_in, K]
201    Conv1d {
202        input: crate::Tensor<B>,
203        weight: crate::Tensor<B>,
204        bias: Option<crate::Tensor<B>>,
205        stride: usize,
206        padding: usize,
207    },
208
209    /// Index select along a dimension: result = input.index_select(dim, indices)
210    /// Backward = scatter-add of grad_output into grad_input at index positions.
211    IndexSelect {
212        input: crate::Tensor<B>,
213        indices: crate::Tensor<B>,
214        dim: usize,
215    },
216
217    /// Dtype conversion: result = input.to_dtype(target_dtype)
218    /// Backward casts gradient back to the original dtype.
219    ToDtype {
220        input: crate::Tensor<B>,
221        src_dtype: crate::dtype::DType,
222    },
223}
224
225// Manual Clone implementation because derive can't handle the generic well.
226// All clones are cheap: Tensor clone is just Arc refcount increment.
227impl<B: Backend> Clone for Op<B> {
228    fn clone(&self) -> Self {
229        match self {
230            Op::None => Op::None,
231            Op::Binary { lhs, rhs, op } => Op::Binary {
232                lhs: lhs.clone(),
233                rhs: rhs.clone(),
234                op: *op,
235            },
236            Op::Unary { input, op } => Op::Unary {
237                input: input.clone(),
238                op: *op,
239            },
240            Op::Reduce {
241                input,
242                op,
243                dims,
244                keep_dim,
245            } => Op::Reduce {
246                input: input.clone(),
247                op: *op,
248                dims: dims.clone(),
249                keep_dim: *keep_dim,
250            },
251            Op::Matmul { lhs, rhs } => Op::Matmul {
252                lhs: lhs.clone(),
253                rhs: rhs.clone(),
254            },
255            Op::Reshape { input, src_shape } => Op::Reshape {
256                input: input.clone(),
257                src_shape: src_shape.clone(),
258            },
259            Op::Transpose { input, dim0, dim1 } => Op::Transpose {
260                input: input.clone(),
261                dim0: *dim0,
262                dim1: *dim1,
263            },
264            Op::Narrow {
265                input,
266                dim,
267                start,
268                len,
269            } => Op::Narrow {
270                input: input.clone(),
271                dim: *dim,
272                start: *start,
273                len: *len,
274            },
275            Op::Affine { input, mul, add } => Op::Affine {
276                input: input.clone(),
277                mul: *mul,
278                add: *add,
279            },
280            Op::Contiguous { input } => Op::Contiguous {
281                input: input.clone(),
282            },
283            Op::Conv2d {
284                input,
285                weight,
286                bias,
287                stride,
288                padding,
289            } => Op::Conv2d {
290                input: input.clone(),
291                weight: weight.clone(),
292                bias: bias.clone(),
293                stride: *stride,
294                padding: *padding,
295            },
296            Op::MaxPool2d {
297                input,
298                kernel_size,
299                stride,
300                padding,
301                indices,
302            } => Op::MaxPool2d {
303                input: input.clone(),
304                kernel_size: *kernel_size,
305                stride: *stride,
306                padding: *padding,
307                indices: indices.clone(),
308            },
309            Op::Cat { inputs, dim, sizes } => Op::Cat {
310                inputs: inputs.clone(),
311                dim: *dim,
312                sizes: sizes.clone(),
313            },
314            Op::Powf { input, exponent } => Op::Powf {
315                input: input.clone(),
316                exponent: *exponent,
317            },
318            Op::Clamp { input, min, max } => Op::Clamp {
319                input: input.clone(),
320                min: *min,
321                max: *max,
322            },
323            Op::WhereCond {
324                mask,
325                on_true,
326                on_false,
327            } => Op::WhereCond {
328                mask: mask.clone(),
329                on_true: on_true.clone(),
330                on_false: on_false.clone(),
331            },
332            Op::Gather { input, index, dim } => Op::Gather {
333                input: input.clone(),
334                index: index.clone(),
335                dim: *dim,
336            },
337            Op::Pad { input, padding } => Op::Pad {
338                input: input.clone(),
339                padding: padding.clone(),
340            },
341            Op::AvgPool2d {
342                input,
343                kernel_size,
344                stride,
345                padding,
346            } => Op::AvgPool2d {
347                input: input.clone(),
348                kernel_size: *kernel_size,
349                stride: *stride,
350                padding: *padding,
351            },
352            Op::Conv1d {
353                input,
354                weight,
355                bias,
356                stride,
357                padding,
358            } => Op::Conv1d {
359                input: input.clone(),
360                weight: weight.clone(),
361                bias: bias.clone(),
362                stride: *stride,
363                padding: *padding,
364            },
365            Op::IndexSelect {
366                input,
367                indices,
368                dim,
369            } => Op::IndexSelect {
370                input: input.clone(),
371                indices: indices.clone(),
372                dim: *dim,
373            },
374            Op::ToDtype { input, src_dtype } => Op::ToDtype {
375                input: input.clone(),
376                src_dtype: *src_dtype,
377            },
378        }
379    }
380}
381
382// Concise Debug: show op type and tensor IDs only (not full tensor data).
383impl<B: Backend> std::fmt::Debug for Op<B> {
384    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385        match self {
386            Op::None => write!(f, "None"),
387            Op::Binary { lhs, rhs, op } => {
388                write!(f, "Binary({:?}, id={:?}, id={:?})", op, lhs.id(), rhs.id())
389            }
390            Op::Unary { input, op } => {
391                write!(f, "Unary({:?}, id={:?})", op, input.id())
392            }
393            Op::Reduce {
394                input, op, dims, ..
395            } => {
396                write!(f, "Reduce({:?}, dims={:?}, id={:?})", op, dims, input.id())
397            }
398            Op::Matmul { lhs, rhs } => {
399                write!(f, "Matmul(id={:?}, id={:?})", lhs.id(), rhs.id())
400            }
401            Op::Reshape { input, src_shape } => {
402                write!(f, "Reshape({} → ?, id={:?})", src_shape, input.id())
403            }
404            Op::Transpose { input, dim0, dim1 } => {
405                write!(f, "Transpose({}, {}, id={:?})", dim0, dim1, input.id())
406            }
407            Op::Narrow {
408                input,
409                dim,
410                start,
411                len,
412            } => {
413                write!(
414                    f,
415                    "Narrow(dim={}, {}..{}, id={:?})",
416                    dim,
417                    start,
418                    start + len,
419                    input.id()
420                )
421            }
422            Op::Affine { input, mul, add } => {
423                write!(f, "Affine(*{} +{}, id={:?})", mul, add, input.id())
424            }
425            Op::Contiguous { input } => {
426                write!(f, "Contiguous(id={:?})", input.id())
427            }
428            Op::Conv2d {
429                input,
430                weight,
431                bias,
432                stride,
433                padding,
434            } => {
435                write!(
436                    f,
437                    "Conv2d(in={:?}, w={:?}, bias={}, s={:?}, p={:?})",
438                    input.id(),
439                    weight.id(),
440                    bias.is_some(),
441                    stride,
442                    padding
443                )
444            }
445            Op::MaxPool2d {
446                input,
447                kernel_size,
448                stride,
449                padding,
450                ..
451            } => {
452                write!(
453                    f,
454                    "MaxPool2d(in={:?}, k={:?}, s={:?}, p={:?})",
455                    input.id(),
456                    kernel_size,
457                    stride,
458                    padding
459                )
460            }
461            Op::Cat { inputs, dim, .. } => {
462                let ids: Vec<_> = inputs.iter().map(|t| t.id()).collect();
463                write!(f, "Cat(dim={}, ids={:?})", dim, ids)
464            }
465            Op::Powf { input, exponent } => {
466                write!(f, "Powf(exp={}, id={:?})", exponent, input.id())
467            }
468            Op::Clamp { input, min, max } => {
469                write!(f, "Clamp(min={}, max={}, id={:?})", min, max, input.id())
470            }
471            Op::WhereCond {
472                mask,
473                on_true,
474                on_false,
475            } => {
476                write!(
477                    f,
478                    "WhereCond(mask={:?}, true={:?}, false={:?})",
479                    mask.id(),
480                    on_true.id(),
481                    on_false.id()
482                )
483            }
484            Op::Gather { input, index, dim } => {
485                write!(
486                    f,
487                    "Gather(dim={}, input={:?}, index={:?})",
488                    dim,
489                    input.id(),
490                    index.id()
491                )
492            }
493            Op::Pad { input, padding } => {
494                write!(f, "Pad(pad={:?}, id={:?})", padding, input.id())
495            }
496            Op::AvgPool2d {
497                input,
498                kernel_size,
499                stride,
500                padding,
501                ..
502            } => {
503                write!(
504                    f,
505                    "AvgPool2d(in={:?}, k={:?}, s={:?}, p={:?})",
506                    input.id(),
507                    kernel_size,
508                    stride,
509                    padding
510                )
511            }
512            Op::Conv1d {
513                input,
514                weight,
515                bias,
516                stride,
517                padding,
518            } => {
519                write!(
520                    f,
521                    "Conv1d(in={:?}, w={:?}, bias={}, s={}, p={})",
522                    input.id(),
523                    weight.id(),
524                    bias.is_some(),
525                    stride,
526                    padding
527                )
528            }
529            Op::IndexSelect {
530                input,
531                indices,
532                dim,
533            } => {
534                write!(
535                    f,
536                    "IndexSelect(dim={}, input={:?}, indices={:?})",
537                    dim,
538                    input.id(),
539                    indices.id()
540                )
541            }
542            Op::ToDtype { input, src_dtype } => {
543                write!(f, "ToDtype(from={:?}, id={:?})", src_dtype, input.id())
544            }
545        }
546    }
547}
548
549impl<B: Backend> Op<B> {
550    /// Return references to all input tensors of this operation.
551    /// Used by topological sort in backward() to traverse the graph.
552    pub fn inputs(&self) -> Vec<&crate::Tensor<B>> {
553        match self {
554            Op::None => vec![],
555            Op::Binary { lhs, rhs, .. } | Op::Matmul { lhs, rhs } => vec![lhs, rhs],
556            Op::Unary { input, .. }
557            | Op::Reduce { input, .. }
558            | Op::Reshape { input, .. }
559            | Op::Transpose { input, .. }
560            | Op::Narrow { input, .. }
561            | Op::Affine { input, .. }
562            | Op::Contiguous { input }
563            | Op::MaxPool2d { input, .. }
564            | Op::AvgPool2d { input, .. }
565            | Op::Powf { input, .. }
566            | Op::Clamp { input, .. } => vec![input],
567            Op::Conv2d {
568                input,
569                weight,
570                bias,
571                ..
572            } => {
573                let mut v = vec![input, weight];
574                if let Some(b) = bias {
575                    v.push(b);
576                }
577                v
578            }
579            Op::Conv1d {
580                input,
581                weight,
582                bias,
583                ..
584            } => {
585                let mut v = vec![input, weight];
586                if let Some(b) = bias {
587                    v.push(b);
588                }
589                v
590            }
591            Op::Cat { inputs, .. } => inputs.iter().collect(),
592            Op::WhereCond {
593                mask,
594                on_true,
595                on_false,
596            } => {
597                vec![mask, on_true, on_false]
598            }
599            Op::Gather { input, index, .. } => vec![input, index],
600            Op::IndexSelect { input, indices, .. } => vec![input, indices],
601            Op::ToDtype { input, .. } => vec![input],
602            Op::Pad { input, .. } => vec![input],
603        }
604    }
605}