Skip to main content

torsh_tensor/
tensor_comprehension.rs

1//! Tensor Comprehensions - Declarative Tensor Creation
2//!
3//! This module provides powerful macro-based tensor comprehensions similar to Python's list
4//! comprehensions and NumPy's array creation syntax. It enables concise, readable tensor
5//! creation with various patterns and conditions.
6//!
7//! # Features
8//!
9//! - **List-style comprehensions**: Create tensors with for-loop style syntax
10//! - **Conditional filtering**: Include conditions to filter elements
11//! - **Multi-dimensional**: Support for creating multi-dimensional tensors
12//! - **Generator expressions**: Lazy generation of tensor elements
13//! - **Type inference**: Automatic type inference from expressions
14//!
15//! # Examples
16//!
17//! ```rust
18//! use torsh_tensor::tensor_comp;
19//!
20//! // Simple range: [0, 1, 2, 3, 4]
21//! // let t = tensor_comp![x; x in 0..5];
22//!
23//! // With transformation: [0, 2, 4, 6, 8]
24//! // let t = tensor_comp![x * 2; x in 0..5];
25//!
26//! // With condition: [0, 2, 4]
27//! // let t = tensor_comp![x; x in 0..5, if x % 2 == 0];
28//!
29//! // 2D comprehension: [[0, 1], [2, 3], [4, 5]]
30//! // let t = tensor_comp![[i * 2 + j; j in 0..2]; i in 0..3];
31//! ```
32
33use torsh_core::{device::DeviceType, dtype::TensorElement, error::Result};
34
35use crate::Tensor;
36
37/// Builder for tensor comprehensions
38pub struct TensorComprehension<T: TensorElement> {
39    elements: Vec<T>,
40    shape: Vec<usize>,
41    device: DeviceType,
42}
43
44impl<T: TensorElement + Copy> TensorComprehension<T> {
45    /// Create a new tensor comprehension builder
46    pub fn new() -> Self {
47        Self {
48            elements: Vec::new(),
49            shape: Vec::new(),
50            device: DeviceType::Cpu,
51        }
52    }
53
54    /// Set the device for the tensor
55    pub fn device(mut self, device: DeviceType) -> Self {
56        self.device = device;
57        self
58    }
59
60    /// Add elements from an iterator
61    pub fn from_iter<I>(mut self, iter: I) -> Self
62    where
63        I: IntoIterator<Item = T>,
64    {
65        self.elements = iter.into_iter().collect();
66        if self.shape.is_empty() {
67            self.shape = vec![self.elements.len()];
68        }
69        self
70    }
71
72    /// Add elements with explicit shape
73    pub fn from_iter_with_shape<I>(mut self, iter: I, shape: Vec<usize>) -> Self
74    where
75        I: IntoIterator<Item = T>,
76    {
77        self.elements = iter.into_iter().collect();
78        self.shape = shape;
79        self
80    }
81
82    /// Build the tensor
83    pub fn build(self) -> Result<Tensor<T>> {
84        Tensor::from_data(self.elements, self.shape, self.device)
85    }
86}
87
88impl<T: TensorElement + Copy> Default for TensorComprehension<T> {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94/// Helper function for creating tensors from ranges
95pub fn range_tensor<T>(start: T, end: T, step: T, device: DeviceType) -> Result<Tensor<T>>
96where
97    T: TensorElement + Copy + std::ops::Add<Output = T> + std::cmp::PartialOrd + num_traits::Zero,
98{
99    let mut elements = Vec::new();
100    let mut current = start;
101
102    if step == <T as torsh_core::TensorElement>::zero() {
103        return Err(torsh_core::error::TorshError::InvalidArgument(
104            "Step cannot be zero".to_string(),
105        ));
106    }
107
108    // Handle both positive and negative steps
109    let ascending = start < end;
110    while (ascending && current < end) || (!ascending && current > end) {
111        elements.push(current);
112        current = current + step;
113    }
114
115    let len = elements.len();
116    Tensor::from_data(elements, vec![len], device)
117}
118
119/// Helper function for creating linspace tensors
120pub fn linspace_range<T>(
121    start: f64,
122    end: f64,
123    steps: usize,
124    device: DeviceType,
125) -> Result<Tensor<T>>
126where
127    T: TensorElement + Copy + num_traits::FromPrimitive,
128{
129    if steps == 0 {
130        return Err(torsh_core::error::TorshError::InvalidArgument(
131            "Steps must be greater than 0".to_string(),
132        ));
133    }
134
135    let step = if steps == 1 {
136        0.0
137    } else {
138        (end - start) / (steps - 1) as f64
139    };
140
141    let elements: Vec<T> = (0..steps)
142        .map(|i| {
143            let val = start + step * i as f64;
144            <T as torsh_core::TensorElement>::from_f64(val)
145                .unwrap_or_else(|| <T as torsh_core::TensorElement>::zero())
146        })
147        .collect();
148
149    Tensor::from_data(elements, vec![steps], device)
150}
151
152/// Helper function for creating logspace tensors
153pub fn logspace<T>(
154    start: f64,
155    end: f64,
156    steps: usize,
157    base: f64,
158    device: DeviceType,
159) -> Result<Tensor<T>>
160where
161    T: TensorElement + Copy + num_traits::FromPrimitive,
162{
163    if steps == 0 {
164        return Err(torsh_core::error::TorshError::InvalidArgument(
165            "Steps must be greater than 0".to_string(),
166        ));
167    }
168
169    let step = if steps == 1 {
170        0.0
171    } else {
172        (end - start) / (steps - 1) as f64
173    };
174
175    let elements: Vec<T> = (0..steps)
176        .map(|i| {
177            let exponent = start + step * i as f64;
178            let val = base.powf(exponent);
179            <T as torsh_core::TensorElement>::from_f64(val)
180                .unwrap_or_else(|| <T as torsh_core::TensorElement>::zero())
181        })
182        .collect();
183
184    Tensor::from_data(elements, vec![steps], device)
185}
186
187/// Helper function for creating meshgrid-style tensors
188pub fn meshgrid<T>(x: &Tensor<T>, y: &Tensor<T>) -> Result<(Tensor<T>, Tensor<T>)>
189where
190    T: TensorElement + Copy,
191{
192    let x_data = x.to_vec()?;
193    let y_data = y.to_vec()?;
194
195    let nx = x_data.len();
196    let ny = y_data.len();
197
198    // X grid: repeat each x value ny times
199    let mut x_grid = Vec::with_capacity(nx * ny);
200    for &x_val in &x_data {
201        for _ in 0..ny {
202            x_grid.push(x_val);
203        }
204    }
205
206    // Y grid: tile y values nx times
207    let mut y_grid = Vec::with_capacity(nx * ny);
208    for _ in 0..nx {
209        for &y_val in &y_data {
210            y_grid.push(y_val);
211        }
212    }
213
214    let x_tensor = Tensor::from_data(x_grid, vec![nx, ny], x.device)?;
215    let y_tensor = Tensor::from_data(y_grid, vec![nx, ny], y.device)?;
216
217    Ok((x_tensor, y_tensor))
218}
219
220/// Macro for tensor comprehensions
221#[macro_export]
222macro_rules! tensor_comp {
223    // Simple range: tensor_comp![x; x in start..end]
224    ($expr:expr; $var:ident in $start:expr, $end:expr) => {{
225        let elements: Vec<_> = ($start..$end).map(|$var| $expr).collect();
226        $crate::Tensor::from_data(elements, vec![elements.len()], $crate::DeviceType::Cpu)
227    }};
228
229    // Range with step: tensor_comp![x; x in start..end, step s]
230    ($expr:expr; $var:ident in $start:expr, $end:expr, step $step:expr) => {{
231        let mut elements = Vec::new();
232        let mut $var = $start;
233        while $var < $end {
234            elements.push($expr);
235            $var = $var + $step;
236        }
237        $crate::Tensor::from_data(elements, vec![elements.len()], $crate::DeviceType::Cpu)
238    }};
239
240    // With condition: tensor_comp![x; x in start..end, if condition]
241    ($expr:expr; $var:ident in $start:expr, $end:expr, if $cond:expr) => {{
242        let elements: Vec<_> = ($start..$end)
243            .filter(|&$var| $cond)
244            .map(|$var| $expr)
245            .collect();
246        $crate::Tensor::from_data(elements, vec![elements.len()], $crate::DeviceType::Cpu)
247    }};
248
249    // 2D comprehension: tensor_comp![[expr; j in 0..n]; i in 0..m]
250    ([$expr:expr; $inner_var:ident in $inner_start:expr, $inner_end:expr]; $outer_var:ident in $outer_start:expr, $outer_end:expr) => {{
251        let mut all_elements = Vec::new();
252        let rows = $outer_end - $outer_start;
253        let cols = $inner_end - $inner_start;
254
255        for $outer_var in $outer_start..$outer_end {
256            for $inner_var in $inner_start..$inner_end {
257                all_elements.push($expr);
258            }
259        }
260        $crate::Tensor::from_data(all_elements, vec![rows, cols], $crate::DeviceType::Cpu)
261    }};
262}
263
264/// Macro for creating tensors with repeated values
265#[macro_export]
266macro_rules! tensor_repeat {
267    // Repeat value n times: tensor_repeat![value; n]
268    ($value:expr; $count:expr) => {{
269        let elements = vec![$value; $count];
270        $crate::Tensor::from_data(elements, vec![$count], $crate::DeviceType::Cpu)
271    }};
272
273    // Repeat with shape: tensor_repeat![value; shape]
274    ($value:expr; [$($dim:expr),+]) => {{
275        let shape = vec![$($dim),+];
276        let size: usize = shape.iter().product();
277        let elements = vec![$value; size];
278        $crate::Tensor::from_data(elements, shape, $crate::DeviceType::Cpu)
279    }};
280}
281
282/// Macro for creating identity-like tensors
283#[macro_export]
284macro_rules! tensor_eye {
285    // Identity matrix: tensor_eye![n]
286    ($n:expr) => {{
287        tensor_eye![$n, $n]
288    }};
289
290    // Rectangular identity: tensor_eye![m, n]
291    ($m:expr, $n:expr) => {{
292        let mut elements = vec![0.0f32; $m * $n];
293        let min_dim = std::cmp::min($m, $n);
294        for i in 0..min_dim {
295            elements[i * $n + i] = 1.0;
296        }
297        $crate::Tensor::from_data(elements, vec![$m, $n], $crate::DeviceType::Cpu)
298    }};
299
300    // With offset: tensor_eye![m, n, offset k]
301    ($m:expr, $n:expr, offset $k:expr) => {{
302        let mut elements = vec![0.0f32; $m * $n];
303        if $k >= 0 {
304            let k = $k as usize;
305            for i in 0..$m {
306                let j = i + k;
307                if j < $n {
308                    elements[i * $n + j] = 1.0;
309                }
310            }
311        } else {
312            let k = (-$k) as usize;
313            for j in 0..$n {
314                let i = j + k;
315                if i < $m {
316                    elements[i * $n + j] = 1.0;
317                }
318            }
319        }
320        $crate::Tensor::from_data(elements, vec![$m, $n], $crate::DeviceType::Cpu)
321    }};
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use crate::creation::*;
328
329    #[test]
330    fn test_tensor_comprehension_builder() {
331        let comp = TensorComprehension::new()
332            .from_iter(0..5)
333            .build()
334            .expect("builder should produce valid result");
335
336        let data = comp.to_vec().expect("to_vec conversion should succeed");
337        assert_eq!(data, vec![0, 1, 2, 3, 4]);
338    }
339
340    #[test]
341    fn test_range_tensor() {
342        let t = range_tensor(0, 10, 2, DeviceType::Cpu).expect("range_tensor should succeed");
343        let data = t.to_vec().expect("to_vec conversion should succeed");
344        assert_eq!(data, vec![0, 2, 4, 6, 8]);
345    }
346
347    #[test]
348    fn test_linspace() {
349        let t: Tensor<f32> =
350            linspace_range(0.0, 10.0, 5, DeviceType::Cpu).expect("linspace should succeed");
351        let data = t.to_vec().expect("to_vec conversion should succeed");
352
353        assert!((data[0] - 0.0).abs() < 1e-6);
354        assert!((data[1] - 2.5).abs() < 1e-6);
355        assert!((data[2] - 5.0).abs() < 1e-6);
356        assert!((data[3] - 7.5).abs() < 1e-6);
357        assert!((data[4] - 10.0).abs() < 1e-6);
358    }
359
360    #[test]
361    fn test_logspace() {
362        let t: Tensor<f32> =
363            logspace(0.0, 2.0, 3, 10.0, DeviceType::Cpu).expect("logspace should succeed");
364        let data = t.to_vec().expect("to_vec conversion should succeed");
365
366        assert!((data[0] - 1.0).abs() < 1e-6); // 10^0
367        assert!((data[1] - 10.0).abs() < 1e-5); // 10^1
368        assert!((data[2] - 100.0).abs() < 1e-4); // 10^2
369    }
370
371    #[test]
372    fn test_meshgrid() {
373        let x = tensor_1d(&[1.0f32, 2.0, 3.0]).expect("tensor_1d creation should succeed");
374        let y = tensor_1d(&[4.0f32, 5.0]).expect("tensor_1d creation should succeed");
375
376        let (x_grid, y_grid) = meshgrid(&x, &y).expect("meshgrid should succeed");
377
378        assert_eq!(x_grid.shape().dims(), &[3, 2]);
379        assert_eq!(y_grid.shape().dims(), &[3, 2]);
380
381        let x_data = x_grid.to_vec().expect("to_vec conversion should succeed");
382        assert_eq!(x_data, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
383
384        let y_data = y_grid.to_vec().expect("to_vec conversion should succeed");
385        assert_eq!(y_data, vec![4.0, 5.0, 4.0, 5.0, 4.0, 5.0]);
386    }
387
388    #[test]
389    fn test_tensor_comprehension_with_device() {
390        let comp = TensorComprehension::new()
391            .device(DeviceType::Cpu)
392            .from_iter(0..3)
393            .build()
394            .expect("builder should produce valid result");
395
396        assert_eq!(comp.device, DeviceType::Cpu);
397    }
398
399    #[test]
400    fn test_linspace_single_step() {
401        let t: Tensor<f32> =
402            linspace_range(5.0, 5.0, 1, DeviceType::Cpu).expect("linspace should succeed");
403        let data = t.to_vec().expect("to_vec conversion should succeed");
404
405        assert_eq!(data.len(), 1);
406        assert!((data[0] - 5.0).abs() < 1e-6);
407    }
408
409    #[test]
410    fn test_range_tensor_zero_step_error() {
411        let result = range_tensor(0, 10, 0, DeviceType::Cpu);
412        assert!(result.is_err());
413    }
414
415    #[test]
416    fn test_linspace_zero_steps_error() {
417        let result: Result<Tensor<f32>> = linspace_range(0.0, 10.0, 0, DeviceType::Cpu);
418        assert!(result.is_err());
419    }
420
421    #[test]
422    fn test_meshgrid_different_sizes() {
423        let x = tensor_1d(&[1.0f32, 2.0]).expect("tensor_1d creation should succeed");
424        let y = tensor_1d(&[3.0f32, 4.0, 5.0]).expect("tensor_1d creation should succeed");
425
426        let (x_grid, y_grid) = meshgrid(&x, &y).expect("meshgrid should succeed");
427
428        assert_eq!(x_grid.shape().dims(), &[2, 3]);
429        assert_eq!(y_grid.shape().dims(), &[2, 3]);
430    }
431}