svod_tensor/variable.rs
1//! Symbolic variables for dynamic tensor dimensions.
2//!
3//! Variables allow tensor shapes to contain symbolic dimensions (e.g., batch size,
4//! sequence length) that are resolved to concrete values at execution time.
5//!
6//! # Tinygrad Alignment
7//!
8//! Matches Tinygrad's `Variable = UOp` where `Variable("i", 1, 10)` creates
9//! a `DEFINE_VAR` UOp and `.bind(val)` produces a `BIND(DEFINE_VAR, CONST)` UOp.
10//!
11//! # Example
12//!
13//! ```ignore
14//! use svod_tensor::{Variable, Tensor};
15//! use svod_dtype::DType;
16//!
17//! let batch = Variable::new("batch", 1, 32);
18//! let bound = batch.bind(16)?;
19//!
20//! let mut x = Tensor::full_dynamic(&[bound.as_sint(), 784.into()], 0.0, DType::Float32)?;
21//! x.realize()?;
22//! ```
23
24use std::sync::Arc;
25
26use svod_dtype::DType;
27use svod_ir::{ConstValue, Op, SInt, UOp};
28
29use crate::error::{Result, VariableOutOfRangeSnafu};
30use snafu::ensure;
31
32/// A symbolic variable for dynamic tensor dimensions.
33///
34/// Wraps a `DEFINE_VAR` UOp with known bounds `[min_val, max_val]`.
35/// Variables are created unbound, then bound to concrete values via [`bind()`](Self::bind).
36///
37/// The same `Variable` can be bound to different values for different executions,
38/// enabling dynamic batch sizes, sequence lengths, etc.
39#[derive(Clone)]
40pub struct Variable {
41 uop: Arc<UOp>,
42}
43
44impl Variable {
45 /// Create a new symbolic variable with inclusive bounds.
46 ///
47 /// # Arguments
48 ///
49 /// * `name` - Unique variable name (used as kernel parameter name)
50 /// * `min_val` - Minimum allowed value (inclusive)
51 /// * `max_val` - Maximum allowed value (inclusive)
52 ///
53 /// # Panics
54 ///
55 /// Panics if `min_val > max_val`.
56 pub fn new(name: &str, min_val: i64, max_val: i64) -> Self {
57 assert!(min_val <= max_val, "Variable '{name}': min_val ({min_val}) > max_val ({max_val})");
58 Self { uop: UOp::define_var(name.to_string(), min_val, max_val) }
59 }
60
61 /// Bind this variable to a concrete value.
62 ///
63 /// Returns a [`BoundVariable`] whose [`as_sint()`](BoundVariable::as_sint) can be
64 /// used as a tensor dimension.
65 ///
66 /// # Errors
67 ///
68 /// Returns [`VariableOutOfRange`](crate::error::Error::VariableOutOfRange) if
69 /// `val` is outside `[min_val, max_val]`.
70 pub fn bind(&self, val: i64) -> Result<BoundVariable> {
71 let (min, max) = self.bounds();
72 ensure!(val >= min && val <= max, VariableOutOfRangeSnafu { name: self.name().to_string(), val, min, max });
73 let val_uop = UOp::const_(DType::Index, ConstValue::Int(val));
74 let bind_uop = self.uop.bind(val_uop);
75 Ok(BoundVariable { var: self.clone(), value: val, uop: bind_uop })
76 }
77
78 /// Variable name.
79 pub fn name(&self) -> &str {
80 match self.uop.op() {
81 Op::DefineVar { name, .. } => name.as_str(),
82 _ => unreachable!("Variable always wraps DefineVar"),
83 }
84 }
85
86 /// Inclusive bounds `(min_val, max_val)`.
87 pub fn bounds(&self) -> (i64, i64) {
88 match self.uop.op() {
89 Op::DefineVar { min_val, max_val, .. } => (*min_val, *max_val),
90 _ => unreachable!("Variable always wraps DefineVar"),
91 }
92 }
93
94 /// Get the underlying `DEFINE_VAR` UOp as an `SInt`.
95 ///
96 /// This is useful for constructing shapes that use the variable's max value
97 /// for buffer allocation (unbound variable → allocate to max).
98 pub fn as_sint(&self) -> SInt {
99 SInt::Symbolic(self.uop.clone())
100 }
101
102 /// Get the underlying `DEFINE_VAR` UOp.
103 pub fn uop(&self) -> &Arc<UOp> {
104 &self.uop
105 }
106}
107
108impl std::fmt::Debug for Variable {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 let (min, max) = self.bounds();
111 write!(f, "Variable({:?}, {}, {})", self.name(), min, max)
112 }
113}
114
115/// A variable bound to a concrete value.
116///
117/// Created by [`Variable::bind()`]. Use [`as_sint()`](Self::as_sint) to get an `SInt`
118/// suitable as a tensor dimension.
119#[derive(Clone)]
120pub struct BoundVariable {
121 var: Variable,
122 value: i64,
123 uop: Arc<UOp>,
124}
125
126impl BoundVariable {
127 /// Get an `SInt` representing this bound variable.
128 ///
129 /// The returned `SInt::Symbolic` contains `BIND(DEFINE_VAR, CONST(value))`,
130 /// which flows through the existing shape infrastructure (reshape, permute,
131 /// expand, binary ops all handle `SInt::Symbolic`).
132 pub fn as_sint(&self) -> SInt {
133 SInt::Symbolic(self.uop.clone())
134 }
135
136 /// The bound concrete value.
137 pub fn value(&self) -> i64 {
138 self.value
139 }
140
141 /// The underlying variable definition.
142 pub fn variable(&self) -> &Variable {
143 &self.var
144 }
145
146 /// Decompose into the variable and its bound value.
147 pub fn unbind(self) -> (Variable, i64) {
148 (self.var, self.value)
149 }
150
151 /// Get `(name, value)` pair for use with `ExecutionPlan::execute_with_vars`.
152 pub fn as_var_val(&self) -> (&str, i64) {
153 (self.variable().name(), self.value())
154 }
155
156 /// Get the underlying `BIND` UOp.
157 pub fn uop(&self) -> &Arc<UOp> {
158 &self.uop
159 }
160}
161
162impl std::fmt::Debug for BoundVariable {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 write!(f, "BoundVariable({:?} = {})", self.var.name(), self.value)
165 }
166}
167
168// Allow using BoundVariable directly as SInt in shape expressions
169impl From<BoundVariable> for SInt {
170 fn from(bv: BoundVariable) -> SInt {
171 bv.as_sint()
172 }
173}
174
175impl From<&BoundVariable> for SInt {
176 fn from(bv: &BoundVariable) -> SInt {
177 bv.as_sint()
178 }
179}