1use crate::backend::{Backend, BinaryOp, ReduceOp, UnaryOp};
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub struct TensorId(pub(crate) u64);
38
39impl Default for TensorId {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl TensorId {
46 pub fn new() -> Self {
48 use std::sync::atomic::{AtomicU64, Ordering};
49 static COUNTER: AtomicU64 = AtomicU64::new(0);
50 TensorId(COUNTER.fetch_add(1, Ordering::Relaxed))
51 }
52}
53
54pub enum Op<B: Backend> {
62 None,
64
65 Binary {
67 lhs: crate::Tensor<B>,
68 rhs: crate::Tensor<B>,
69 op: BinaryOp,
70 },
71
72 Unary {
74 input: crate::Tensor<B>,
75 op: UnaryOp,
76 },
77
78 Reduce {
80 input: crate::Tensor<B>,
81 op: ReduceOp,
82 dims: Vec<usize>,
83 keep_dim: bool,
84 },
85
86 Matmul {
88 lhs: crate::Tensor<B>,
89 rhs: crate::Tensor<B>,
90 },
91
92 Reshape {
95 input: crate::Tensor<B>,
96 src_shape: crate::Shape,
97 },
98
99 Transpose {
101 input: crate::Tensor<B>,
102 dim0: usize,
103 dim1: usize,
104 },
105
106 Narrow {
108 input: crate::Tensor<B>,
109 dim: usize,
110 start: usize,
111 len: usize,
112 },
113
114 Affine {
116 input: crate::Tensor<B>,
117 mul: f64,
118 add: f64,
119 },
120
121 Contiguous { input: crate::Tensor<B> },
124
125 Conv2d {
128 input: crate::Tensor<B>,
129 weight: crate::Tensor<B>,
130 bias: Option<crate::Tensor<B>>,
131 stride: [usize; 2],
132 padding: [usize; 2],
133 },
134
135 MaxPool2d {
139 input: crate::Tensor<B>,
140 kernel_size: [usize; 2],
141 stride: [usize; 2],
142 padding: [usize; 2],
143 indices: Vec<usize>,
144 },
145
146 Cat {
152 inputs: Vec<crate::Tensor<B>>,
153 dim: usize,
154 sizes: Vec<usize>,
155 },
156
157 Powf {
159 input: crate::Tensor<B>,
160 exponent: f64,
161 },
162
163 Clamp {
165 input: crate::Tensor<B>,
166 min: f64,
167 max: f64,
168 },
169
170 WhereCond {
172 mask: crate::Tensor<B>,
173 on_true: crate::Tensor<B>,
174 on_false: crate::Tensor<B>,
175 },
176
177 Gather {
179 input: crate::Tensor<B>,
180 index: crate::Tensor<B>,
181 dim: usize,
182 },
183
184 Pad {
186 input: crate::Tensor<B>,
187 padding: Vec<[usize; 2]>,
188 },
189
190 AvgPool2d {
193 input: crate::Tensor<B>,
194 kernel_size: [usize; 2],
195 stride: [usize; 2],
196 padding: [usize; 2],
197 },
198
199 Conv1d {
202 input: crate::Tensor<B>,
203 weight: crate::Tensor<B>,
204 bias: Option<crate::Tensor<B>>,
205 stride: usize,
206 padding: usize,
207 },
208
209 IndexSelect {
212 input: crate::Tensor<B>,
213 indices: crate::Tensor<B>,
214 dim: usize,
215 },
216
217 ToDtype {
220 input: crate::Tensor<B>,
221 src_dtype: crate::dtype::DType,
222 },
223}
224
225impl<B: Backend> Clone for Op<B> {
228 fn clone(&self) -> Self {
229 match self {
230 Op::None => Op::None,
231 Op::Binary { lhs, rhs, op } => Op::Binary {
232 lhs: lhs.clone(),
233 rhs: rhs.clone(),
234 op: *op,
235 },
236 Op::Unary { input, op } => Op::Unary {
237 input: input.clone(),
238 op: *op,
239 },
240 Op::Reduce {
241 input,
242 op,
243 dims,
244 keep_dim,
245 } => Op::Reduce {
246 input: input.clone(),
247 op: *op,
248 dims: dims.clone(),
249 keep_dim: *keep_dim,
250 },
251 Op::Matmul { lhs, rhs } => Op::Matmul {
252 lhs: lhs.clone(),
253 rhs: rhs.clone(),
254 },
255 Op::Reshape { input, src_shape } => Op::Reshape {
256 input: input.clone(),
257 src_shape: src_shape.clone(),
258 },
259 Op::Transpose { input, dim0, dim1 } => Op::Transpose {
260 input: input.clone(),
261 dim0: *dim0,
262 dim1: *dim1,
263 },
264 Op::Narrow {
265 input,
266 dim,
267 start,
268 len,
269 } => Op::Narrow {
270 input: input.clone(),
271 dim: *dim,
272 start: *start,
273 len: *len,
274 },
275 Op::Affine { input, mul, add } => Op::Affine {
276 input: input.clone(),
277 mul: *mul,
278 add: *add,
279 },
280 Op::Contiguous { input } => Op::Contiguous {
281 input: input.clone(),
282 },
283 Op::Conv2d {
284 input,
285 weight,
286 bias,
287 stride,
288 padding,
289 } => Op::Conv2d {
290 input: input.clone(),
291 weight: weight.clone(),
292 bias: bias.clone(),
293 stride: *stride,
294 padding: *padding,
295 },
296 Op::MaxPool2d {
297 input,
298 kernel_size,
299 stride,
300 padding,
301 indices,
302 } => Op::MaxPool2d {
303 input: input.clone(),
304 kernel_size: *kernel_size,
305 stride: *stride,
306 padding: *padding,
307 indices: indices.clone(),
308 },
309 Op::Cat { inputs, dim, sizes } => Op::Cat {
310 inputs: inputs.clone(),
311 dim: *dim,
312 sizes: sizes.clone(),
313 },
314 Op::Powf { input, exponent } => Op::Powf {
315 input: input.clone(),
316 exponent: *exponent,
317 },
318 Op::Clamp { input, min, max } => Op::Clamp {
319 input: input.clone(),
320 min: *min,
321 max: *max,
322 },
323 Op::WhereCond {
324 mask,
325 on_true,
326 on_false,
327 } => Op::WhereCond {
328 mask: mask.clone(),
329 on_true: on_true.clone(),
330 on_false: on_false.clone(),
331 },
332 Op::Gather { input, index, dim } => Op::Gather {
333 input: input.clone(),
334 index: index.clone(),
335 dim: *dim,
336 },
337 Op::Pad { input, padding } => Op::Pad {
338 input: input.clone(),
339 padding: padding.clone(),
340 },
341 Op::AvgPool2d {
342 input,
343 kernel_size,
344 stride,
345 padding,
346 } => Op::AvgPool2d {
347 input: input.clone(),
348 kernel_size: *kernel_size,
349 stride: *stride,
350 padding: *padding,
351 },
352 Op::Conv1d {
353 input,
354 weight,
355 bias,
356 stride,
357 padding,
358 } => Op::Conv1d {
359 input: input.clone(),
360 weight: weight.clone(),
361 bias: bias.clone(),
362 stride: *stride,
363 padding: *padding,
364 },
365 Op::IndexSelect {
366 input,
367 indices,
368 dim,
369 } => Op::IndexSelect {
370 input: input.clone(),
371 indices: indices.clone(),
372 dim: *dim,
373 },
374 Op::ToDtype { input, src_dtype } => Op::ToDtype {
375 input: input.clone(),
376 src_dtype: *src_dtype,
377 },
378 }
379 }
380}
381
382impl<B: Backend> std::fmt::Debug for Op<B> {
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 match self {
386 Op::None => write!(f, "None"),
387 Op::Binary { lhs, rhs, op } => {
388 write!(f, "Binary({:?}, id={:?}, id={:?})", op, lhs.id(), rhs.id())
389 }
390 Op::Unary { input, op } => {
391 write!(f, "Unary({:?}, id={:?})", op, input.id())
392 }
393 Op::Reduce {
394 input, op, dims, ..
395 } => {
396 write!(f, "Reduce({:?}, dims={:?}, id={:?})", op, dims, input.id())
397 }
398 Op::Matmul { lhs, rhs } => {
399 write!(f, "Matmul(id={:?}, id={:?})", lhs.id(), rhs.id())
400 }
401 Op::Reshape { input, src_shape } => {
402 write!(f, "Reshape({} → ?, id={:?})", src_shape, input.id())
403 }
404 Op::Transpose { input, dim0, dim1 } => {
405 write!(f, "Transpose({}, {}, id={:?})", dim0, dim1, input.id())
406 }
407 Op::Narrow {
408 input,
409 dim,
410 start,
411 len,
412 } => {
413 write!(
414 f,
415 "Narrow(dim={}, {}..{}, id={:?})",
416 dim,
417 start,
418 start + len,
419 input.id()
420 )
421 }
422 Op::Affine { input, mul, add } => {
423 write!(f, "Affine(*{} +{}, id={:?})", mul, add, input.id())
424 }
425 Op::Contiguous { input } => {
426 write!(f, "Contiguous(id={:?})", input.id())
427 }
428 Op::Conv2d {
429 input,
430 weight,
431 bias,
432 stride,
433 padding,
434 } => {
435 write!(
436 f,
437 "Conv2d(in={:?}, w={:?}, bias={}, s={:?}, p={:?})",
438 input.id(),
439 weight.id(),
440 bias.is_some(),
441 stride,
442 padding
443 )
444 }
445 Op::MaxPool2d {
446 input,
447 kernel_size,
448 stride,
449 padding,
450 ..
451 } => {
452 write!(
453 f,
454 "MaxPool2d(in={:?}, k={:?}, s={:?}, p={:?})",
455 input.id(),
456 kernel_size,
457 stride,
458 padding
459 )
460 }
461 Op::Cat { inputs, dim, .. } => {
462 let ids: Vec<_> = inputs.iter().map(|t| t.id()).collect();
463 write!(f, "Cat(dim={}, ids={:?})", dim, ids)
464 }
465 Op::Powf { input, exponent } => {
466 write!(f, "Powf(exp={}, id={:?})", exponent, input.id())
467 }
468 Op::Clamp { input, min, max } => {
469 write!(f, "Clamp(min={}, max={}, id={:?})", min, max, input.id())
470 }
471 Op::WhereCond {
472 mask,
473 on_true,
474 on_false,
475 } => {
476 write!(
477 f,
478 "WhereCond(mask={:?}, true={:?}, false={:?})",
479 mask.id(),
480 on_true.id(),
481 on_false.id()
482 )
483 }
484 Op::Gather { input, index, dim } => {
485 write!(
486 f,
487 "Gather(dim={}, input={:?}, index={:?})",
488 dim,
489 input.id(),
490 index.id()
491 )
492 }
493 Op::Pad { input, padding } => {
494 write!(f, "Pad(pad={:?}, id={:?})", padding, input.id())
495 }
496 Op::AvgPool2d {
497 input,
498 kernel_size,
499 stride,
500 padding,
501 ..
502 } => {
503 write!(
504 f,
505 "AvgPool2d(in={:?}, k={:?}, s={:?}, p={:?})",
506 input.id(),
507 kernel_size,
508 stride,
509 padding
510 )
511 }
512 Op::Conv1d {
513 input,
514 weight,
515 bias,
516 stride,
517 padding,
518 } => {
519 write!(
520 f,
521 "Conv1d(in={:?}, w={:?}, bias={}, s={}, p={})",
522 input.id(),
523 weight.id(),
524 bias.is_some(),
525 stride,
526 padding
527 )
528 }
529 Op::IndexSelect {
530 input,
531 indices,
532 dim,
533 } => {
534 write!(
535 f,
536 "IndexSelect(dim={}, input={:?}, indices={:?})",
537 dim,
538 input.id(),
539 indices.id()
540 )
541 }
542 Op::ToDtype { input, src_dtype } => {
543 write!(f, "ToDtype(from={:?}, id={:?})", src_dtype, input.id())
544 }
545 }
546 }
547}
548
549impl<B: Backend> Op<B> {
550 pub fn inputs(&self) -> Vec<&crate::Tensor<B>> {
553 match self {
554 Op::None => vec![],
555 Op::Binary { lhs, rhs, .. } | Op::Matmul { lhs, rhs } => vec![lhs, rhs],
556 Op::Unary { input, .. }
557 | Op::Reduce { input, .. }
558 | Op::Reshape { input, .. }
559 | Op::Transpose { input, .. }
560 | Op::Narrow { input, .. }
561 | Op::Affine { input, .. }
562 | Op::Contiguous { input }
563 | Op::MaxPool2d { input, .. }
564 | Op::AvgPool2d { input, .. }
565 | Op::Powf { input, .. }
566 | Op::Clamp { input, .. } => vec![input],
567 Op::Conv2d {
568 input,
569 weight,
570 bias,
571 ..
572 } => {
573 let mut v = vec![input, weight];
574 if let Some(b) = bias {
575 v.push(b);
576 }
577 v
578 }
579 Op::Conv1d {
580 input,
581 weight,
582 bias,
583 ..
584 } => {
585 let mut v = vec![input, weight];
586 if let Some(b) = bias {
587 v.push(b);
588 }
589 v
590 }
591 Op::Cat { inputs, .. } => inputs.iter().collect(),
592 Op::WhereCond {
593 mask,
594 on_true,
595 on_false,
596 } => {
597 vec![mask, on_true, on_false]
598 }
599 Op::Gather { input, index, .. } => vec![input, index],
600 Op::IndexSelect { input, indices, .. } => vec![input, indices],
601 Op::ToDtype { input, .. } => vec![input],
602 Op::Pad { input, .. } => vec![input],
603 }
604 }
605}