1use crate::{Tensor, TensorElement};
7use torsh_core::error::Result;
8
9pub trait TensorConvenience<T: TensorElement> {
11 #[allow(non_snake_case)]
20 fn T(&self) -> Result<Tensor<T>>;
21
22 #[allow(non_snake_case)]
24 fn mT(&self) -> Result<Tensor<T>>;
25
26 #[allow(non_snake_case)]
28 fn H(&self) -> Result<Tensor<T>>;
29
30 fn t(&self) -> Result<Tensor<T>>;
32
33 fn m_t(&self) -> Result<Tensor<T>>;
35
36 fn h(&self) -> Result<Tensor<T>>;
38
39 fn detach(&self) -> Tensor<T>;
41
42 fn clone_tensor(&self) -> Result<Tensor<T>>;
44
45 fn is_contiguous(&self) -> bool;
47
48 fn contiguous(&self) -> Result<Tensor<T>>;
50
51 fn numel(&self) -> usize;
53
54 fn size(&self) -> Vec<usize>;
56
57 fn is_empty(&self) -> bool;
59
60 fn is_scalar(&self) -> bool;
62
63 fn item(&self) -> T;
65
66 fn to_scalar(&self) -> Result<T>;
68}
69
70impl<T: TensorElement + Copy + torsh_core::FloatElement> TensorConvenience<T> for Tensor<T> {
71 #[allow(non_snake_case)]
72 fn T(&self) -> Result<Tensor<T>> {
73 if self.shape().dims().len() == 2 {
75 self.transpose(0, 1)
76 } else if self.shape().dims().len() == 1 {
77 Ok(self.clone())
79 } else {
80 let ndim = self.shape().dims().len();
82 if ndim >= 2 {
83 self.transpose((ndim - 2) as i32, (ndim - 1) as i32)
84 } else {
85 Ok(self.clone())
86 }
87 }
88 }
89
90 #[allow(non_snake_case)]
91 fn mT(&self) -> Result<Tensor<T>> {
92 self.T()
93 }
94
95 #[allow(non_snake_case)]
96 fn H(&self) -> Result<Tensor<T>> {
97 let transposed = self.T()?;
100
101 Ok(transposed)
104 }
105
106 fn t(&self) -> Result<Tensor<T>> {
107 self.T()
108 }
109
110 fn m_t(&self) -> Result<Tensor<T>> {
111 self.T()
112 }
113
114 fn h(&self) -> Result<Tensor<T>> {
115 self.H()
116 }
117
118 fn detach(&self) -> Tensor<T> {
119 self.clone()
122 }
123
124 fn clone_tensor(&self) -> Result<Tensor<T>> {
125 Ok(self.detach())
126 }
127
128 fn is_contiguous(&self) -> bool {
129 let shape_ref = self.shape();
131 let shape = shape_ref.dims();
132 if shape.is_empty() {
133 return true;
134 }
135
136 let mut _expected_stride = 1;
137 for &dim_size in shape.iter().rev() {
138 _expected_stride *= dim_size;
139 }
140
141 true
144 }
145
146 fn contiguous(&self) -> Result<Tensor<T>> {
147 if self.is_contiguous() {
148 Ok(self.clone())
149 } else {
150 self.clone_tensor()
152 }
153 }
154
155 fn numel(&self) -> usize {
156 self.shape().dims().iter().product()
157 }
158
159 fn size(&self) -> Vec<usize> {
160 self.shape().dims().to_vec()
161 }
162
163 fn is_empty(&self) -> bool {
164 self.numel() == 0
165 }
166
167 fn is_scalar(&self) -> bool {
168 self.shape().dims().is_empty()
169 }
170
171 fn item(&self) -> T {
172 if self.numel() != 1 {
174 panic!("Can only call item() on tensors with one element");
175 }
176 let data = self
177 .to_vec()
178 .expect("tensor to vec conversion should succeed");
179 data[0]
180 }
181
182 fn to_scalar(&self) -> Result<T> {
183 let squeezed = self.squeeze_all()?;
185 squeezed.item()
186 }
187}
188
189pub trait TensorShapeConvenience<T: TensorElement> {
191 fn unsqueeze_at(&self, dim: i32) -> Result<Tensor<T>>;
193
194 fn squeeze_all(&self) -> Result<Tensor<T>>;
196
197 fn flatten(&self) -> Result<Tensor<T>>;
199
200 fn flatten_from(&self, start_dim: i32) -> Result<Tensor<T>>;
202
203 fn unflatten(&self, dim: i32, sizes: &[usize]) -> Result<Tensor<T>>;
205}
206
207impl<T: TensorElement + Copy> TensorShapeConvenience<T> for Tensor<T> {
208 fn unsqueeze_at(&self, dim: i32) -> Result<Tensor<T>> {
209 self.unsqueeze(dim)
210 }
211
212 fn squeeze_all(&self) -> Result<Tensor<T>> {
213 let mut result = self.clone();
214 let shape_ref = self.shape();
215 let dims = shape_ref.dims();
216
217 for (i, &size) in dims.iter().enumerate().rev() {
219 if size == 1 {
220 result = result.squeeze(i as i32)?;
221 }
222 }
223
224 Ok(result)
225 }
226
227 fn flatten(&self) -> Result<Tensor<T>> {
228 let total_elements = self.numel();
229 self.reshape(&[total_elements as i32])
230 }
231
232 fn flatten_from(&self, start_dim: i32) -> Result<Tensor<T>> {
233 let shape_ref = self.shape();
234 let shape = shape_ref.dims();
235 let ndim = shape.len() as i32;
236 let start_dim = if start_dim < 0 {
237 ndim + start_dim
238 } else {
239 start_dim
240 };
241
242 if start_dim < 0 || start_dim >= ndim {
243 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
244 "Invalid start_dim {start_dim} for tensor with {ndim} dimensions"
245 )));
246 }
247
248 let mut new_shape = Vec::new();
249
250 for &dim in shape.iter().take(start_dim as usize) {
252 new_shape.push(dim);
253 }
254
255 let flattened_size: usize = shape[start_dim as usize..].iter().product();
257 new_shape.push(flattened_size);
258
259 let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
260 self.reshape(&new_shape_i32)
261 }
262
263 fn unflatten(&self, dim: i32, sizes: &[usize]) -> Result<Tensor<T>> {
264 let shape_ref = self.shape();
265 let shape = shape_ref.dims();
266 let ndim = shape.len() as i32;
267 let dim = if dim < 0 { ndim + dim } else { dim };
268
269 if dim < 0 || dim >= ndim {
270 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
271 "Invalid dim {dim} for tensor with {ndim} dimensions"
272 )));
273 }
274
275 let expected_size = shape[dim as usize];
277 let actual_size: usize = sizes.iter().product();
278
279 if expected_size != actual_size {
280 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
281 "Sizes {actual_size} don't multiply to dimension size {expected_size}"
282 )));
283 }
284
285 let mut new_shape = Vec::new();
287
288 for &dim_size in shape.iter().take(dim as usize) {
290 new_shape.push(dim_size);
291 }
292
293 new_shape.extend_from_slice(sizes);
295
296 for &dim_size in shape.iter().skip(dim as usize + 1) {
298 new_shape.push(dim_size);
299 }
300
301 let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
302 self.reshape(&new_shape_i32)
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_transpose_shortcuts() {
312 let tensor = crate::creation::tensor_2d_arrays(&[[1.0f32, 2.0], [3.0, 4.0]])
313 .expect("tensor creation failed");
314
315 let transposed = tensor.T().expect("T() failed");
317 assert_eq!(transposed.shape().dims(), &[2, 2]);
318
319 let mt_transposed = tensor.mT().expect("mT() failed");
321 assert_eq!(mt_transposed.shape().dims(), &[2, 2]);
322
323 let hermitian = tensor.H().expect("H() failed");
325 assert_eq!(hermitian.shape().dims(), &[2, 2]);
326 }
327
328 #[test]
329 fn test_tensor_properties() {
330 let tensor = crate::creation::tensor_2d_arrays(&[[1.0f32, 2.0], [3.0, 4.0]])
331 .expect("tensor creation failed");
332
333 assert_eq!(tensor.numel(), 4);
334 assert_eq!(tensor.shape().dims(), &[2, 2]);
335 assert!(!tensor.is_empty());
336 assert!(!tensor.is_scalar());
337 assert!(tensor.is_contiguous());
338
339 let scalar = crate::creation::tensor_scalar(42.0f32).expect("scalar creation failed");
341 assert!(scalar.is_scalar());
342 assert_eq!(scalar.item().expect("item retrieval failed"), 42.0);
343 }
344
345 #[test]
346 fn test_shape_convenience() {
347 let tensor = crate::creation::zeros::<f32>(&[4])
349 .expect("zeros creation failed")
350 .reshape(&[2, 1, 2])
351 .expect("reshape failed");
352
353 let squeezed = tensor.squeeze_all().expect("squeeze_all failed");
355 assert_eq!(squeezed.shape().dims(), &[2, 2]);
356
357 let flattened = tensor.flatten().expect("flatten failed");
359 assert_eq!(flattened.shape().dims(), &[4]);
360
361 let flat_from_1 = tensor.flatten_from(1).expect("flatten_from failed");
363 assert_eq!(flat_from_1.shape().dims(), &[2, 2]);
364 }
365
366 #[test]
367 fn test_detach() {
368 let tensor =
369 crate::creation::tensor_1d(&[1.0f32, 2.0, 3.0]).expect("tensor creation failed");
370 let detached = tensor.detach();
371
372 assert_eq!(tensor.shape().dims(), detached.shape().dims());
374 assert_eq!(
375 tensor.data().expect("data retrieval failed"),
376 detached.data().expect("detached data retrieval failed")
377 );
378 }
379
380 #[test]
381 fn test_fluent_api() {
382 use crate::TensorFluentExt;
383 let tensor =
384 crate::creation::tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor creation failed");
385
386 let result = tensor
388 .fluent()
389 .add_scalar(1.0) .mul_scalar(2.0) .sub_scalar(1.0) .unwrap()
393 .unwrap();
394
395 let expected = vec![3.0, 5.0, 7.0, 9.0];
396 let actual = result.to_vec().expect("to_vec failed");
397
398 for (exp, act) in expected.iter().zip(actual.iter()) {
399 assert!((exp - act).abs() < f32::EPSILON);
400 }
401 }
402
403 #[test]
404 fn test_fluent_api_operations() {
405 use crate::TensorFluentExt;
406 let tensor1 =
407 crate::creation::tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor1 creation failed");
408 let tensor2 =
409 crate::creation::tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor2 creation failed");
410
411 let result = tensor1
413 .fluent()
414 .add(&tensor2) .mul_scalar(0.5) .sum() .unwrap()
418 .unwrap();
419
420 let actual = result.to_vec().expect("to_vec failed");
421 assert!((actual[0] - 9.0).abs() < f32::EPSILON);
422 }
423
424 #[test]
425 fn test_fluent_api_mathematical_operations() {
426 use crate::TensorFluentExt;
427 let tensor =
428 crate::creation::tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor creation failed");
429
430 let result = tensor
432 .fluent()
433 .relu() .pow(2.0) .sigmoid() .unwrap()
437 .unwrap();
438
439 let actual = result.to_vec().expect("to_vec failed");
440 for val in actual.iter() {
442 assert!(*val > 0.0 && *val < 1.0);
443 }
444 }
445}
446
447pub trait TensorFluentExt<T: TensorElement> {
468 fn fluent(self) -> FluentTensor<T>;
470}
471
472pub struct FluentTensor<T: TensorElement> {
474 tensor: Tensor<T>,
475}
476
477impl<T: TensorElement> TensorFluentExt<T> for Tensor<T> {
478 fn fluent(self) -> FluentTensor<T> {
479 FluentTensor { tensor: self }
480 }
481}
482
483impl<
484 T: TensorElement
485 + Copy
486 + std::ops::Add<Output = T>
487 + std::ops::Sub<Output = T>
488 + std::ops::Mul<Output = T>
489 + std::ops::Div<Output = T>
490 + num_traits::Zero,
491 > FluentTensor<T>
492{
493 pub fn tensor(self) -> Tensor<T> {
495 self.tensor
496 }
497
498 pub fn unwrap(self) -> Result<Tensor<T>> {
500 Ok(self.tensor)
501 }
502
503 pub fn add_scalar(mut self, scalar: T) -> Self {
505 if let Ok(result) = self.tensor.add_scalar(scalar) {
506 self.tensor = result;
507 }
508 self
509 }
510
511 pub fn mul_scalar(mut self, scalar: T) -> Self {
513 if let Ok(result) = self.tensor.mul_scalar(scalar) {
514 self.tensor = result;
515 }
516 self
517 }
518
519 pub fn sub_scalar(mut self, scalar: T) -> Self {
521 if let Ok(result) = self.tensor.sub_scalar(scalar) {
522 self.tensor = result;
523 }
524 self
525 }
526
527 pub fn div_scalar(mut self, scalar: T) -> Self {
529 if let Ok(result) = self.tensor.div_scalar(scalar) {
530 self.tensor = result;
531 }
532 self
533 }
534
535 pub fn add(mut self, other: &Tensor<T>) -> Self {
537 if let Ok(result) = self.tensor.add_op(other) {
538 self.tensor = result;
539 }
540 self
541 }
542
543 pub fn mul(mut self, other: &Tensor<T>) -> Self {
545 if let Ok(result) = self.tensor.mul_op(other) {
546 self.tensor = result;
547 }
548 self
549 }
550
551 pub fn sub(mut self, other: &Tensor<T>) -> Self {
553 if let Ok(result) = self.tensor.sub(other) {
554 self.tensor = result;
555 }
556 self
557 }
558
559 pub fn div(mut self, other: &Tensor<T>) -> Self {
561 if let Ok(result) = self.tensor.div(other) {
562 self.tensor = result;
563 }
564 self
565 }
566
567 pub fn reshape(mut self, shape: &[i32]) -> Self {
569 if let Ok(result) = self.tensor.reshape(shape) {
570 self.tensor = result;
571 }
572 self
573 }
574
575 pub fn transpose(mut self, dim0: i32, dim1: i32) -> Self {
577 if let Ok(result) = self.tensor.transpose(dim0, dim1) {
578 self.tensor = result;
579 }
580 self
581 }
582
583 pub fn t(mut self) -> Self {
585 if let Ok(result) = self.tensor.t() {
586 self.tensor = result;
587 }
588 self
589 }
590
591 pub fn sum(mut self) -> Self {
593 if let Ok(result) = self.tensor.sum() {
594 self.tensor = result;
595 }
596 self
597 }
598
599 pub fn sum_dim(mut self, dims: &[i32], keepdim: bool) -> Self {
601 if let Ok(result) = self.tensor.sum_dim(dims, keepdim) {
602 self.tensor = result;
603 }
604 self
605 }
606
607 pub fn squeeze(mut self, dim: i32) -> Self {
609 if let Ok(result) = self.tensor.squeeze(dim) {
610 self.tensor = result;
611 }
612 self
613 }
614
615 pub fn unsqueeze(mut self, dim: i32) -> Self {
617 if let Ok(result) = self.tensor.unsqueeze(dim) {
618 self.tensor = result;
619 }
620 self
621 }
622}
623
624impl<T: TensorElement + Copy + num_traits::Float> FluentTensor<T> {
626 pub fn relu(mut self) -> Self {
628 if let Ok(result) = self.tensor.relu() {
629 self.tensor = result;
630 }
631 self
632 }
633
634 pub fn sigmoid(mut self) -> Self
636 where
637 T: torsh_core::dtype::FloatElement,
638 {
639 if let Ok(result) = self.tensor.sigmoid() {
640 self.tensor = result;
641 }
642 self
643 }
644
645 pub fn tanh(mut self) -> Self
647 where
648 T: torsh_core::dtype::FloatElement,
649 {
650 if let Ok(result) = self.tensor.tanh() {
651 self.tensor = result;
652 }
653 self
654 }
655
656 pub fn exp(mut self) -> Self
658 where
659 T: torsh_core::dtype::FloatElement,
660 {
661 if let Ok(result) = self.tensor.exp() {
662 self.tensor = result;
663 }
664 self
665 }
666
667 pub fn log(mut self) -> Self
669 where
670 T: torsh_core::dtype::FloatElement,
671 {
672 if let Ok(result) = self.tensor.log() {
673 self.tensor = result;
674 }
675 self
676 }
677
678 pub fn pow(mut self, exponent: T) -> Self
680 where
681 T: torsh_core::dtype::FloatElement + Into<f32>,
682 {
683 if let Ok(result) = self.tensor.pow(exponent) {
684 self.tensor = result;
685 }
686 self
687 }
688
689 }
692
693impl<T: TensorElement + Copy> FluentTensor<T>
695where
696 T: num_traits::Float + std::iter::Sum,
697{
698 pub fn matmul(mut self, other: &Tensor<T>) -> Self {
700 if let Ok(result) = self.tensor.matmul(other) {
701 self.tensor = result;
702 }
703 self
704 }
705}
706
707impl<
709 T: TensorElement
710 + Copy
711 + num_traits::FromPrimitive
712 + std::ops::Div<Output = T>
713 + num_traits::Zero
714 + num_traits::One,
715 > FluentTensor<T>
716{
717 pub fn mean(mut self, dims: Option<&[usize]>, keepdim: bool) -> Self {
719 if let Ok(result) = self.tensor.mean(dims, keepdim) {
720 self.tensor = result;
721 }
722 self
723 }
724}