Skip to main content

svod_tensor/
lib.rs

1use bon::bon;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use smallvec::smallvec;
6use snafu::ResultExt;
7use svod_device::Buffer;
8use svod_dtype::DType;
9use svod_dtype::ext::HasDType;
10use svod_ir::{CallInfo, ConstValue, ConstValueHash, DeviceSpec, Op, SInt, UOp, UOpKey, shape::Shape};
11
12/// Extract max value from an SInt for buffer allocation.
13///
14/// Concrete dims return their value. Symbolic dims (DefineVar, Bind)
15/// return `max_val` from the underlying Variable, enabling rebinding
16/// without reallocation. Matches Tinygrad's `x.vmax`.
17fn sint_vmax(s: &SInt) -> usize {
18    match s {
19        SInt::Const(v) => *v,
20        SInt::Symbolic(uop) => match uop.op() {
21            Op::DefineVar { max_val, .. } => *max_val as usize,
22            Op::Bind { var, .. } => match var.op() {
23                Op::DefineVar { max_val, .. } => *max_val as usize,
24                _ => 1,
25            },
26            _ => 1,
27        },
28        SInt::Infer => panic!("cannot compute vmax of SInt::Infer"),
29    }
30}
31
32fn find_assign_identity(target: &Arc<UOp>, base: &Arc<UOp>) -> Arc<UOp> {
33    let mut identity = target.clone();
34    while !identity.has_buffer_identity() && identity.id != base.id {
35        let sources = identity.op().sources();
36        let Some(next) = sources.first() else {
37            break;
38        };
39        identity = next.clone();
40    }
41    identity
42}
43
44pub mod error;
45use error::*;
46
47pub mod activation;
48pub mod arithmetic;
49pub mod bitwise;
50pub mod broadcast;
51pub mod conditional;
52pub mod config;
53pub mod data;
54pub mod einsum;
55pub mod indexing;
56pub mod math;
57pub mod matmul;
58pub mod memory_planner;
59pub mod nn;
60pub mod rand;
61pub mod realize;
62pub mod reduce;
63pub mod schedule;
64pub(crate) mod schedule_cache;
65pub mod shape_ops;
66pub mod tensor_registry;
67pub mod traits;
68pub mod transformer;
69pub mod variable;
70
71// Re-export for public API
72pub use config::PrepareConfig;
73pub use svod_runtime::CpuBackend;
74pub use tensor_registry::apply_map_to_tensors;
75pub use variable::{BoundVariable, Variable};
76
77/// Reduction operations supported by cumulative reduce (`_cumalu`).
78#[derive(Debug, Clone, Copy)]
79enum CumReduceOp {
80    Add,
81    Mul,
82    #[allow(dead_code)]
83    Max,
84}
85
86impl CumReduceOp {
87    /// Identity element for this operation as f64, used as pad fill value.
88    fn identity_value(&self, dtype: DType) -> f64 {
89        match self {
90            CumReduceOp::Add => 0.0,
91            CumReduceOp::Mul => 1.0,
92            CumReduceOp::Max => {
93                if dtype.is_int() {
94                    i64::MIN as f64
95                } else {
96                    f64::NEG_INFINITY
97                }
98            }
99        }
100    }
101}
102
103/// Information about a rendered kernel.
104///
105/// This is the public API returned by `tensor.kernels()`.
106#[derive(Clone, Debug)]
107pub struct KernelInfo {
108    /// Kernel name (e.g., "kernel")
109    pub name: String,
110    /// Generated code (LLVM IR, CUDA PTX, etc.)
111    pub code: String,
112    /// Entry point function name
113    pub entry_point: String,
114    /// Backend that generated this kernel
115    pub backend: String,
116}
117
118/// Tensor represents a multi-dimensional array with lazy evaluation.
119///
120/// Operations like addition and multiplication build a computation graph
121/// without allocating buffers. Buffers are only allocated when:
122/// - Creating input tensors via `from_slice()`
123/// - Evaluating the computation graph via `realize()`
124///
125/// # Global Graph Substitution
126///
127/// Tensors are registered in a global registry to support atomic graph substitution.
128/// When rangeify transforms a UOp (e.g., NEG → BUFFERIZE(NEG)), all tensors
129/// referencing it are updated atomically via `apply_map_to_tensors()`.
130///
131/// This is critical for diamond patterns (like argmin's NEG feeding both MAX and EQ)
132/// where different consumers must see the same transformed version.
133///
134/// # Buffer Ownership (RAII)
135///
136/// Tensors own their buffers via `Arc<Buffer>`. When all Tensor clones referencing
137/// a buffer are dropped, the buffer is automatically freed. This provides RAII
138/// cleanup without manual buffer management.
139///
140/// # Examples
141///
142/// ```
143/// # use svod_tensor::Tensor;
144/// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
145/// let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0]);
146/// let mut c = &a + &b;  // Lazy - only builds UOp graph
147/// c.realize().unwrap();  // Executes the computation
148/// ```
149pub struct Tensor {
150    /// Registry entry holding the computation graph (supports global substitution)
151    entry: Arc<tensor_registry::TensorEntry>,
152    /// Owned buffer for RAII cleanup. None for lazy tensors.
153    buffer: Option<Arc<Buffer>>,
154}
155
156// Manual Clone impl to share Arc<Buffer> across clones
157impl Clone for Tensor {
158    fn clone(&self) -> Self {
159        Self { entry: Arc::clone(&self.entry), buffer: self.buffer.clone() }
160    }
161}
162
163#[bon]
164impl Tensor {
165    /// Create tensor without buffer (for lazy computation graphs).
166    fn new(uop: Arc<UOp>) -> Self {
167        let entry = tensor_registry::register_tensor(uop);
168        Self { entry, buffer: None }
169    }
170
171    /// Create a lazy tensor from a UOp graph (no buffer allocated).
172    /// Used for deferred computation graphs like ONNX weight views.
173    pub fn from_lazy(uop: Arc<UOp>) -> Self {
174        Self::new(uop)
175    }
176
177    /// Create a file-backed tensor using the DISK device (Tinygrad: `Tensor(pathlib.Path)`).
178    /// The file is memory-mapped lazily — no data is read until the tensor is realized.
179    /// The resulting tensor has dtype `uint8` and shape `(file_size,)`.
180    pub fn from_path(path: &std::path::Path) -> Result<Self> {
181        let file_size = std::fs::metadata(path)
182            .map_err(|e| Error::IrConstruction { details: format!("DISK: {}: {e}", path.display()) })?
183            .len() as usize;
184        let canonical = path
185            .canonicalize()
186            .map_err(|e| Error::IrConstruction { details: format!("DISK: {}: {e}", path.display()) })?;
187        let device = svod_dtype::DeviceSpec::Disk { path: canonical };
188        let buffer_uop = UOp::new_buffer(device, file_size, svod_dtype::DType::Scalar(svod_dtype::ScalarDType::UInt8));
189        Ok(Self::new(buffer_uop))
190    }
191
192    /// Create tensor with existing buffer (for input tensors and realize results).
193    pub(crate) fn with_buffer(entry: Arc<tensor_registry::TensorEntry>, buffer: Arc<Buffer>) -> Self {
194        Self { entry, buffer: Some(buffer) }
195    }
196
197    /// Check if this tensor has zero total elements (any shape dimension is 0).
198    fn has_zero_elements(&self) -> bool {
199        match self.uop().shape() {
200            Ok(Some(shape)) => shape.iter().any(|dim| dim.as_const() == Some(0)),
201            _ => false,
202        }
203    }
204
205    /// Ensure buffer is attached if the UOp has buffer identity.
206    ///
207    /// When `apply_map_to_tensors` substitutes a tensor's UOp with a realized
208    /// BUFFER+RESHAPE, the Tensor struct's `buffer` field isn't updated.
209    /// This method looks up the buffer from the registry and attaches it.
210    pub(crate) fn ensure_buffer(&self) {
211        let buffer_id = self.uop().base().id;
212        if let Some(buf_arc) = tensor_registry::get_buffer_arc(buffer_id) {
213            self.entry.set_buffer(buf_arc);
214        }
215    }
216
217    /// Get the current UOp for this tensor.
218    ///
219    /// This reads from the registry, so it reflects any global substitutions.
220    pub fn uop(&self) -> Arc<UOp> {
221        self.entry.uop.read().clone()
222    }
223
224    /// Get kernels for THIS tensor.
225    ///
226    /// Note: Kernel tracking is not yet implemented with the new registry.
227    /// This returns an empty list for now.
228    pub fn kernels(&self) -> Vec<KernelInfo> {
229        // TODO: Implement kernel tracking with the new registry
230        Vec::new()
231    }
232
233    /// Create an uninitialized buffer-backed tensor with the given shape and dtype.
234    ///
235    /// No device memory is allocated — only the BUFFER UOp is created.
236    /// Use `assign()` to bind real data before `realize()`.
237    /// Matches Tinygrad's `Tensor.empty(*shape)`.
238    pub fn empty(shape: &[usize], dtype: DType) -> Self {
239        let numel: usize = shape.iter().product();
240        let buffer_uop = UOp::new_buffer(DeviceSpec::Cpu, numel, dtype);
241        let ir_shape = Shape::from_iter(shape.iter().map(|&d| SInt::Const(d)));
242        let uop = buffer_uop.try_reshape(&ir_shape).expect("shape matches element count");
243        Self::new(uop)
244    }
245
246    /// Create an uninitialized buffer-backed tensor with symbolic (dynamic) dimensions.
247    ///
248    /// Buffer is sized to `prod(vmax)` — each symbolic dim uses its Variable's
249    /// max_val for allocation. This enables rebinding to any value in [min, max]
250    /// without reallocation. Matches Tinygrad's
251    /// `prod([x.vmax if isinstance(x, UOp) else x for x in shape])`.
252    pub fn empty_dynamic(shape: &[SInt], dtype: DType) -> Self {
253        let numel: usize = shape.iter().map(sint_vmax).product();
254        let buffer_uop = UOp::new_buffer(DeviceSpec::Cpu, numel, dtype);
255        let ir_shape = Shape::from_iter(shape.iter().cloned());
256        let uop = buffer_uop.try_reshape(&ir_shape).expect("shape valid for reshape");
257        Self::new(uop)
258    }
259
260    /// Create an empty 0-element tensor with the given dtype and shape `[0]`.
261    pub fn empty_zero(dtype: DType) -> Self {
262        Self::empty(&[0], dtype)
263    }
264
265    /// Create a tensor filled with a constant value, broadcast to the given shape.
266    pub fn full(shape: &[usize], value: impl Into<ConstValue>, dtype: DType) -> Result<Self> {
267        let scalar = Self::const_(value, dtype);
268        if shape.is_empty() {
269            return Ok(scalar);
270        }
271        let expand_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
272        scalar.try_reshape(vec![1; shape.len()])?.try_expand(&expand_shape)
273    }
274
275    /// Create a zero-filled tensor with the given concrete shape.
276    pub fn zeros(shape: &[usize], dtype: DType) -> Result<Self> {
277        Self::full(shape, ConstValue::zero(dtype.base()), dtype)
278    }
279
280    /// Create a one-filled tensor with the given concrete shape.
281    pub fn ones(shape: &[usize], dtype: DType) -> Result<Self> {
282        Self::full(shape, ConstValue::one(dtype.base()), dtype)
283    }
284
285    /// Create a tensor filled with a constant value, using symbolic (dynamic) dimensions.
286    ///
287    /// Dimensions can be concrete (`SInt::Const`) or symbolic (`SInt::Symbolic`
288    /// from [`Variable::bind()`](crate::Variable::bind)).
289    ///
290    /// # Example
291    ///
292    /// ```ignore
293    /// use svod_tensor::{Tensor, Variable};
294    /// use svod_dtype::DType;
295    ///
296    /// let batch = Variable::new("batch", 1, 32);
297    /// let x = Tensor::full_dynamic(&[batch.bind(16)?.into(), 784.into()], 0.0, DType::Float32)?;
298    /// ```
299    pub fn full_dynamic(shape: &[SInt], value: impl Into<ConstValue>, dtype: DType) -> Result<Self> {
300        let const_uop = UOp::const_(dtype.clone(), value.into());
301        if shape.is_empty() {
302            return Ok(Self::new(const_uop));
303        }
304        // Reshape scalar to [1, 1, ...] then expand to target shape.
305        // Expand handles both concrete and symbolic (SInt::Symbolic) dims.
306        let ones: Shape = vec![SInt::Const(1); shape.len()].into();
307        let target: Shape = shape.to_vec().into();
308        let reshaped = const_uop.try_reshape(&ones).context(error::UOpSnafu)?;
309        let expanded = reshaped.try_expand(&target).context(error::UOpSnafu)?;
310        Ok(Self::new(expanded))
311    }
312
313    /// Create a zero-filled tensor with symbolic (dynamic) dimensions.
314    pub fn zeros_dynamic(shape: &[SInt], dtype: DType) -> Result<Self> {
315        Self::full_dynamic(shape, ConstValue::zero(dtype.base()), dtype)
316    }
317
318    /// Create a one-filled tensor with symbolic (dynamic) dimensions.
319    pub fn ones_dynamic(shape: &[SInt], dtype: DType) -> Result<Self> {
320        Self::full_dynamic(shape, ConstValue::one(dtype.base()), dtype)
321    }
322
323    /// Cumulative reduce along an axis using a sliding-window approach.
324    ///
325    /// Decomposes prefix-sum/prefix-max/prefix-prod into existing ops:
326    /// pad → pool (sliding windows) → reduce. Fully lazy, O(1) graph nodes.
327    fn _cumalu(&self, axis: isize, reduce: CumReduceOp) -> Result<Self> {
328        let shape = self.shape()?;
329        let ndim = shape.len();
330        let axis_idx = Self::normalize_axis(axis, ndim)?;
331        let n = shape[axis_idx]
332            .as_const()
333            .ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "_cumalu".to_string() })?;
334
335        if n <= 1 {
336            return Ok(self.clone());
337        }
338
339        // 1. Transpose target axis to last
340        let x = if axis_idx != ndim - 1 { self.try_transpose(axis_idx as isize, -1)? } else { self.clone() };
341
342        // 2. Pad left with (n-1) identity elements
343        let identity = reduce.identity_value(self.uop().dtype());
344        let mut padding = vec![(0isize, 0isize); ndim];
345        padding[ndim - 1] = ((n - 1) as isize, 0);
346        let x = x.try_pad_value(&padding, identity)?;
347
348        // 3. Pool with kernel=n, stride=1
349        let x = x.pool(&[n], &[1], &[1])?;
350
351        // 4. Reduce last dim
352        let x = match reduce {
353            CumReduceOp::Add => x.sum(-1isize)?,
354            CumReduceOp::Mul => x.prod(-1isize)?,
355            CumReduceOp::Max => x.max(-1isize)?,
356        };
357
358        // 5. Transpose back
359        if axis_idx != ndim - 1 { x.try_transpose(axis_idx as isize, -1) } else { Ok(x) }
360    }
361
362    /// Cumulative sum along an axis.
363    pub fn cumsum(&self, axis: isize) -> Result<Self> {
364        self._cumalu(axis, CumReduceOp::Add)
365    }
366
367    /// Cumulative product along an axis.
368    pub fn cumprod(&self, axis: isize) -> Result<Self> {
369        self._cumalu(axis, CumReduceOp::Mul)
370    }
371
372    /// Create 1D tensor with evenly spaced values and explicit dtype.
373    ///
374    /// Matches Tinygrad's `Tensor.arange()`: `full(step) → cumsum → + (start - step)`.
375    /// Accepts concrete `i64` or symbolic `Arc<UOp>` for start/stop/step.
376    /// If `stop` is None, treats `start` as stop and starts from 0.
377    #[builder]
378    pub fn arange_with_dtype(
379        start: Arc<UOp>,
380        stop: Option<Arc<UOp>>,
381        dtype: DType,
382        #[builder(default = UOp::const_(dtype.clone(), ConstValue::one(dtype.base())))] step: Arc<UOp>,
383    ) -> Result<Self> {
384        let (start, stop) = match stop {
385            Some(s) => (start, s),
386            None => (UOp::const_(dtype.clone(), ConstValue::zero(dtype.base())), start),
387        };
388
389        let step_tensor = if let Op::Const(ConstValueHash(ConstValue::Int(start))) = start.op()
390            && let Op::Const(ConstValueHash(ConstValue::Int(stop))) = stop.op()
391            && let Op::Const(ConstValueHash(s @ ConstValue::Int(step))) = step.op()
392        {
393            let diff = stop - start;
394            let ceildiv = ((diff as f64) / (*step as f64)).ceil() as i64;
395            if ceildiv <= 0 {
396                return Ok(Self::empty_zero(dtype));
397            }
398
399            Self::full(&[ceildiv as usize], *s, dtype.clone())?
400        } else {
401            let diff = stop.sub(&start);
402            let one = UOp::const_(dtype.clone(), ConstValue::one(dtype.base()));
403            let ceildiv = diff.add(&step.sub(&one)).idiv(&step);
404            let output_len_sint = SInt::from(ceildiv.clone());
405            let ones: Shape = vec![SInt::Const(1)].into();
406            let target: Shape = vec![output_len_sint].into();
407            let reshaped = step.try_reshape(&ones).unwrap();
408            Self::new(reshaped.try_expand(&target).unwrap())
409        };
410
411        let cumsum = step_tensor._cumalu(0, CumReduceOp::Add)?;
412        let offset = Self::new(start.sub(&step));
413        cumsum.try_add(&offset)?.cast(dtype)
414    }
415
416    /// Create 1D tensor with evenly spaced Int32 values.
417    pub fn arange(start: i64, stop: Option<i64>, step: Option<i64>) -> Result<Self> {
418        let dtype = DType::Int32;
419        Self::arange_with_dtype()
420            .start(UOp::const_(dtype.clone(), ConstValue::Int(start)))
421            .maybe_stop(stop.map(|s| UOp::const_(dtype.clone(), ConstValue::Int(s))))
422            .maybe_step(step.map(|s| UOp::const_(dtype.clone(), ConstValue::Int(s))))
423            .dtype(dtype)
424            .call()
425    }
426
427    /// Create 1D tensor with evenly spaced values (float parameters).
428    pub fn arange_f64(start: f64, stop: f64, step: f64, dtype: DType) -> Result<Self> {
429        if step == 0.0 {
430            return Err(Error::SymbolicShapeUnsupported { operation: "arange with step=0".to_string() });
431        }
432        let count = ((stop - start) / step).ceil() as i64;
433        if count <= 0 {
434            return Ok(Self::empty_zero(dtype));
435        }
436        let count = count as usize;
437        let step_tensor = Self::full(&[count], ConstValue::Float(step), dtype.clone())?;
438        let cumsum = step_tensor._cumalu(0, CumReduceOp::Add)?;
439        let offset = Self::const_(ConstValue::Float(start - step), dtype.clone());
440        cumsum.try_add(&offset)?.cast(dtype)
441    }
442
443    /// Create 1D tensor with `steps` evenly spaced values from `start` to `end` (inclusive).
444    pub fn linspace(start: f64, end: f64, steps: usize, dtype: DType) -> Result<Self> {
445        if steps == 0 {
446            return Ok(Self::empty_zero(dtype));
447        }
448        if steps == 1 {
449            return Self::full(&[1], start, dtype);
450        }
451        let t = Self::arange(steps as i64, None, None)?;
452        let scale = Self::const_((end - start) / (steps as f64 - 1.0), DType::Float64);
453        let offset = Tensor::const_(start, DType::Float64);
454        t.cast(DType::Float64)?.try_mul(&scale)?.try_add(&offset)?.cast(dtype)
455    }
456
457    // === Constant Constructors ===
458
459    /// Create a scalar constant tensor.
460    ///
461    /// Creates a 0-dimensional tensor containing a single constant value.
462    /// The constant is embedded directly in the IR and does not allocate
463    /// a buffer until realized (if needed).
464    ///
465    /// # Arguments
466    /// * `value` - The constant value (will be converted to ConstValue)
467    /// * `dtype` - The data type for the tensor
468    ///
469    /// # Examples
470    /// ```ignore
471    /// // Float constant
472    /// let pi = Tensor::const_(3.14159, DType::Float32);
473    ///
474    /// // Integer constant
475    /// let forty_two = Tensor::const_(42i64, DType::Int64);
476    /// ```
477    pub fn const_<T: Into<ConstValue>>(value: T, dtype: DType) -> Self {
478        let const_val = value.into();
479        let uop = UOp::const_(dtype, const_val);
480        Self::new(uop)
481    }
482
483    /// Create a scalar constant tensor with dtype auto-inferred from value.
484    ///
485    /// Convenience method that infers dtype from the Rust type.
486    ///
487    /// # Examples
488    /// ```ignore
489    /// let f = Tensor::from_const(3.14f32);  // DType::Float32
490    /// let i = Tensor::from_const(42i32);    // DType::Int32
491    /// let b = Tensor::from_const(true);     // DType::Bool
492    /// ```
493    pub fn from_const<T: Into<ConstValue> + HasDType>(value: T) -> Self {
494        let dtype = T::DTYPE;
495        Self::const_(value, dtype)
496    }
497
498    /// Get device specification from underlying UOp graph.
499    ///
500    /// Returns the device where this tensor's data resides.
501    /// For lazy tensors (not yet realized), returns the target device.
502    /// Defaults to CPU if no device is found in the graph.
503    ///
504    /// # Examples
505    /// ```ignore
506    /// let cpu_tensor = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
507    /// assert_eq!(cpu_tensor.device(), DeviceSpec::Cpu);
508    /// ```
509    pub fn device(&self) -> DeviceSpec {
510        self.uop().device_spec().unwrap_or(DeviceSpec::Cpu)
511    }
512
513    /// Move tensor to a different device.
514    ///
515    /// Creates a lazy COPY operation. Data is not transferred until `realize()`.
516    /// If already on target device, returns a clone (no-op).
517    ///
518    /// # Examples
519    /// ```ignore
520    /// let cpu_tensor = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
521    /// let mut gpu_tensor = cpu_tensor.to(DeviceSpec::Cuda { device_id: 0 });
522    /// gpu_tensor.realize()?;  // Actually transfers data
523    /// ```
524    pub fn to(&self, device: DeviceSpec) -> Self {
525        if self.device() == device {
526            return self.clone();
527        }
528
529        let copy_uop = self.uop().copy_to_device(device);
530        Self::new(copy_uop)
531    }
532
533    /// Cast tensor to a different dtype.
534    ///
535    /// # Examples
536    /// ```ignore
537    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
538    /// let t_int = t.cast(DType::Int32)?;
539    /// ```
540    pub fn cast(&self, dtype: svod_dtype::DType) -> Result<Self> {
541        let casted = self.uop().cast(dtype);
542        Ok(Self::new(casted))
543    }
544
545    /// Build and apply a custom UOp kernel over this tensor and additional inputs.
546    ///
547    /// The closure receives PARAM placeholders (as UOps) corresponding to
548    /// `[self, others...]` and must return the kernel body UOp (typically a SINK).
549    /// Returns tensors wrapped with AFTER(CALL) dependencies in argument order.
550    pub fn custom_kernel<F>(&self, others: &[&Tensor], fxn: F) -> Result<Vec<Tensor>>
551    where
552        F: FnOnce(Vec<Arc<UOp>>) -> Arc<UOp>,
553    {
554        self.custom_kernel_with(others, CallInfo::default(), fxn)
555    }
556
557    /// `custom_kernel` with explicit CALL metadata.
558    pub fn custom_kernel_with<F>(&self, others: &[&Tensor], info: CallInfo, fxn: F) -> Result<Vec<Tensor>>
559    where
560        F: FnOnce(Vec<Arc<UOp>>) -> Arc<UOp>,
561    {
562        let mut srcs: Vec<Arc<UOp>> = Vec::with_capacity(1 + others.len());
563        srcs.push(self.uop());
564        srcs.extend(others.iter().map(|t| t.uop()));
565
566        let outputs = UOp::custom_kernel(srcs, fxn, info).context(UOpSnafu)?;
567        Ok(outputs.into_iter().map(Self::from_lazy).collect())
568    }
569
570    /// Bitcast tensor to a different dtype, reinterpreting bits.
571    ///
572    /// For equal-itemsize dtypes (e.g. `f32 ↔ i32`) this is the pure
573    /// IR-level reinterpretation. For different-itemsize dtypes (e.g.
574    /// `u32 → u16` or `u32 → u64`) the last axis is split or combined via
575    /// shifts + reshape, matching Tinygrad's `tensor.py::bitcast`. The total
576    /// byte count is preserved; the last axis grows (`src_size > dst_size`)
577    /// or shrinks (`src_size < dst_size`) by `rate = max(...)/min(...)`.
578    ///
579    /// Requires:
580    /// - source and destination are both scalar (vector dtypes unsupported);
581    /// - `(shape[-1] * src_size)` divides evenly by `dst_size`;
582    /// - the last shape dim is concrete (not symbolic).
583    pub fn bitcast(&self, dtype: svod_dtype::DType) -> Result<Self> {
584        let src_dt = self.uop().dtype();
585        let src_scalar = src_dt.scalar().ok_or_else(|| Error::SymbolicShapeUnsupported {
586            operation: "bitcast: non-scalar source dtype".to_string(),
587        })?;
588        let dst_scalar = dtype.scalar().ok_or_else(|| Error::SymbolicShapeUnsupported {
589            operation: "bitcast: non-scalar destination dtype".to_string(),
590        })?;
591        let src_size = src_scalar.bytes();
592        let dst_size = dst_scalar.bytes();
593
594        if src_size == dst_size {
595            return Ok(Self::new(self.uop().bitcast(dtype)));
596        }
597
598        let shape = self.shape()?;
599        let last_dim = shape.last().and_then(|s| s.as_const()).ok_or_else(|| Error::SymbolicShapeUnsupported {
600            operation: "bitcast with size change on symbolic last dim".to_string(),
601        })?;
602        if last_dim * src_size % dst_size != 0 {
603            return Err(Error::ReshapeSizeMismatch {
604                operation: format!(
605                    "bitcast {src_scalar:?}({src_size}B) → {dst_scalar:?}({dst_size}B): \
606                     last dim {last_dim} × {src_size} not divisible by {dst_size}"
607                ),
608            });
609        }
610
611        let src_uint = DType::Scalar(uint_for_bytes(src_size));
612        let dst_uint = DType::Scalar(uint_for_bytes(dst_size));
613
614        // Reinterpret as the source-sized uint first (always equal-size, falls
615        // into the identity path above).
616        let tmp = if src_dt == src_uint { self.clone() } else { Self::new(self.uop().bitcast(src_uint.clone())) };
617
618        let result = if dst_size > src_size {
619            // Combine `rate` source words into one dst word: shift each by
620            // `8*i*src_size`, OR them, squeeze the trailing axis.
621            let rate = dst_size / src_size;
622            let mut new_shape: Vec<isize> = svod_ir::shape::to_vec_isize(&shape).context(UOpSnafu)?;
623            let last_idx = new_shape.len() - 1;
624            new_shape[last_idx] = (last_dim / rate) as isize;
625            new_shape.push(rate as isize);
626            let reshaped = tmp.try_reshape(&new_shape)?;
627
628            let mut acc: Option<Tensor> = None;
629            for i in 0..rate {
630                // Slice the trailing axis to `(i, i+1)` (preserves rank).
631                let mut shrink_ranges: Vec<Option<(isize, isize)>> =
632                    std::iter::repeat_n(None, new_shape.len() - 1).collect();
633                shrink_ranges.push(Some((i as isize, (i + 1) as isize)));
634                let slice = reshaped.try_shrink(shrink_ranges)?;
635                let widened = slice.cast(dst_uint.clone())?;
636                let shift_amount = 8 * i * src_size;
637                let term = if shift_amount == 0 {
638                    widened
639                } else {
640                    let shift_t = Tensor::full(
641                        &svod_ir::shape::to_vec_usize(&widened.shape()?).context(UOpSnafu)?,
642                        ConstValue::UInt(shift_amount as u64),
643                        dst_uint.clone(),
644                    )?;
645                    widened.try_shl(&shift_t)?
646                };
647                acc = Some(match acc {
648                    None => term,
649                    Some(a) => a.try_bitor(&term)?,
650                });
651            }
652            let summed = acc.expect("rate >= 1");
653            // Squeeze the trailing axis (now size 1).
654            summed.try_squeeze(Some(-1))?
655        } else {
656            // Split each source word into `rate` dst words via right shifts,
657            // stack along a new trailing axis, then flatten the last two.
658            let rate = src_size / dst_size;
659            let mut shifted: Vec<Tensor> = Vec::with_capacity(rate);
660            for i in 0..rate {
661                let shift_amount = 8 * i * dst_size;
662                let s = if shift_amount == 0 {
663                    tmp.clone()
664                } else {
665                    let shift_t = Tensor::full(
666                        &svod_ir::shape::to_vec_usize(&tmp.shape()?).context(UOpSnafu)?,
667                        ConstValue::UInt(shift_amount as u64),
668                        src_uint.clone(),
669                    )?;
670                    tmp.try_shr(&shift_t)?
671                };
672                shifted.push(s);
673            }
674            let refs: Vec<&Tensor> = shifted.iter().collect();
675            let stacked = Tensor::stack(&refs, -1)?;
676            // Collapse trailing two axes (... × last × rate) → (... × last*rate).
677            let stacked_shape = stacked.shape()?;
678            let nd = stacked_shape.len();
679            let mut new_shape: Vec<isize> = svod_ir::shape::to_vec_isize(&stacked_shape).context(UOpSnafu)?;
680            let trailing = new_shape[nd - 2] * new_shape[nd - 1];
681            new_shape.truncate(nd - 2);
682            new_shape.push(trailing);
683            let flat = stacked.try_reshape(&new_shape)?;
684            flat.cast(dst_uint.clone())?
685        };
686
687        // Final reinterpretation at equal size (e.g. u16 → f16).
688        if result.uop().dtype() == dtype { Ok(result) } else { Ok(Self::new(result.uop().bitcast(dtype))) }
689    }
690}
691
692fn uint_for_bytes(n: usize) -> svod_dtype::ScalarDType {
693    use svod_dtype::ScalarDType;
694    match n {
695        1 => ScalarDType::UInt8,
696        2 => ScalarDType::UInt16,
697        4 => ScalarDType::UInt32,
698        8 => ScalarDType::UInt64,
699        _ => panic!("uint_for_bytes: unsupported byte size {n}"),
700    }
701}
702
703#[allow(dead_code)]
704impl Tensor {
705    /// Assign a value tensor to this tensor in-place.
706    ///
707    /// Embeds the write as `AFTER(target, STORE(target, value))`.
708    ///
709    /// # Example
710    ///
711    /// ```ignore
712    /// let placeholder = Tensor::empty(&[2, 3], DType::Float32);
713    /// let real_data = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0])
714    ///     .try_reshape(&[2, 3]).unwrap();
715    /// placeholder.assign(&real_data);
716    /// ```
717    pub fn try_assign(&self, value: &Tensor) -> Result<()> {
718        let target_uop = self.uop();
719        if self.device().is_disk() {
720            return Err(Error::IrConstruction {
721                details: "assign to DISK tensors is not supported by Svod runtime".to_string(),
722            });
723        }
724
725        let target_shape = self.shape()?;
726        let value_shape = value.shape()?;
727        let value = if target_shape != value_shape { value.broadcast_to(&target_shape)? } else { value.clone() };
728        if self.device() != value.device() {
729            return Err(Error::IrConstruction {
730                details: format!("assign device mismatch {:?} != {:?}", self.device(), value.device()),
731            });
732        }
733
734        let target_dtype = target_uop.dtype();
735        let value_dtype = value.uop().dtype();
736        if target_dtype != value_dtype {
737            return Err(Error::TypeMismatch { expected: target_dtype, actual: value_dtype });
738        }
739
740        let value_uop = value.uop();
741        if Arc::ptr_eq(&target_uop, &value_uop) {
742            return Ok(());
743        }
744
745        let assign_effect = target_uop.after(smallvec![target_uop.store(value_uop)]);
746        let base = target_uop.base();
747        if matches!(base.op(), Op::Buffer { .. } | Op::After { .. })
748            && target_uop.id != base.id
749            && !target_uop.has_buffer_identity()
750        {
751            let identity = find_assign_identity(&target_uop, &base);
752            let assigned_identity = identity.after(smallvec![assign_effect]);
753            #[allow(clippy::mutable_key_type)]
754            let mut becomes_map = HashMap::new();
755            becomes_map.insert(UOpKey(identity), assigned_identity);
756            // Walk semantics required: replacement contains the original key
757            // (`After(Buffer, [...])` wraps `Buffer`). A re-traversing rewrite
758            // would loop or wrap the buffer multiple times.
759            tensor_registry::apply_map_to_tensors_walk(&becomes_map);
760        } else {
761            self.set_uop(assign_effect);
762        }
763        Ok(())
764    }
765
766    pub fn assign(&self, value: &Tensor) {
767        self.try_assign(value).expect("tensor assign failed");
768    }
769
770    /// Update the UOp for this tensor directly.
771    ///
772    /// This is used internally after realization to update the tensor's UOp
773    /// to point to the materialized buffer.
774    pub(crate) fn set_uop(&self, uop: Arc<UOp>) {
775        *self.entry.uop.write() = uop;
776    }
777
778    /// Ensure this tensor has contiguous memory layout.
779    ///
780    /// Creates a CONTIGUOUS UOp that forces materialization when realized.
781    /// Following Tinygrad's approach, calling `.contiguous().realize()` on
782    /// a pure constant tensor will create an actual buffer.
783    ///
784    /// # Examples
785    /// ```ignore
786    /// // Force a constant to be materialized
787    /// let mut c = Tensor::const_(5.0f32, DType::Float32).contiguous();
788    /// c.realize()?;
789    /// assert!(c.buffer().is_some());
790    /// ```
791    pub fn contiguous(&self) -> Self {
792        let uop = self.uop();
793        if matches!(uop.op(), svod_ir::Op::Contiguous { .. }) {
794            return self.clone();
795        }
796        let contiguous_uop = uop.contiguous();
797        Self::new(contiguous_uop)
798    }
799}
800
801impl Tensor {
802    /// Helper to broadcast a scalar constant to match this tensor's shape.
803    pub(crate) fn broadcast_scalar(&self, value: ConstValue) -> Result<Self> {
804        let shape = self.shape()?;
805        let scalar = Self::new(UOp::const_(self.uop().dtype(), value));
806        scalar.broadcast_to(&shape)
807    }
808
809    /// Broadcast a dtype-aware zero to match this tensor's shape.
810    pub fn zero(&self) -> Result<Self> {
811        let sdtype = self.uop().dtype().scalar().expect("scalar dtype");
812        self.broadcast_scalar(ConstValue::zero(sdtype))
813    }
814
815    /// Broadcast a dtype-aware one to match this tensor's shape.
816    pub fn one(&self) -> Result<Self> {
817        let sdtype = self.uop().dtype().scalar().expect("scalar dtype");
818        self.broadcast_scalar(ConstValue::one(sdtype))
819    }
820
821    /// Identity matrix of shape `[n, m]` with the given dtype.
822    pub fn eye(n: usize, m: usize, dtype: DType) -> Result<Self> {
823        let rows = Self::arange(n as i64, None, None)?.try_unsqueeze(-1)?;
824        let cols = Self::arange(m as i64, None, None)?;
825        rows.try_eq(&cols)?.cast(dtype)
826    }
827}
828
829#[bon]
830impl Tensor {
831    /// Cumulative sum with exclusive and reverse options.
832    #[builder]
833    pub fn cumsum_with(
834        &self,
835        axis: isize,
836        #[builder(default = false)] exclusive: bool,
837        #[builder(default = false)] reverse: bool,
838    ) -> Result<Self> {
839        let shape = self.shape()?;
840        let ndim = shape.len();
841        let axis_idx = Self::normalize_axis(axis, ndim)?;
842        let mut result = self.clone();
843        if reverse {
844            result = result.flip(&[axis_idx as isize])?;
845        }
846        if exclusive {
847            let dim_size = shape[axis_idx].as_const().unwrap() as isize;
848            let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); ndim];
849            pad_spec[axis_idx] = (1, 0);
850            result = result.try_pad(&pad_spec)?;
851            let mut shrink_spec: Vec<(isize, isize)> =
852                result.shape()?.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
853            shrink_spec[axis_idx] = (0, dim_size);
854            result = result.try_shrink(&shrink_spec)?;
855        }
856        result = result.cumsum(axis_idx as isize)?;
857        if reverse {
858            result = result.flip(&[axis_idx as isize])?;
859        }
860        Ok(result)
861    }
862
863    /// Cumulative product with exclusive and reverse options.
864    #[builder]
865    pub fn cumprod_with(
866        &self,
867        axis: isize,
868        #[builder(default = false)] exclusive: bool,
869        #[builder(default = false)] reverse: bool,
870    ) -> Result<Self> {
871        let shape = self.shape()?;
872        let ndim = shape.len();
873        let axis_idx = Self::normalize_axis(axis, ndim)?;
874        let mut result = self.clone();
875        if reverse {
876            result = result.flip(&[axis_idx as isize])?;
877        }
878        if exclusive {
879            let dim_size = shape[axis_idx].as_const().unwrap() as isize;
880            let mut pad_spec: Vec<(isize, isize)> = vec![(0, 0); ndim];
881            pad_spec[axis_idx] = (1, 0);
882            result = result.try_pad_value(&pad_spec, 1.0)?;
883            let mut shrink_spec: Vec<(isize, isize)> =
884                result.shape()?.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
885            shrink_spec[axis_idx] = (0, dim_size);
886            result = result.try_shrink(&shrink_spec)?;
887        }
888        result = result.cumprod(axis_idx as isize)?;
889        if reverse {
890            result = result.flip(&[axis_idx as isize])?;
891        }
892        Ok(result)
893    }
894}
895
896#[cfg(test)]
897mod test;