svod_tensor/conditional.rs
1//! Conditional and selection operations for tensors.
2//!
3//! This module provides element-wise conditional operations like where, maximum,
4//! minimum, and clamp that are fundamental for many ML operations.
5
6use bon::bon;
7use snafu::ResultExt;
8use svod_ir::UOp;
9
10use crate::{Result, Tensor, error::UOpSnafu};
11
12#[bon]
13impl Tensor {
14 /// Element-wise conditional selection: `condition ? self : other`.
15 ///
16 /// For each element, returns `self[i]` if `condition[i]` is true, else `other[i]`.
17 ///
18 /// # Arguments
19 /// * `condition` - Boolean tensor (dtype should be Bool or will be treated as boolean)
20 /// * `other` - Alternative value tensor
21 ///
22 /// # Shape Requirements
23 /// All three tensors (self, condition, other) must be broadcastable to the same shape.
24 ///
25 /// # Examples
26 /// ```ignore
27 /// let x = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
28 /// let condition = &x.gt(&Tensor::from_slice(&[2.0f32]))?; // [false, false, true, true]
29 /// let zeros = Tensor::from_slice(&[0.0f32]);
30 ///
31 /// // Replace values > 2.0 with the original value, else 0
32 /// let result = x.where_(condition, &zeros)?;
33 /// // result = [0.0, 0.0, 3.0, 4.0]
34 /// ```
35 pub fn where_(&self, condition: &Tensor, other: &Tensor) -> Result<Self> {
36 use svod_ir::shape::{align_shapes_left, broadcast_shapes};
37
38 let cond_shape = condition.shape()?;
39 let self_shape = self.shape()?;
40 let other_shape = other.shape()?;
41
42 // Broadcast all three to a common shape
43 let aligned = align_shapes_left(&[cond_shape.clone(), self_shape.clone(), other_shape.clone()]);
44 let target = broadcast_shapes(&aligned).context(UOpSnafu)?;
45
46 let cond_bc = condition.broadcast_to(&target)?;
47 let self_bc = self.broadcast_to(&target)?;
48 let other_bc = other.broadcast_to(&target)?;
49
50 let result = UOp::try_where(cond_bc.uop(), self_bc.uop(), other_bc.uop()).context(UOpSnafu)?;
51 Ok(Self::new(result))
52 }
53
54 /// Element-wise maximum: `max(self, other)`.
55 ///
56 /// Returns the element-wise maximum of two tensors.
57 /// This is NOT a reduction - it returns a tensor of the same shape.
58 ///
59 /// # Shape Requirements
60 /// Both tensors must be broadcastable to the same shape.
61 ///
62 /// # Examples
63 /// ```ignore
64 /// let a = Tensor::from_slice(&[1.0f32, 5.0, 3.0]);
65 /// let b = Tensor::from_slice(&[2.0f32, 3.0, 4.0]);
66 /// let result = a.maximum(&b)?;
67 /// // result = [2.0, 5.0, 4.0]
68 /// ```
69 pub fn maximum(&self, other: &Tensor) -> Result<Self> {
70 let (lhs, rhs) = self.broadcast_for_binop(other)?;
71 let result = lhs.uop().try_max(&rhs.uop()).context(UOpSnafu)?;
72 Ok(Self::new(result))
73 }
74
75 /// Element-wise minimum: `min(self, other)`.
76 ///
77 /// Returns the element-wise minimum of two tensors.
78 /// This is NOT a reduction - it returns a tensor of the same shape.
79 ///
80 /// # Shape Requirements
81 /// Both tensors must be broadcastable to the same shape.
82 ///
83 /// # Examples
84 /// ```ignore
85 /// let a = Tensor::from_slice(&[1.0f32, 5.0, 3.0]);
86 /// let b = Tensor::from_slice(&[2.0f32, 3.0, 4.0]);
87 /// let result = a.minimum(&b)?;
88 /// // result = [1.0, 3.0, 3.0]
89 /// ```
90 pub fn minimum(&self, other: &Tensor) -> Result<Self> {
91 // Minimum is not a primitive, we implement it as: -max(-a, -b)
92 // Or equivalently: where(a < b, a, b)
93 let condition = self.try_lt(other)?;
94 self.where_(&condition, other)
95 }
96
97 /// Clamp values to a range: `max(min_val, min(self, max_val))`.
98 ///
99 /// Constrains all elements to be within [min_val, max_val].
100 ///
101 /// # Examples
102 /// ```ignore
103 /// let x = Tensor::from_slice(&[-1.0f32, 0.0, 1.0, 2.0, 3.0]);
104 /// let min = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0]);
105 /// let max = Tensor::from_slice(&[2.0f32, 2.0, 2.0, 2.0, 2.0]);
106 ///
107 /// // Clamp to [0, 2]
108 /// let result = x.clamp().min(&min).max(&max).call()?;
109 /// // result = [0.0, 0.0, 1.0, 2.0, 2.0]
110 ///
111 /// // Clamp only lower bound
112 /// let result = x.clamp().min(&min).call()?;
113 /// // result = [0.0, 0.0, 1.0, 2.0, 3.0]
114 ///
115 /// // Clamp only upper bound
116 /// let result = x.clamp().max(&max).call()?;
117 /// // result = [-1.0, 0.0, 1.0, 2.0, 2.0]
118 /// ```
119 #[builder]
120 pub fn clamp(&self, min: Option<&Tensor>, max: Option<&Tensor>) -> Result<Self> {
121 let mut result = self.clone();
122
123 if let Some(min_val) = min {
124 result = result.maximum(min_val)?;
125 }
126
127 if let Some(max_val) = max {
128 result = result.minimum(max_val)?;
129 }
130
131 Ok(result)
132 }
133
134 /// Alias for `clamp` (matches NumPy/PyTorch naming).
135 ///
136 /// # Examples
137 /// ```ignore
138 /// let x = Tensor::from_slice(&[-1.0f32, 0.0, 1.0, 2.0, 3.0]);
139 /// let min = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0]);
140 /// let max = Tensor::from_slice(&[2.0f32, 2.0, 2.0, 2.0, 2.0]);
141 ///
142 /// // Clip to [0, 2]
143 /// let result = x.clip().min(&min).max(&max).call()?;
144 /// ```
145 #[builder]
146 pub fn clip(&self, min: Option<&Tensor>, max: Option<&Tensor>) -> Result<Self> {
147 self.clamp().maybe_min(min).maybe_max(max).call()
148 }
149}