1use bon::bon;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use smallvec::smallvec;
6use snafu::ResultExt;
7use svod_device::Buffer;
8use svod_dtype::DType;
9use svod_dtype::ext::HasDType;
10use svod_ir::{CallInfo, ConstValue, ConstValueHash, DeviceSpec, Op, SInt, UOp, UOpKey, shape::Shape};
11
12fn sint_vmax(s: &SInt) -> usize {
18 match s {
19 SInt::Const(v) => *v,
20 SInt::Symbolic(uop) => match uop.op() {
21 Op::DefineVar { max_val, .. } => *max_val as usize,
22 Op::Bind { var, .. } => match var.op() {
23 Op::DefineVar { max_val, .. } => *max_val as usize,
24 _ => 1,
25 },
26 _ => 1,
27 },
28 SInt::Infer => panic!("cannot compute vmax of SInt::Infer"),
29 }
30}
31
32fn find_assign_identity(target: &Arc<UOp>, base: &Arc<UOp>) -> Arc<UOp> {
33 let mut identity = target.clone();
34 while !identity.has_buffer_identity() && identity.id != base.id {
35 let sources = identity.op().sources();
36 let Some(next) = sources.first() else {
37 break;
38 };
39 identity = next.clone();
40 }
41 identity
42}
43
44pub mod error;
45use error::*;
46
47pub mod activation;
48pub mod arithmetic;
49pub mod bitwise;
50pub mod broadcast;
51pub mod conditional;
52pub mod config;
53pub mod data;
54pub mod einsum;
55pub mod indexing;
56pub mod math;
57pub mod matmul;
58pub mod memory_planner;
59pub mod nn;
60pub mod rand;
61pub mod realize;
62pub mod reduce;
63pub mod schedule;
64pub(crate) mod schedule_cache;
65pub mod shape_ops;
66pub mod tensor_registry;
67pub mod traits;
68pub mod transformer;
69pub mod variable;
70
71pub use config::PrepareConfig;
73pub use svod_runtime::CpuBackend;
74pub use tensor_registry::apply_map_to_tensors;
75pub use variable::{BoundVariable, Variable};
76
77#[derive(Debug, Clone, Copy)]
79enum CumReduceOp {
80 Add,
81 Mul,
82 #[allow(dead_code)]
83 Max,
84}
85
86impl CumReduceOp {
87 fn identity_value(&self, dtype: DType) -> f64 {
89 match self {
90 CumReduceOp::Add => 0.0,
91 CumReduceOp::Mul => 1.0,
92 CumReduceOp::Max => {
93 if dtype.is_int() {
94 i64::MIN as f64
95 } else {
96 f64::NEG_INFINITY
97 }
98 }
99 }
100 }
101}
102
103#[derive(Clone, Debug)]
107pub struct KernelInfo {
108 pub name: String,
110 pub code: String,
112 pub entry_point: String,
114 pub backend: String,
116}
117
118pub struct Tensor {
150 entry: Arc<tensor_registry::TensorEntry>,
152 buffer: Option<Arc<Buffer>>,
154}
155
156impl Clone for Tensor {
158 fn clone(&self) -> Self {
159 Self { entry: Arc::clone(&self.entry), buffer: self.buffer.clone() }
160 }
161}
162
163#[bon]
164impl Tensor {
165 fn new(uop: Arc<UOp>) -> Self {
167 let entry = tensor_registry::register_tensor(uop);
168 Self { entry, buffer: None }
169 }
170
171 pub fn from_lazy(uop: Arc<UOp>) -> Self {
174 Self::new(uop)
175 }
176
177 pub fn from_path(path: &std::path::Path) -> Result<Self> {
181 let file_size = std::fs::metadata(path)
182 .map_err(|e| Error::IrConstruction { details: format!("DISK: {}: {e}", path.display()) })?
183 .len() as usize;
184 let canonical = path
185 .canonicalize()
186 .map_err(|e| Error::IrConstruction { details: format!("DISK: {}: {e}", path.display()) })?;
187 let device = svod_dtype::DeviceSpec::Disk { path: canonical };
188 let buffer_uop = UOp::new_buffer(device, file_size, svod_dtype::DType::Scalar(svod_dtype::ScalarDType::UInt8));
189 Ok(Self::new(buffer_uop))
190 }
191
192 pub(crate) fn with_buffer(entry: Arc<tensor_registry::TensorEntry>, buffer: Arc<Buffer>) -> Self {
194 Self { entry, buffer: Some(buffer) }
195 }
196
197 fn has_zero_elements(&self) -> bool {
199 match self.uop().shape() {
200 Ok(Some(shape)) => shape.iter().any(|dim| dim.as_const() == Some(0)),
201 _ => false,
202 }
203 }
204
205 pub(crate) fn ensure_buffer(&self) {
211 let buffer_id = self.uop().base().id;
212 if let Some(buf_arc) = tensor_registry::get_buffer_arc(buffer_id) {
213 self.entry.set_buffer(buf_arc);
214 }
215 }
216
217 pub fn uop(&self) -> Arc<UOp> {
221 self.entry.uop.read().clone()
222 }
223
224 pub fn kernels(&self) -> Vec<KernelInfo> {
229 Vec::new()
231 }
232
233 pub fn empty(shape: &[usize], dtype: DType) -> Self {
239 let numel: usize = shape.iter().product();
240 let buffer_uop = UOp::new_buffer(DeviceSpec::Cpu, numel, dtype);
241 let ir_shape = Shape::from_iter(shape.iter().map(|&d| SInt::Const(d)));
242 let uop = buffer_uop.try_reshape(&ir_shape).expect("shape matches element count");
243 Self::new(uop)
244 }
245
246 pub fn empty_dynamic(shape: &[SInt], dtype: DType) -> Self {
253 let numel: usize = shape.iter().map(sint_vmax).product();
254 let buffer_uop = UOp::new_buffer(DeviceSpec::Cpu, numel, dtype);
255 let ir_shape = Shape::from_iter(shape.iter().cloned());
256 let uop = buffer_uop.try_reshape(&ir_shape).expect("shape valid for reshape");
257 Self::new(uop)
258 }
259
260 pub fn empty_zero(dtype: DType) -> Self {
262 Self::empty(&[0], dtype)
263 }
264
265 pub fn full(shape: &[usize], value: impl Into<ConstValue>, dtype: DType) -> Result<Self> {
267 let scalar = Self::const_(value, dtype);
268 if shape.is_empty() {
269 return Ok(scalar);
270 }
271 let expand_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
272 scalar.try_reshape(vec![1; shape.len()])?.try_expand(&expand_shape)
273 }
274
275 pub fn zeros(shape: &[usize], dtype: DType) -> Result<Self> {
277 Self::full(shape, ConstValue::zero(dtype.base()), dtype)
278 }
279
280 pub fn ones(shape: &[usize], dtype: DType) -> Result<Self> {
282 Self::full(shape, ConstValue::one(dtype.base()), dtype)
283 }
284
285 pub fn full_dynamic(shape: &[SInt], value: impl Into<ConstValue>, dtype: DType) -> Result<Self> {
300 let const_uop = UOp::const_(dtype.clone(), value.into());
301 if shape.is_empty() {
302 return Ok(Self::new(const_uop));
303 }
304 let ones: Shape = vec![SInt::Const(1); shape.len()].into();
307 let target: Shape = shape.to_vec().into();
308 let reshaped = const_uop.try_reshape(&ones).context(error::UOpSnafu)?;
309 let expanded = reshaped.try_expand(&target).context(error::UOpSnafu)?;
310 Ok(Self::new(expanded))
311 }
312
313 pub fn zeros_dynamic(shape: &[SInt], dtype: DType) -> Result<Self> {
315 Self::full_dynamic(shape, ConstValue::zero(dtype.base()), dtype)
316 }
317
318 pub fn ones_dynamic(shape: &[SInt], dtype: DType) -> Result<Self> {
320 Self::full_dynamic(shape, ConstValue::one(dtype.base()), dtype)
321 }
322
323 fn _cumalu(&self, axis: isize, reduce: CumReduceOp) -> Result<Self> {
328 let shape = self.shape()?;
329 let ndim = shape.len();
330 let axis_idx = Self::normalize_axis(axis, ndim)?;
331 let n = shape[axis_idx]
332 .as_const()
333 .ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "_cumalu".to_string() })?;
334
335 if n <= 1 {
336 return Ok(self.clone());
337 }
338
339 let x = if axis_idx != ndim - 1 { self.try_transpose(axis_idx as isize, -1)? } else { self.clone() };
341
342 let identity = reduce.identity_value(self.uop().dtype());
344 let mut padding = vec![(0isize, 0isize); ndim];
345 padding[ndim - 1] = ((n - 1) as isize, 0);
346 let x = x.try_pad_value(&padding, identity)?;
347
348 let x = x.pool(&[n], &[1], &[1])?;
350
351 let x = match reduce {
353 CumReduceOp::Add => x.sum(-1isize)?,
354 CumReduceOp::Mul => x.prod(-1isize)?,
355 CumReduceOp::Max => x.max(-1isize)?,
356 };
357
358 if axis_idx != ndim - 1 { x.try_transpose(axis_idx as isize, -1) } else { Ok(x) }
360 }
361
362 pub fn cumsum(&self, axis: isize) -> Result<Self> {
364 self._cumalu(axis, CumReduceOp::Add)
365 }
366
367 pub fn cumprod(&self, axis: isize) -> Result<Self> {
369 self._cumalu(axis, CumReduceOp::Mul)
370 }
371
372 #[builder]
378 pub fn arange_with_dtype(
379 start: Arc<UOp>,
380 stop: Option<Arc<UOp>>,
381 dtype: DType,
382 #[builder(default = UOp::const_(dtype.clone(), ConstValue::one(dtype.base())))] step: Arc<UOp>,
383 ) -> Result<Self> {
384 let (start, stop) = match stop {
385 Some(s) => (start, s),
386 None => (UOp::const_(dtype.clone(), ConstValue::zero(dtype.base())), start),
387 };
388
389 let step_tensor = if let Op::Const(ConstValueHash(ConstValue::Int(start))) = start.op()
390 && let Op::Const(ConstValueHash(ConstValue::Int(stop))) = stop.op()
391 && let Op::Const(ConstValueHash(s @ ConstValue::Int(step))) = step.op()
392 {
393 let diff = stop - start;
394 let ceildiv = ((diff as f64) / (*step as f64)).ceil() as i64;
395 if ceildiv <= 0 {
396 return Ok(Self::empty_zero(dtype));
397 }
398
399 Self::full(&[ceildiv as usize], *s, dtype.clone())?
400 } else {
401 let diff = stop.sub(&start);
402 let one = UOp::const_(dtype.clone(), ConstValue::one(dtype.base()));
403 let ceildiv = diff.add(&step.sub(&one)).idiv(&step);
404 let output_len_sint = SInt::from(ceildiv.clone());
405 let ones: Shape = vec![SInt::Const(1)].into();
406 let target: Shape = vec![output_len_sint].into();
407 let reshaped = step.try_reshape(&ones).unwrap();
408 Self::new(reshaped.try_expand(&target).unwrap())
409 };
410
411 let cumsum = step_tensor._cumalu(0, CumReduceOp::Add)?;
412 let offset = Self::new(start.sub(&step));
413 cumsum.try_add(&offset)?.cast(dtype)
414 }
415
416 pub fn arange(start: i64, stop: Option<i64>, step: Option<i64>) -> Result<Self> {
418 let dtype = DType::Int32;
419 Self::arange_with_dtype()
420 .start(UOp::const_(dtype.clone(), ConstValue::Int(start)))
421 .maybe_stop(stop.map(|s| UOp::const_(dtype.clone(), ConstValue::Int(s))))
422 .maybe_step(step.map(|s| UOp::const_(dtype.clone(), ConstValue::Int(s))))
423 .dtype(dtype)
424 .call()
425 }
426
427 pub fn arange_f64(start: f64, stop: f64, step: f64, dtype: DType) -> Result<Self> {
429 if step == 0.0 {
430 return Err(Error::SymbolicShapeUnsupported { operation: "arange with step=0".to_string() });
431 }
432 let count = ((stop - start) / step).ceil() as i64;
433 if count <= 0 {
434 return Ok(Self::empty_zero(dtype));
435 }
436 let count = count as usize;
437 let step_tensor = Self::full(&[count], ConstValue::Float(step), dtype.clone())?;
438 let cumsum = step_tensor._cumalu(0, CumReduceOp::Add)?;
439 let offset = Self::const_(ConstValue::Float(start - step), dtype.clone());
440 cumsum.try_add(&offset)?.cast(dtype)
441 }
442
443 pub fn linspace(start: f64, end: f64, steps: usize, dtype: DType) -> Result<Self> {
445 if steps == 0 {
446 return Ok(Self::empty_zero(dtype));
447 }
448 if steps == 1 {
449 return Self::full(&[1], start, dtype);
450 }
451 let t = Self::arange(steps as i64, None, None)?;
452 let scale = Self::const_((end - start) / (steps as f64 - 1.0), DType::Float64);
453 let offset = Tensor::const_(start, DType::Float64);
454 t.cast(DType::Float64)?.try_mul(&scale)?.try_add(&offset)?.cast(dtype)
455 }
456
457 pub fn const_<T: Into<ConstValue>>(value: T, dtype: DType) -> Self {
478 let const_val = value.into();
479 let uop = UOp::const_(dtype, const_val);
480 Self::new(uop)
481 }
482
483 pub fn from_const<T: Into<ConstValue> + HasDType>(value: T) -> Self {
494 let dtype = T::DTYPE;
495 Self::const_(value, dtype)
496 }
497
498 pub fn device(&self) -> DeviceSpec {
510 self.uop().device_spec().unwrap_or(DeviceSpec::Cpu)
511 }
512
513 pub fn to(&self, device: DeviceSpec) -> Self {
525 if self.device() == device {
526 return self.clone();
527 }
528
529 let copy_uop = self.uop().copy_to_device(device);
530 Self::new(copy_uop)
531 }
532
533 pub fn cast(&self, dtype: svod_dtype::DType) -> Result<Self> {
541 let casted = self.uop().cast(dtype);
542 Ok(Self::new(casted))
543 }
544
545 pub fn custom_kernel<F>(&self, others: &[&Tensor], fxn: F) -> Result<Vec<Tensor>>
551 where
552 F: FnOnce(Vec<Arc<UOp>>) -> Arc<UOp>,
553 {
554 self.custom_kernel_with(others, CallInfo::default(), fxn)
555 }
556
557 pub fn custom_kernel_with<F>(&self, others: &[&Tensor], info: CallInfo, fxn: F) -> Result<Vec<Tensor>>
559 where
560 F: FnOnce(Vec<Arc<UOp>>) -> Arc<UOp>,
561 {
562 let mut srcs: Vec<Arc<UOp>> = Vec::with_capacity(1 + others.len());
563 srcs.push(self.uop());
564 srcs.extend(others.iter().map(|t| t.uop()));
565
566 let outputs = UOp::custom_kernel(srcs, fxn, info).context(UOpSnafu)?;
567 Ok(outputs.into_iter().map(Self::from_lazy).collect())
568 }
569
570 pub fn bitcast(&self, dtype: svod_dtype::DType) -> Result<Self> {
584 let src_dt = self.uop().dtype();
585 let src_scalar = src_dt.scalar().ok_or_else(|| Error::SymbolicShapeUnsupported {
586 operation: "bitcast: non-scalar source dtype".to_string(),
587 })?;
588 let dst_scalar = dtype.scalar().ok_or_else(|| Error::SymbolicShapeUnsupported {
589 operation: "bitcast: non-scalar destination dtype".to_string(),
590 })?;
591 let src_size = src_scalar.bytes();
592 let dst_size = dst_scalar.bytes();
593
594 if src_size == dst_size {
595 return Ok(Self::new(self.uop().bitcast(dtype)));
596 }
597
598 let shape = self.shape()?;
599 let last_dim = shape.last().and_then(|s| s.as_const()).ok_or_else(|| Error::SymbolicShapeUnsupported {
600 operation: "bitcast with size change on symbolic last dim".to_string(),
601 })?;
602 if last_dim * src_size % dst_size != 0 {
603 return Err(Error::ReshapeSizeMismatch {
604 operation: format!(
605 "bitcast {src_scalar:?}({src_size}B) → {dst_scalar:?}({dst_size}B): \
606 last dim {last_dim} × {src_size} not divisible by {dst_size}"
607 ),
608 });
609 }
610
611 let src_uint = DType::Scalar(uint_for_bytes(src_size));
612 let dst_uint = DType::Scalar(uint_for_bytes(dst_size));
613
614 let tmp = if src_dt == src_uint { self.clone() } else { Self::new(self.uop().bitcast(src_uint.clone())) };
617
618 let result = if dst_size > src_size {
619 let rate = dst_size / src_size;
622 let mut new_shape: Vec<isize> = svod_ir::shape::to_vec_isize(&shape).context(UOpSnafu)?;
623 let last_idx = new_shape.len() - 1;
624 new_shape[last_idx] = (last_dim / rate) as isize;
625 new_shape.push(rate as isize);
626 let reshaped = tmp.try_reshape(&new_shape)?;
627
628 let mut acc: Option<Tensor> = None;
629 for i in 0..rate {
630 let mut shrink_ranges: Vec<Option<(isize, isize)>> =
632 std::iter::repeat_n(None, new_shape.len() - 1).collect();
633 shrink_ranges.push(Some((i as isize, (i + 1) as isize)));
634 let slice = reshaped.try_shrink(shrink_ranges)?;
635 let widened = slice.cast(dst_uint.clone())?;
636 let shift_amount = 8 * i * src_size;
637 let term = if shift_amount == 0 {
638 widened
639 } else {
640 let shift_t = Tensor::full(
641 &svod_ir::shape::to_vec_usize(&widened.shape()?).context(UOpSnafu)?,
642 ConstValue::UInt(shift_amount as u64),
643 dst_uint.clone(),
644 )?;
645 widened.try_shl(&shift_t)?
646 };
647 acc = Some(match acc {
648 None => term,
649 Some(a) => a.try_bitor(&term)?,
650 });
651 }
652 let summed = acc.expect("rate >= 1");
653 summed.try_squeeze(Some(-1))?
655 } else {
656 let rate = src_size / dst_size;
659 let mut shifted: Vec<Tensor> = Vec::with_capacity(rate);
660 for i in 0..rate {
661 let shift_amount = 8 * i * dst_size;
662 let s = if shift_amount == 0 {
663 tmp.clone()
664 } else {
665 let shift_t = Tensor::full(
666 &svod_ir::shape::to_vec_usize(&tmp.shape()?).context(UOpSnafu)?,
667 ConstValue::UInt(shift_amount as u64),
668 src_uint.clone(),
669 )?;
670 tmp.try_shr(&shift_t)?
671 };
672 shifted.push(s);
673 }
674 let refs: Vec<&Tensor> = shifted.iter().collect();
675 let stacked = Tensor::stack(&refs, -1)?;
676 let stacked_shape = stacked.shape()?;
678 let nd = stacked_shape.len();
679 let mut new_shape: Vec<isize> = svod_ir::shape::to_vec_isize(&stacked_shape).context(UOpSnafu)?;
680 let trailing = new_shape[nd - 2] * new_shape[nd - 1];
681 new_shape.truncate(nd - 2);
682 new_shape.push(trailing);
683 let flat = stacked.try_reshape(&new_shape)?;
684 flat.cast(dst_uint.clone())?
685 };
686
687 if result.uop().dtype() == dtype { Ok(result) } else { Ok(Self::new(result.uop().bitcast(dtype))) }
689 }
690}
691
692fn uint_for_bytes(n: usize) -> svod_dtype::ScalarDType {
693 use svod_dtype::ScalarDType;
694 match n {
695 1 => ScalarDType::UInt8,
696 2 => ScalarDType::UInt16,
697 4 => ScalarDType::UInt32,
698 8 => ScalarDType::UInt64,
699 _ => panic!("uint_for_bytes: unsupported byte size {n}"),
700 }
701}
702
703#[allow(dead_code)]
704impl Tensor {
705 pub fn try_assign(&self, value: &Tensor) -> Result<()> {
718 let target_uop = self.uop();
719 if self.device().is_disk() {
720 return Err(Error::IrConstruction {
721 details: "assign to DISK tensors is not supported by Svod runtime".to_string(),
722 });
723 }
724
725 let target_shape = self.shape()?;
726 let value_shape = value.shape()?;
727 let value = if target_shape != value_shape { value.broadcast_to(&target_shape)? } else { value.clone() };
728 if self.device() != value.device() {
729 return Err(Error::IrConstruction {
730 details: format!("assign device mismatch {:?} != {:?}", self.device(), value.device()),
731 });
732 }
733
734 let target_dtype = target_uop.dtype();
735 let value_dtype = value.uop().dtype();
736 if target_dtype != value_dtype {
737 return Err(Error::TypeMismatch { expected: target_dtype, actual: value_dtype });
738 }
739
740 let value_uop = value.uop();
741 if Arc::ptr_eq(&target_uop, &value_uop) {
742 return Ok(());
743 }
744
745 let assign_effect = target_uop.after(smallvec![target_uop.store(value_uop)]);
746 let base = target_uop.base();
747 if matches!(base.op(), Op::Buffer { .. } | Op::After { .. })
748 && target_uop.id != base.id
749 && !target_uop.has_buffer_identity()
750 {
751 let identity = find_assign_identity(&target_uop, &base);
752 let assigned_identity = identity.after(smallvec![assign_effect]);
753 #[allow(clippy::mutable_key_type)]
754 let mut becomes_map = HashMap::new();
755 becomes_map.insert(UOpKey(identity), assigned_identity);
756 tensor_registry::apply_map_to_tensors_walk(&becomes_map);
760 } else {
761 self.set_uop(assign_effect);
762 }
763 Ok(())
764 }
765
766 pub fn assign(&self, value: &Tensor) {
767 self.try_assign(value).expect("tensor assign failed");
768 }
769
770 pub(crate) fn set_uop(&self, uop: Arc<UOp>) {
775 *self.entry.uop.write() = uop;
776 }
777
778 pub fn contiguous(&self) -> Self {
792 let uop = self.uop();
793 if matches!(uop.op(), svod_ir::Op::Contiguous { .. }) {
794 return self.clone();
795 }
796 let contiguous_uop = uop.contiguous();
797 Self::new(contiguous_uop)
798 }
799}
800
801impl Tensor {
802 pub(crate) fn broadcast_scalar(&self, value: ConstValue) -> Result<Self> {
804 let shape = self.shape()?;
805 let scalar = Self::new(UOp::const_(self.uop().dtype(), value));
806 scalar.broadcast_to(&shape)
807 }
808
809 pub fn zero(&self) -> Result<Self> {
811 let sdtype = self.uop().dtype().scalar().expect("scalar dtype");
812 self.broadcast_scalar(ConstValue::zero(sdtype))
813 }
814
815 pub fn one(&self) -> Result<Self> {
817 let sdtype = self.uop().dtype().scalar().expect("scalar dtype");
818 self.broadcast_scalar(ConstValue::one(sdtype))
819 }
820
821 pub fn eye(n: usize, m: usize, dtype: DType) -> Result<Self> {
823 let rows = Self::arange(n as i64, None, None)?.try_unsqueeze(-1)?;
824 let cols = Self::arange(m as i64, None, None)?;
825 rows.try_eq(&cols)?.cast(dtype)
826 }
827}
828
829#[bon]
830impl Tensor {
831 #[builder]
833 pub fn cumsum_with(
834 &self,
835 axis: isize,
836 #[builder(default = false)] exclusive: bool,
837 #[builder(default = false)] reverse: bool,
838 ) -> Result<Self> {
839 let shape = self.shape()?;
840 let ndim = shape.len();
841 let axis_idx = Self::normalize_axis(axis, ndim)?;
842 let mut result = self.clone();
843 if reverse {
844 result = result.flip(&[axis_idx as isize])?;
845 }
846 if exclusive {
847 let dim_size = shape[axis_idx].as_const().unwrap() as isize;
848 let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); ndim];
849 pad_spec[axis_idx] = (1, 0);
850 result = result.try_pad(&pad_spec)?;
851 let mut shrink_spec: Vec<(isize, isize)> =
852 result.shape()?.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
853 shrink_spec[axis_idx] = (0, dim_size);
854 result = result.try_shrink(&shrink_spec)?;
855 }
856 result = result.cumsum(axis_idx as isize)?;
857 if reverse {
858 result = result.flip(&[axis_idx as isize])?;
859 }
860 Ok(result)
861 }
862
863 #[builder]
865 pub fn cumprod_with(
866 &self,
867 axis: isize,
868 #[builder(default = false)] exclusive: bool,
869 #[builder(default = false)] reverse: bool,
870 ) -> Result<Self> {
871 let shape = self.shape()?;
872 let ndim = shape.len();
873 let axis_idx = Self::normalize_axis(axis, ndim)?;
874 let mut result = self.clone();
875 if reverse {
876 result = result.flip(&[axis_idx as isize])?;
877 }
878 if exclusive {
879 let dim_size = shape[axis_idx].as_const().unwrap() as isize;
880 let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); ndim];
881 pad_spec[axis_idx] = (1, 0);
882 result = result.try_pad_value(&pad_spec, 1.0)?;
883 let mut shrink_spec: Vec<(isize, isize)> =
884 result.shape()?.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
885 shrink_spec[axis_idx] = (0, dim_size);
886 result = result.try_shrink(&shrink_spec)?;
887 }
888 result = result.cumprod(axis_idx as isize)?;
889 if reverse {
890 result = result.flip(&[axis_idx as isize])?;
891 }
892 Ok(result)
893 }
894}
895
896#[cfg(test)]
897mod test;