1use std::rc::Rc;
8use std::cell::RefCell;
9use ::rand::prelude::StdRng;
11
12#[cfg(feature = "use-serde")]
13use serde::{Serialize, Deserialize};
14
15use std::fmt;
16
17
18use super::typed_tensor::TypedTensor;
19use crate::tensor_impl::gen_tensor::GenTensor;
20
21macro_rules! tensor_method {
23 ($a:ident) => {
24 pub fn $a(&self, o: &Tensor) -> Tensor {
25 Tensor {
26 v: Rc::new(RefCell::new(self.v.borrow().$a(&o.v.borrow()))),
27 }
28 }
29 }
30}
31
32macro_rules! tensor_method_2_to_1option {
34 ($a:ident) => {
35 pub fn $a(&self, o: &Tensor) -> Option<Tensor> {
36 self.v.borrow().$a(&o.v.borrow()).map(|v| Tensor {
37 v: Rc::new(RefCell::new(v))})
38 }
39 }
40}
41
42macro_rules! tensor_method_single_same_return {
44 ($a:ident, $b:ty) => {
45 pub fn $a(&self) -> $b {
46 self.v.borrow().$a()
47 }
48 }
49}
50
51macro_rules! tensor_method_single_tensor_return {
53 ($a:ident) => {
54 pub fn $a(&self) -> Tensor {
55 Tensor {
56 v: Rc::new(RefCell::new(self.v.borrow().$a())),
57 }
58 }
59 }
60}
61
62macro_rules! tensor_method_1_option_tensor_return {
64 ($a:ident) => {
65 pub fn $a(&self) -> Option<Tensor> {
66 let r = self.v.borrow().$a();
67 r.map(|r1| Tensor {
68 v: Rc::new(RefCell::new(r1)),
69 })
70 }
71 }
72}
73
74macro_rules! tensor_method_2_option_tensor_return {
76 ($a:ident) => {
77 pub fn $a(&self) -> Option<[Tensor; 2]> {
78 let r = self.v.borrow().$a();
79 r.map(|[r1, r2]| [Tensor {
80 v: Rc::new(RefCell::new(r1)),},
81 Tensor {
82 v: Rc::new(RefCell::new(r2)),
83 }])
84 }
85 }
86}
87
88macro_rules! tensor_method_3_option_tensor_return {
90 ($a:ident) => {
91 pub fn $a(&self) -> Option<[Tensor; 3]> {
92 let r = self.v.borrow().$a();
93 r.map(|[r1, r2, r3]| [
94 Tensor {
95 v: Rc::new(RefCell::new(r1)),},
96 Tensor {
97 v: Rc::new(RefCell::new(r2)),
98 },
99 Tensor {
100 v: Rc::new(RefCell::new(r3)),
101 }
102 ])
103 }
104 }
105}
106
107
108
109pub struct Tensor {
110 v: Rc<RefCell<TypedTensor>>,
111}
112
113impl Default for Tensor {
114 fn default() -> Tensor {
115 Tensor {
116 v: Rc::new(RefCell::new(TypedTensor::new())),
117 }
118 }
119}
120
121impl Tensor {
122 pub fn new() -> Tensor {
123 Tensor {
124 v: Rc::new(RefCell::new(TypedTensor::new())),
125 }
126 }
127
128 pub fn data_copy(&self, o: &Tensor) {
129 self.v.borrow_mut().data_copy(&o.v.borrow());
130 }
131
132 pub fn swap(&self, o: &Tensor) {
133 self.v.swap(&o.v);
134 }
135
136 pub fn ref_copy(&self) -> Tensor {
137 Tensor {
138 v: self.v.clone(),
139 }
140 }
141
142 pub fn index2dimpos(&self, index: usize) -> Vec::<usize> {
145 self.v.borrow().index2dimpos(index)
146 }
147 pub fn dimpos2index(&self, dimpos: &[usize]) -> usize {
150 self.v.borrow().dimpos2index(dimpos)
151 }
152
153 pub fn is_empty() -> bool {
154 unimplemented!();
155 }
156
157 pub fn size(&self) -> Vec<usize> {
158 self.v.borrow().size().clone()
159 }
160 tensor_method_single_same_return!(numel, usize);
161
162 pub fn get_scale_f32(&self) -> f32 {
163 self.v.borrow().get_scale_f32()
164 }
165 pub fn get_scale_f64(&self) -> f64 {
166 self.v.borrow().get_scale_f64()
167 }
168
169 tensor_method_single_tensor_return!(get_n);
170 tensor_method_single_tensor_return!(get_c);
171 tensor_method_single_tensor_return!(get_d);
172 tensor_method_single_tensor_return!(get_h);
173 tensor_method_single_tensor_return!(get_w);
174 tensor_method_single_tensor_return!(numel_tensor);
175
176 pub fn get_patch(&self, range: &[(usize, usize)], step: Option<&[usize]>) -> Tensor {
177 Tensor {
178 v: Rc::new(RefCell::new(self.v.borrow().get_patch(range, step)))
179 }
180 }
181 pub fn set_patch(&self, other: &Tensor,
182 range: &[(usize, usize)], step: Option<&[usize]>) -> Tensor {
183 Tensor {
184 v: Rc::new(RefCell::new(self.v.borrow().set_patch(
185 &other.v.borrow(), range, step)))
186 }
187 }
188
189 pub fn same_shape(&self, o: &Tensor) -> bool {
190 let a = self.size();
191 let b = o.size();
192 a == b
193 }
194
195
196 pub fn from_vec_usize(input: &[usize], dim: &[usize]) -> Tensor {
197 let data: Vec<f32> = input.iter().map(|x| *x as f32).collect();
198 Self::from_vec_f32(&data, dim)
199 }
200
201 pub fn from_vec_f32(input: &[f32], dim: &[usize]) -> Tensor {
207 let data = input.to_vec();
208 let idim = dim.to_vec();
209
210 Tensor {
211 v: Rc::new(RefCell::new(TypedTensor::Typef32(GenTensor::new_raw(&data, &idim) ))),
213 }
214 }
215 pub fn get_raw_f32(&self) -> Vec<f32> {
218 self.v.borrow().get_raw_f32()
219 }
220 pub fn from_vec_f64(input: &[f64], dim: &[usize]) -> Tensor {
221 let data = input.to_vec();
222 let idim = dim.to_vec();
223
224 Tensor {
225 v: Rc::new(RefCell::new(TypedTensor::Typef64(GenTensor::new_raw(&data, &idim) ))),
227 }
228 }
229 pub fn get_raw_f64(&self) -> Vec<f64> {
232 self.v.borrow().get_raw_f64()
233 }
234
235 pub fn get_u8(&self) -> Option<Vec<u8>> {
237 self.v.borrow().get_u8()
238 }
239
240
241 pub fn from_record_f32(&self, row: usize, record: &[f32]) -> Result<(), &'static str> {
242 self.v.borrow_mut().from_record_f32(row, record)
243 }
244 pub fn from_record_f64(&self, row: usize, record: &[f64]) -> Result<(), &'static str> {
245 self.v.borrow_mut().from_record_f64(row, record)
246 }
247 pub fn get_f32(&self, o: &[usize]) -> f32 {
248 self.v.borrow().get_f32(o)
249 }
250 pub fn set_f32(&mut self, o: &[usize], v: f32) {
251 self.v.borrow_mut().set_f32(o, v);
252 }
253
254 pub fn get_f64(&self, o: &[usize]) -> f64 {
255 self.v.borrow().get_f64(o)
256 }
257 pub fn set_f64(&mut self, o: &[usize], v: f64) {
258 self.v.borrow_mut().set_f64(o, v);
259 }
260
261
262
263 pub fn fill(size: &[usize], fill_value: &Tensor) -> Tensor {
265 Tensor {
266 v: Rc::new(RefCell::new(TypedTensor::fill(size, &fill_value.v.borrow()))),
267 }
268 }
269 pub fn fill_f32(size: &[usize], fill_value: f32) -> Tensor {
270 Tensor {
271 v: Rc::new(RefCell::new(TypedTensor::fill_f32(size, fill_value))),
272 }
273 }
274 pub fn fill_f64(size: &[usize], fill_value: f64) -> Tensor {
275 Tensor {
276 v: Rc::new(RefCell::new(TypedTensor::fill_f64(size, fill_value))),
277 }
278 }
279
280 pub fn zeros(dim: &[usize]) -> Tensor {
282 Tensor {
283 #[cfg(feature = "use-f64")]
284 v: Rc::new(RefCell::new(TypedTensor::zeros_f64(dim))),
285 #[cfg(feature = "use-f32")]
286 v: Rc::new(RefCell::new(TypedTensor::zeros_f32(dim))),
287 }
288 }
289 tensor_method_single_tensor_return!(zeros_like);
291 pub fn ones(dim: &[usize]) -> Tensor {
293 Tensor {
294 #[cfg(feature = "use-f64")]
295 v: Rc::new(RefCell::new(TypedTensor::ones_f64(dim))),
296 #[cfg(feature = "use-f32")]
297 v: Rc::new(RefCell::new(TypedTensor::ones_f32(dim))),
298 }
299 }
300 pub fn twos(dim: &[usize]) -> Tensor {
301 let a = Self::ones(dim);
302 a.add(&a)
303 }
304 pub fn int_n(dim: &[usize], n: isize) -> Tensor {
305 let abs_n = n.abs();
306 let mut a = Self::ones(dim);
307 let b = Self::ones(dim);
308 for _i in 0..(abs_n-1) {
309 a = a.add(&b);
310 }
311 if n >= 0 {
312 a
313 } else {
314 a.neg()
315 }
316 }
317 tensor_method_single_tensor_return!(ones_like);
319 pub fn range(start: f32, end: f32, step: Option<f32>) -> Tensor {
321 let real_step = if let Some(v) = step {
322 v
323 } else {
324 1.
325 };
326
327 let mut value = start;
328 let mut index = 0;
329 let mut data = Vec::new();
330 while value <= end {
331 value += real_step;
332 data.push(value);
333 index += 1;
334 }
335
336 Tensor::from_vec_f32(&data, &[index])
337 }
338 pub fn linspace(start: f32, end: f32, steps: usize) -> Tensor {
340 let real_step = (end-start)/(steps as f32);
341
342 let mut value = start;
343 let mut index = 0;
344 let mut data = Vec::new();
345 while value <= end {
346 value += real_step;
347 data.push(value);
348 index += 1;
349 }
350
351 Tensor::from_vec_f32(&data, &[index])
352 }
353 pub fn logspace(start: f32, end: f32, steps: usize, base: f32) -> Tensor {
355 let linspace_data = Tensor::linspace(start, end, steps);
356 let mut ret_data = Vec::new();
357 for i in 0..linspace_data.numel() {
358 ret_data.push(base.powf(linspace_data.get_f32(&[i])));
359 }
360 Tensor::from_vec_f32(&ret_data, &[ret_data.len()])
361 }
362 pub fn eye(n: usize, m: usize) -> Tensor {
364 let ret = Tensor::zeros(&[n, m]);
365 for i in 0..n.min(m) {
366 ret.v.borrow_mut().set_f32(&[i, i], 1.);
367 }
368 ret
369 }
370 pub fn empty(shape: &[usize]) -> Tensor {
372 for i in shape {
373 if *i == 0 {
374 println!("empty: shape with zeros in it.");
375 }
376 }
377 Tensor {
378 #[cfg(feature = "use-f64")]
379 v: Rc::new(RefCell::new(TypedTensor::zeros_f64(shape))),
380 #[cfg(feature = "use-f32")]
381 v: Rc::new(RefCell::new(TypedTensor::zeros_f32(shape))),
382 }
383 }
384
385 pub fn log10_like(&self) -> Tensor {
386 Tensor {
387 v: Rc::new(RefCell::new(self.v.borrow().log10_like())),
388 }
389 }
390
391 pub fn log2_like(&self) -> Tensor {
392 Tensor {
393 v: Rc::new(RefCell::new(self.v.borrow().log2_like())),
394 }
395 }
396
397
398 pub fn cat(&self, tensors: &[Tensor], dim: usize) -> Tensor {
400 let mut concrete_tensor = Vec::new();
401
402 for i in tensors {
403 concrete_tensor.push(i.v.borrow().clone());
404 }
405 Tensor {
406 v: Rc::new(RefCell::new(self.v.borrow().cat(&concrete_tensor, dim))),
407 }
408 }
409 pub fn chunk(&self, chunks: usize, dim: usize) -> Vec<Tensor> {
410 let mut result = self.v.borrow().chunk(chunks, dim);
411 let mut ret = Vec::new();
412 for i in result.drain(..) {
413 ret.push(Tensor {
414 v: Rc::new(RefCell::new(i))
415 });
416 }
417 ret
418 }
419 pub fn gather(&self, dim: usize, index: &Tensor) -> Tensor {
420 Tensor {
421 v: Rc::new(RefCell::new(self.v.borrow().gather(dim, &index.v.borrow()))),
422 }
423 }
424 pub fn spread(&self, dim: usize, index: &Tensor, value: &Tensor) -> Tensor {
425 Tensor {
426 v: Rc::new(RefCell::new(self.v.borrow().spread(dim, &index.v.borrow(), &value.v.borrow()))),
427 }
428 }
429 pub fn index_select(&self, dim: usize, index: &Tensor) -> Tensor {
430 Tensor {
431 v: Rc::new(RefCell::new(self.v.borrow().index_select(dim, &index.v.borrow()))),
432 }
433 }
434 pub fn index_exclude(&self, dim: usize, index: &Tensor) -> Tensor {
435 Tensor {
436 v: Rc::new(RefCell::new(self.v.borrow().index_exclude(dim, &index.v.borrow()))),
437 }
438 }
439 pub fn masked_select() {
440 unimplemented!();
441 }
442 pub fn narrow() {
443 unimplemented!();
444 }
445 pub fn nonzero() {
446 unimplemented!();
447 }
448 pub fn reshape(&self, new_shape: &[usize]) -> Tensor {
449 Tensor {
450 v: Rc::new(RefCell::new(self.v.borrow().reshape(new_shape))),
451 }
452 }
453 pub fn split(&self, sections: &[usize], dim: usize) -> Vec<Tensor> {
454 let typts = self.v.borrow().split(sections, dim);
455 let mut ret = Vec::new();
456 for i in typts {
457 ret.push(Tensor {
458 v: Rc::new(RefCell::new(i)),
459 });
460 }
461 ret
462 }
463 pub fn squeeze(&self, dim: Option<usize>) -> Tensor {
464 Tensor {
465 v: Rc::new(RefCell::new(self.v.borrow().squeeze(dim))),
466 }
467 }
468 pub fn stack(&self, tensors: &[Tensor], dim: usize) -> Tensor {
469 let mut concrete_tensor = Vec::new();
470
471 for i in tensors {
472 concrete_tensor.push(i.v.borrow().clone());
473 }
474 Tensor {
475 v: Rc::new(RefCell::new(self.v.borrow().stack(&concrete_tensor, dim))),
476 }
477 }
478 pub fn t(&self) -> Tensor {
479 Tensor {
480 v: Rc::new(RefCell::new(self.v.borrow().t()))
481 }
482 }
483 pub fn take(&self, index: &[usize]) -> Tensor {
484 Tensor {
485 v: Rc::new(RefCell::new(self.v.borrow().take(index)))
486 }
487 }
488 pub fn transpose() {
489 unimplemented!();
490 }
491 pub fn unbind() {
492 unimplemented!();
493 }
494
495 pub fn permute(&self, dim: &[usize]) -> Tensor {
496 Tensor {
497 v: Rc::new(RefCell::new(self.v.borrow().permute(dim))),
498 }
499 }
500
501 pub fn unsqueeze(&self, dim: usize) -> Tensor {
507 Tensor {
508 v: Rc::new(RefCell::new(self.v.borrow().unsqueeze(dim))),
509 }
510 }
511
512 pub fn conditional_select(&self, x: &Tensor, y: &Tensor) -> Tensor {
514 Tensor {
515 v: Rc::new(RefCell::new(self.v.borrow().conditional_select(&x.v.borrow(), &y.v.borrow()))),
516 }
517 }
518 pub fn repeat(&self, dim: &[usize]) -> Tensor {
519 Tensor {
520 v: Rc::new(RefCell::new(self.v.borrow().repeat(dim))),
521 }
522 }
523
524
525 pub fn to_f64(&mut self) {}
526 pub fn to_f32(&mut self) {}
527
528 tensor_method_single_tensor_return!(abs);
530 tensor_method_single_tensor_return!(acos);
531 tensor_method_single_tensor_return!(asin);
532 tensor_method_single_tensor_return!(atan);
533 tensor_method_single_tensor_return!(ceil);
534 tensor_method_single_tensor_return!(cos);
536 tensor_method_single_tensor_return!(cosh);
537 tensor_method_single_tensor_return!(exp);
538 tensor_method_single_tensor_return!(expm1);
539 tensor_method_single_tensor_return!(floor);
540 tensor_method_single_tensor_return!(frac);
541 pub fn lerp(&self, end: &Tensor, weight: &Tensor) -> Tensor {
543 self.add(&Tensor::fill(&self.size(), weight).mul(&end.sub(self)))
544 }
545 tensor_method_single_tensor_return!(log);
546 tensor_method_single_tensor_return!(log10);
547 tensor_method_single_tensor_return!(log1p);
548 tensor_method_single_tensor_return!(log1pexp);
549 tensor_method_single_tensor_return!(log2);
550 tensor_method_single_tensor_return!(neg);
551 pub fn pow_f32(&self, n: f32) -> Tensor {
553 Tensor {
554 v: Rc::new(RefCell::new(self.v.borrow().pow_f32(n))),
555 }
556 }
557 tensor_method_single_tensor_return!(reciprocal);
558 tensor_method_single_tensor_return!(round);
559 tensor_method_single_tensor_return!(rsqrt);
560 tensor_method_single_tensor_return!(sigmoid);
561 tensor_method_single_tensor_return!(sign);
562 tensor_method_single_tensor_return!(sin);
563 tensor_method_single_tensor_return!(sinh);
564 tensor_method_single_tensor_return!(sqrt);
565 tensor_method_single_tensor_return!(square);
566 tensor_method_single_tensor_return!(tan);
567 tensor_method_single_tensor_return!(tanh);
568 tensor_method_single_tensor_return!(trunc);
569
570 tensor_method!(add);
571 tensor_method!(sub);
572 tensor_method!(mul); tensor_method!(div);
574
575 tensor_method!(mm); tensor_method!(matmul); pub fn outer(&self, o: &Tensor, avg: Option<bool>) -> Tensor {
578 Tensor {
579 v: Rc::new(RefCell::new(self.v.borrow().outer(&o.v.borrow(), avg))),
580 }
581 }
582
583 pub fn argmax(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
585 Tensor {
586 v: Rc::new(RefCell::new(self.v.borrow().argmax(dim, keepdim))),
587 }
588 }
589 pub fn argmin(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
590 Tensor {
591 v: Rc::new(RefCell::new(self.v.borrow().argmin(dim, keepdim))),
592 }
593 }
594 pub fn logsumexp(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
596 Tensor {
597 v: Rc::new(RefCell::new(self.v.borrow().logsumexp(dim, keepdim))),
598 }
599 }
600 pub fn mean(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
601 Tensor {
602 v: Rc::new(RefCell::new(self.v.borrow().mean(dim, keepdim))),
603 }
604 }
605 pub fn prod(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
606 Tensor {
607 v: Rc::new(RefCell::new(self.v.borrow().prod(dim, keepdim))),
608 }
609 }
610
611 pub fn std(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
616 Tensor {
617 v: Rc::new(RefCell::new(self.v.borrow().std(dim, keepdim))),
618 }
619 }
620 pub fn sum(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
622 Tensor {
623 v: Rc::new(RefCell::new(self.v.borrow().sum(dim, keepdim))),
624 }
625 }
626 pub fn var(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
629 Tensor {
630 v: Rc::new(RefCell::new(self.v.borrow().var(dim, keepdim))),
631 }
632 }
633 pub fn max(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
635 Tensor {
636 v: Rc::new(RefCell::new(self.v.borrow().max(dim, keepdim))),
637 }
638 }
639 pub fn min(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
640 Tensor {
641 v: Rc::new(RefCell::new(self.v.borrow().min(dim, keepdim))),
642 }
643 }
644
645 pub fn normalize(&self, mean: &Tensor, std: &Tensor) -> Tensor {
648 if self.size().len() != 2 {
649 panic!("fn normalize is for two-dimensional data.");
650 }
651 self.sub(mean).div(std)
660 }
661 tensor_method_single_tensor_return!(normalize_unit);
662
663 tensor_method_2_option_tensor_return!(lu);
664 tensor_method_2_to_1option!(lu_solve);
665 tensor_method_2_option_tensor_return!(qr);
666 tensor_method_2_option_tensor_return!(eigen);
667 tensor_method_1_option_tensor_return!(cholesky);
668 tensor_method_1_option_tensor_return!(det);
669 tensor_method_3_option_tensor_return!(svd);
670 tensor_method_1_option_tensor_return!(inv);
671 tensor_method_single_tensor_return!(pinv);
672 tensor_method_single_tensor_return!(tr);
673
674
675 tensor_method!(all_close);
677 pub fn arg_sort(&self, dim: usize, descending: bool) -> Tensor {
678 Tensor {
679 v: Rc::new(RefCell::new(self.v.borrow().arg_sort(dim, descending))),
680 }
681 }
682 tensor_method!(eq_t);
683 pub fn equal(&self, o: &Tensor) -> bool {
684 self.v.borrow().equal(&o.v.borrow())
685 }
686 tensor_method!(ge);
687 tensor_method!(gt);
688 tensor_method!(le);
689 tensor_method!(lt);
690 tensor_method!(max_pair);
691 tensor_method!(min_pair);
692 tensor_method!(ne);
693
694 pub fn rand_usize(rng: &mut StdRng,
696 dim: &[usize],
697 left: usize, right: usize) -> Tensor {
698 Tensor {
699 v: Rc::new(RefCell::new(TypedTensor::rand_usize(rng, dim, left, right))),
700 }
701 }
702 pub fn normal_f64(rng: &mut StdRng,
703 dim: &[usize],
704 mean: f64, std: f64) -> Tensor {
705 Tensor {
706 v: Rc::new(RefCell::new(TypedTensor::normal_f64(rng, dim, mean, std))),
707 }
708 }
709 pub fn normal_f32(rng: &mut StdRng,
710 dim: &[usize],
711 mean: f32, std: f32) -> Tensor {
712 Tensor {
713 v: Rc::new(RefCell::new(TypedTensor::normal_f32(rng, dim, mean, std))),
714 }
715 }
716 pub fn uniform_f64(rng: &mut StdRng,
717 dim: &[usize],
718 from: f64, to: f64) -> Tensor {
719 Tensor {
720 v: Rc::new(RefCell::new(TypedTensor::uniform_f64(rng, dim, from, to)))
721 }
722 }
723 pub fn uniform_f32(rng: &mut StdRng,
724 dim: &[usize],
725 from: f32, to: f32) -> Tensor {
726 Tensor {
727 v: Rc::new(RefCell::new(TypedTensor::uniform_f32(rng, dim, from, to)))
728 }
729 }
730
731
732 pub fn conv2d(&self, weight: &Tensor,
734 stride: (usize, usize),
735 padding: (usize, usize),
736 dilation: (usize, usize),
737 padding_mode: PaddingMode
738 ) -> Tensor {
739 Tensor {
740 v: Rc::new(RefCell::new(self.v.borrow().conv2d(&weight.v.borrow(), stride, padding, dilation, padding_mode))),
741 }
742 }
743 pub fn conv2d_grad(&self, weight: &Tensor,
744 stride: (usize, usize),
745 padding: (usize, usize),
746 dilation: (usize, usize),
747 padding_mode: PaddingMode,
748 output_grad: &Tensor
749 ) -> (Tensor, Tensor) {
750 let (r1, r2) = self.v.borrow().conv2d_grad(&weight.v.borrow(), stride, padding, dilation, padding_mode, &output_grad.v.borrow());
751 (Tensor { v: Rc::new(RefCell::new(r1))},
752 Tensor { v: Rc::new(RefCell::new(r2))},
753 )
754 }
755
756 pub fn inner(&self) -> Rc<RefCell<TypedTensor>> {
757 self.v.clone()
758 }
759 pub fn set_inner(tt: TypedTensor) -> Tensor {
760 Tensor {
761 v: Rc::new(RefCell::new(tt))
762 }
763 }
764}
765
766impl fmt::Display for Tensor {
767 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
768 write!(f, "{}", self.v.borrow())
769 }
770}
771impl fmt::Debug for Tensor {
772 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
773 write!(f, "({:?}, )", self.v.borrow())
774 }
775}
776impl PartialEq for Tensor {
777 fn eq(&self, other: &Self) -> bool {
778 self.v.eq(&other.v)
779 }
780}
781impl Eq for Tensor {}
782
783impl Clone for Tensor {
784 fn clone(&self) -> Self {
785 Tensor {
786 v: Rc::new(RefCell::new(self.v.borrow().clone())),
787 }
788 }
789}
790
791
792#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
809#[derive(Clone, Copy, PartialEq)]
810pub enum PaddingMode{
811 Zeros,
812 Reflect,
813 Replicate,
814 Circular,
815}
816
817
818
819#[cfg(test)]
820mod tests {
821 use super::*;
822
823 #[test]
824 fn tensor_equal() {
825 let a = Tensor::from_vec_f32(&vec![1., 2., 3., ], &vec![3, 1]);
826 let b = Tensor::from_vec_f32(&vec![1., 2., 3., ], &vec![3, 1]);
827 assert_eq!(a.same_shape(&b), true);
828
829 let a = Tensor::from_vec_f32(&vec![1., 2., 3., ], &vec![1, 3]);
830 let b = Tensor::from_vec_f32(&vec![1., 2., 3., ], &vec![3, 1]);
831 assert_eq!(a.same_shape(&b), false);
832 }
833
834 #[test]
835 fn normalize() {
836 let a = Tensor::from_vec_f32(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
837 let b = a.normalize_unit();
838 assert_eq!(b, Tensor::from_vec_f32(&vec![0.10482848, 0.20965695, 0.31448543, 0.4193139, 0.5241424, 0.62897086,], &vec![3, 2]));
839 }
840
841 #[test]
843 fn test_add() {
844 let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,], &vec![2,2]);
845 let m2 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,], &vec![2,2]);
846 let m3 = m1.add(&m2);
847 assert_eq!(m3.get_f32(&vec![0,0]), 2.);
848 assert_eq!(m3.get_f32(&vec![1,1]), 8.);
849
850 let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,], &vec![2,2]);
851 let m2 = Tensor::from_vec_f32(&vec![1.,2.,], &vec![2]);
852 let m3 = m1.add(&m2);
853 assert_eq!(m3, Tensor::from_vec_f32(&vec![2.,4.,4.,6.,], &vec![2,2]));
854 }
855
856 #[test]
857 fn test_mm() {
858 let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,5.,6.], &vec![3,2]);
859 let m2 = Tensor::from_vec_f32(&vec![2.,3.,4.,5.,6.,7.], &vec![2,3]);
860 let result = m1.mm(&m2);
861 assert!(result == Tensor::from_vec_f32(&vec![12.,15.,18.,26.,33.,40.,40.,51.,62.,], &vec![3,3]), "");
862 }
863
864 #[test]
865 fn test_matmul() {
866 let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,5.,6.], &vec![3,2]);
867 let m2 = Tensor::from_vec_f32(&vec![2.,3.,4.,5.,6.,7.], &vec![2,3]);
868 let result = m1.matmul(&m2);
869 assert!(result == Tensor::from_vec_f32(&vec![12.,15.,18.,26.,33.,40.,40.,51.,62.,], &vec![3,3]), "");
870 }
871
872 #[test]
873 fn test_outer() {
874 let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,5.,6.], &vec![3,2]);
875 let m2 = Tensor::from_vec_f32(&vec![2.,3.,4.,5.,6.,7.], &vec![3,2]);
876 let result = m1.outer(&m2, None);
877 assert_eq!(result, Tensor::from_vec_f32(&vec![2.0, 3.0, 4.0, 6.0, 12.0, 15.0, 16.0, 20.0, 30.0, 35.0, 36.0, 42.0], &vec![3,2, 2]));
878 }
879}