Skip to main content

svod_tensor/
variable.rs

1//! Symbolic variables for dynamic tensor dimensions.
2//!
3//! Variables allow tensor shapes to contain symbolic dimensions (e.g., batch size,
4//! sequence length) that are resolved to concrete values at execution time.
5//!
6//! # Tinygrad Alignment
7//!
8//! Matches Tinygrad's `Variable = UOp` where `Variable("i", 1, 10)` creates
9//! a `DEFINE_VAR` UOp and `.bind(val)` produces a `BIND(DEFINE_VAR, CONST)` UOp.
10//!
11//! # Example
12//!
13//! ```ignore
14//! use svod_tensor::{Variable, Tensor};
15//! use svod_dtype::DType;
16//!
17//! let batch = Variable::new("batch", 1, 32);
18//! let bound = batch.bind(16)?;
19//!
20//! let mut x = Tensor::full_dynamic(&[bound.as_sint(), 784.into()], 0.0, DType::Float32)?;
21//! x.realize()?;
22//! ```
23
24use std::sync::Arc;
25
26use svod_dtype::DType;
27use svod_ir::{ConstValue, Op, SInt, UOp};
28
29use crate::error::{Result, VariableOutOfRangeSnafu};
30use snafu::ensure;
31
32/// A symbolic variable for dynamic tensor dimensions.
33///
34/// Wraps a `DEFINE_VAR` UOp with known bounds `[min_val, max_val]`.
35/// Variables are created unbound, then bound to concrete values via [`bind()`](Self::bind).
36///
37/// The same `Variable` can be bound to different values for different executions,
38/// enabling dynamic batch sizes, sequence lengths, etc.
39#[derive(Clone)]
40pub struct Variable {
41    uop: Arc<UOp>,
42}
43
44impl Variable {
45    /// Create a new symbolic variable with inclusive bounds.
46    ///
47    /// # Arguments
48    ///
49    /// * `name` - Unique variable name (used as kernel parameter name)
50    /// * `min_val` - Minimum allowed value (inclusive)
51    /// * `max_val` - Maximum allowed value (inclusive)
52    ///
53    /// # Panics
54    ///
55    /// Panics if `min_val > max_val`.
56    pub fn new(name: &str, min_val: i64, max_val: i64) -> Self {
57        assert!(min_val <= max_val, "Variable '{name}': min_val ({min_val}) > max_val ({max_val})");
58        Self { uop: UOp::define_var(name.to_string(), min_val, max_val) }
59    }
60
61    /// Bind this variable to a concrete value.
62    ///
63    /// Returns a [`BoundVariable`] whose [`as_sint()`](BoundVariable::as_sint) can be
64    /// used as a tensor dimension.
65    ///
66    /// # Errors
67    ///
68    /// Returns [`VariableOutOfRange`](crate::error::Error::VariableOutOfRange) if
69    /// `val` is outside `[min_val, max_val]`.
70    pub fn bind(&self, val: i64) -> Result<BoundVariable> {
71        let (min, max) = self.bounds();
72        ensure!(val >= min && val <= max, VariableOutOfRangeSnafu { name: self.name().to_string(), val, min, max });
73        let val_uop = UOp::const_(DType::Index, ConstValue::Int(val));
74        let bind_uop = self.uop.bind(val_uop);
75        Ok(BoundVariable { var: self.clone(), value: val, uop: bind_uop })
76    }
77
78    /// Variable name.
79    pub fn name(&self) -> &str {
80        match self.uop.op() {
81            Op::DefineVar { name, .. } => name.as_str(),
82            _ => unreachable!("Variable always wraps DefineVar"),
83        }
84    }
85
86    /// Inclusive bounds `(min_val, max_val)`.
87    pub fn bounds(&self) -> (i64, i64) {
88        match self.uop.op() {
89            Op::DefineVar { min_val, max_val, .. } => (*min_val, *max_val),
90            _ => unreachable!("Variable always wraps DefineVar"),
91        }
92    }
93
94    /// Get the underlying `DEFINE_VAR` UOp as an `SInt`.
95    ///
96    /// This is useful for constructing shapes that use the variable's max value
97    /// for buffer allocation (unbound variable → allocate to max).
98    pub fn as_sint(&self) -> SInt {
99        SInt::Symbolic(self.uop.clone())
100    }
101
102    /// Get the underlying `DEFINE_VAR` UOp.
103    pub fn uop(&self) -> &Arc<UOp> {
104        &self.uop
105    }
106}
107
108impl std::fmt::Debug for Variable {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        let (min, max) = self.bounds();
111        write!(f, "Variable({:?}, {}, {})", self.name(), min, max)
112    }
113}
114
115/// A variable bound to a concrete value.
116///
117/// Created by [`Variable::bind()`]. Use [`as_sint()`](Self::as_sint) to get an `SInt`
118/// suitable as a tensor dimension.
119#[derive(Clone)]
120pub struct BoundVariable {
121    var: Variable,
122    value: i64,
123    uop: Arc<UOp>,
124}
125
126impl BoundVariable {
127    /// Get an `SInt` representing this bound variable.
128    ///
129    /// The returned `SInt::Symbolic` contains `BIND(DEFINE_VAR, CONST(value))`,
130    /// which flows through the existing shape infrastructure (reshape, permute,
131    /// expand, binary ops all handle `SInt::Symbolic`).
132    pub fn as_sint(&self) -> SInt {
133        SInt::Symbolic(self.uop.clone())
134    }
135
136    /// The bound concrete value.
137    pub fn value(&self) -> i64 {
138        self.value
139    }
140
141    /// The underlying variable definition.
142    pub fn variable(&self) -> &Variable {
143        &self.var
144    }
145
146    /// Decompose into the variable and its bound value.
147    pub fn unbind(self) -> (Variable, i64) {
148        (self.var, self.value)
149    }
150
151    /// Get `(name, value)` pair for use with `ExecutionPlan::execute_with_vars`.
152    pub fn as_var_val(&self) -> (&str, i64) {
153        (self.variable().name(), self.value())
154    }
155
156    /// Get the underlying `BIND` UOp.
157    pub fn uop(&self) -> &Arc<UOp> {
158        &self.uop
159    }
160}
161
162impl std::fmt::Debug for BoundVariable {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        write!(f, "BoundVariable({:?} = {})", self.var.name(), self.value)
165    }
166}
167
168// Allow using BoundVariable directly as SInt in shape expressions
169impl From<BoundVariable> for SInt {
170    fn from(bv: BoundVariable) -> SInt {
171        bv.as_sint()
172    }
173}
174
175impl From<&BoundVariable> for SInt {
176    fn from(bv: &BoundVariable) -> SInt {
177        bv.as_sint()
178    }
179}