1use torsh_core::{device::DeviceType, dtype::TensorElement, error::Result};
34
35use crate::Tensor;
36
37pub struct TensorComprehension<T: TensorElement> {
39 elements: Vec<T>,
40 shape: Vec<usize>,
41 device: DeviceType,
42}
43
44impl<T: TensorElement + Copy> TensorComprehension<T> {
45 pub fn new() -> Self {
47 Self {
48 elements: Vec::new(),
49 shape: Vec::new(),
50 device: DeviceType::Cpu,
51 }
52 }
53
54 pub fn device(mut self, device: DeviceType) -> Self {
56 self.device = device;
57 self
58 }
59
60 pub fn from_iter<I>(mut self, iter: I) -> Self
62 where
63 I: IntoIterator<Item = T>,
64 {
65 self.elements = iter.into_iter().collect();
66 if self.shape.is_empty() {
67 self.shape = vec![self.elements.len()];
68 }
69 self
70 }
71
72 pub fn from_iter_with_shape<I>(mut self, iter: I, shape: Vec<usize>) -> Self
74 where
75 I: IntoIterator<Item = T>,
76 {
77 self.elements = iter.into_iter().collect();
78 self.shape = shape;
79 self
80 }
81
82 pub fn build(self) -> Result<Tensor<T>> {
84 Tensor::from_data(self.elements, self.shape, self.device)
85 }
86}
87
88impl<T: TensorElement + Copy> Default for TensorComprehension<T> {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94pub fn range_tensor<T>(start: T, end: T, step: T, device: DeviceType) -> Result<Tensor<T>>
96where
97 T: TensorElement + Copy + std::ops::Add<Output = T> + std::cmp::PartialOrd + num_traits::Zero,
98{
99 let mut elements = Vec::new();
100 let mut current = start;
101
102 if step == <T as torsh_core::TensorElement>::zero() {
103 return Err(torsh_core::error::TorshError::InvalidArgument(
104 "Step cannot be zero".to_string(),
105 ));
106 }
107
108 let ascending = start < end;
110 while (ascending && current < end) || (!ascending && current > end) {
111 elements.push(current);
112 current = current + step;
113 }
114
115 let len = elements.len();
116 Tensor::from_data(elements, vec![len], device)
117}
118
119pub fn linspace_range<T>(
121 start: f64,
122 end: f64,
123 steps: usize,
124 device: DeviceType,
125) -> Result<Tensor<T>>
126where
127 T: TensorElement + Copy + num_traits::FromPrimitive,
128{
129 if steps == 0 {
130 return Err(torsh_core::error::TorshError::InvalidArgument(
131 "Steps must be greater than 0".to_string(),
132 ));
133 }
134
135 let step = if steps == 1 {
136 0.0
137 } else {
138 (end - start) / (steps - 1) as f64
139 };
140
141 let elements: Vec<T> = (0..steps)
142 .map(|i| {
143 let val = start + step * i as f64;
144 <T as torsh_core::TensorElement>::from_f64(val)
145 .unwrap_or_else(|| <T as torsh_core::TensorElement>::zero())
146 })
147 .collect();
148
149 Tensor::from_data(elements, vec![steps], device)
150}
151
152pub fn logspace<T>(
154 start: f64,
155 end: f64,
156 steps: usize,
157 base: f64,
158 device: DeviceType,
159) -> Result<Tensor<T>>
160where
161 T: TensorElement + Copy + num_traits::FromPrimitive,
162{
163 if steps == 0 {
164 return Err(torsh_core::error::TorshError::InvalidArgument(
165 "Steps must be greater than 0".to_string(),
166 ));
167 }
168
169 let step = if steps == 1 {
170 0.0
171 } else {
172 (end - start) / (steps - 1) as f64
173 };
174
175 let elements: Vec<T> = (0..steps)
176 .map(|i| {
177 let exponent = start + step * i as f64;
178 let val = base.powf(exponent);
179 <T as torsh_core::TensorElement>::from_f64(val)
180 .unwrap_or_else(|| <T as torsh_core::TensorElement>::zero())
181 })
182 .collect();
183
184 Tensor::from_data(elements, vec![steps], device)
185}
186
187pub fn meshgrid<T>(x: &Tensor<T>, y: &Tensor<T>) -> Result<(Tensor<T>, Tensor<T>)>
189where
190 T: TensorElement + Copy,
191{
192 let x_data = x.to_vec()?;
193 let y_data = y.to_vec()?;
194
195 let nx = x_data.len();
196 let ny = y_data.len();
197
198 let mut x_grid = Vec::with_capacity(nx * ny);
200 for &x_val in &x_data {
201 for _ in 0..ny {
202 x_grid.push(x_val);
203 }
204 }
205
206 let mut y_grid = Vec::with_capacity(nx * ny);
208 for _ in 0..nx {
209 for &y_val in &y_data {
210 y_grid.push(y_val);
211 }
212 }
213
214 let x_tensor = Tensor::from_data(x_grid, vec![nx, ny], x.device)?;
215 let y_tensor = Tensor::from_data(y_grid, vec![nx, ny], y.device)?;
216
217 Ok((x_tensor, y_tensor))
218}
219
220#[macro_export]
222macro_rules! tensor_comp {
223 ($expr:expr; $var:ident in $start:expr, $end:expr) => {{
225 let elements: Vec<_> = ($start..$end).map(|$var| $expr).collect();
226 $crate::Tensor::from_data(elements, vec![elements.len()], $crate::DeviceType::Cpu)
227 }};
228
229 ($expr:expr; $var:ident in $start:expr, $end:expr, step $step:expr) => {{
231 let mut elements = Vec::new();
232 let mut $var = $start;
233 while $var < $end {
234 elements.push($expr);
235 $var = $var + $step;
236 }
237 $crate::Tensor::from_data(elements, vec![elements.len()], $crate::DeviceType::Cpu)
238 }};
239
240 ($expr:expr; $var:ident in $start:expr, $end:expr, if $cond:expr) => {{
242 let elements: Vec<_> = ($start..$end)
243 .filter(|&$var| $cond)
244 .map(|$var| $expr)
245 .collect();
246 $crate::Tensor::from_data(elements, vec![elements.len()], $crate::DeviceType::Cpu)
247 }};
248
249 ([$expr:expr; $inner_var:ident in $inner_start:expr, $inner_end:expr]; $outer_var:ident in $outer_start:expr, $outer_end:expr) => {{
251 let mut all_elements = Vec::new();
252 let rows = $outer_end - $outer_start;
253 let cols = $inner_end - $inner_start;
254
255 for $outer_var in $outer_start..$outer_end {
256 for $inner_var in $inner_start..$inner_end {
257 all_elements.push($expr);
258 }
259 }
260 $crate::Tensor::from_data(all_elements, vec![rows, cols], $crate::DeviceType::Cpu)
261 }};
262}
263
264#[macro_export]
266macro_rules! tensor_repeat {
267 ($value:expr; $count:expr) => {{
269 let elements = vec![$value; $count];
270 $crate::Tensor::from_data(elements, vec![$count], $crate::DeviceType::Cpu)
271 }};
272
273 ($value:expr; [$($dim:expr),+]) => {{
275 let shape = vec![$($dim),+];
276 let size: usize = shape.iter().product();
277 let elements = vec![$value; size];
278 $crate::Tensor::from_data(elements, shape, $crate::DeviceType::Cpu)
279 }};
280}
281
282#[macro_export]
284macro_rules! tensor_eye {
285 ($n:expr) => {{
287 tensor_eye![$n, $n]
288 }};
289
290 ($m:expr, $n:expr) => {{
292 let mut elements = vec![0.0f32; $m * $n];
293 let min_dim = std::cmp::min($m, $n);
294 for i in 0..min_dim {
295 elements[i * $n + i] = 1.0;
296 }
297 $crate::Tensor::from_data(elements, vec![$m, $n], $crate::DeviceType::Cpu)
298 }};
299
300 ($m:expr, $n:expr, offset $k:expr) => {{
302 let mut elements = vec![0.0f32; $m * $n];
303 if $k >= 0 {
304 let k = $k as usize;
305 for i in 0..$m {
306 let j = i + k;
307 if j < $n {
308 elements[i * $n + j] = 1.0;
309 }
310 }
311 } else {
312 let k = (-$k) as usize;
313 for j in 0..$n {
314 let i = j + k;
315 if i < $m {
316 elements[i * $n + j] = 1.0;
317 }
318 }
319 }
320 $crate::Tensor::from_data(elements, vec![$m, $n], $crate::DeviceType::Cpu)
321 }};
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use crate::creation::*;
328
329 #[test]
330 fn test_tensor_comprehension_builder() {
331 let comp = TensorComprehension::new()
332 .from_iter(0..5)
333 .build()
334 .expect("builder should produce valid result");
335
336 let data = comp.to_vec().expect("to_vec conversion should succeed");
337 assert_eq!(data, vec![0, 1, 2, 3, 4]);
338 }
339
340 #[test]
341 fn test_range_tensor() {
342 let t = range_tensor(0, 10, 2, DeviceType::Cpu).expect("range_tensor should succeed");
343 let data = t.to_vec().expect("to_vec conversion should succeed");
344 assert_eq!(data, vec![0, 2, 4, 6, 8]);
345 }
346
347 #[test]
348 fn test_linspace() {
349 let t: Tensor<f32> =
350 linspace_range(0.0, 10.0, 5, DeviceType::Cpu).expect("linspace should succeed");
351 let data = t.to_vec().expect("to_vec conversion should succeed");
352
353 assert!((data[0] - 0.0).abs() < 1e-6);
354 assert!((data[1] - 2.5).abs() < 1e-6);
355 assert!((data[2] - 5.0).abs() < 1e-6);
356 assert!((data[3] - 7.5).abs() < 1e-6);
357 assert!((data[4] - 10.0).abs() < 1e-6);
358 }
359
360 #[test]
361 fn test_logspace() {
362 let t: Tensor<f32> =
363 logspace(0.0, 2.0, 3, 10.0, DeviceType::Cpu).expect("logspace should succeed");
364 let data = t.to_vec().expect("to_vec conversion should succeed");
365
366 assert!((data[0] - 1.0).abs() < 1e-6); assert!((data[1] - 10.0).abs() < 1e-5); assert!((data[2] - 100.0).abs() < 1e-4); }
370
371 #[test]
372 fn test_meshgrid() {
373 let x = tensor_1d(&[1.0f32, 2.0, 3.0]).expect("tensor_1d creation should succeed");
374 let y = tensor_1d(&[4.0f32, 5.0]).expect("tensor_1d creation should succeed");
375
376 let (x_grid, y_grid) = meshgrid(&x, &y).expect("meshgrid should succeed");
377
378 assert_eq!(x_grid.shape().dims(), &[3, 2]);
379 assert_eq!(y_grid.shape().dims(), &[3, 2]);
380
381 let x_data = x_grid.to_vec().expect("to_vec conversion should succeed");
382 assert_eq!(x_data, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
383
384 let y_data = y_grid.to_vec().expect("to_vec conversion should succeed");
385 assert_eq!(y_data, vec![4.0, 5.0, 4.0, 5.0, 4.0, 5.0]);
386 }
387
388 #[test]
389 fn test_tensor_comprehension_with_device() {
390 let comp = TensorComprehension::new()
391 .device(DeviceType::Cpu)
392 .from_iter(0..3)
393 .build()
394 .expect("builder should produce valid result");
395
396 assert_eq!(comp.device, DeviceType::Cpu);
397 }
398
399 #[test]
400 fn test_linspace_single_step() {
401 let t: Tensor<f32> =
402 linspace_range(5.0, 5.0, 1, DeviceType::Cpu).expect("linspace should succeed");
403 let data = t.to_vec().expect("to_vec conversion should succeed");
404
405 assert_eq!(data.len(), 1);
406 assert!((data[0] - 5.0).abs() < 1e-6);
407 }
408
409 #[test]
410 fn test_range_tensor_zero_step_error() {
411 let result = range_tensor(0, 10, 0, DeviceType::Cpu);
412 assert!(result.is_err());
413 }
414
415 #[test]
416 fn test_linspace_zero_steps_error() {
417 let result: Result<Tensor<f32>> = linspace_range(0.0, 10.0, 0, DeviceType::Cpu);
418 assert!(result.is_err());
419 }
420
421 #[test]
422 fn test_meshgrid_different_sizes() {
423 let x = tensor_1d(&[1.0f32, 2.0]).expect("tensor_1d creation should succeed");
424 let y = tensor_1d(&[3.0f32, 4.0, 5.0]).expect("tensor_1d creation should succeed");
425
426 let (x_grid, y_grid) = meshgrid(&x, &y).expect("meshgrid should succeed");
427
428 assert_eq!(x_grid.shape().dims(), &[2, 3]);
429 assert_eq!(y_grid.shape().dims(), &[2, 3]);
430 }
431}