tensor_rs/
tensor.rs

1//! 
2//! A general tensor type.
3//! 
4
5// Default value type is f32.
6// Right dimension of the tensor changes fastest.
7use std::rc::Rc;
8use std::cell::RefCell;
9//use std::ops::Index;
10use ::rand::prelude::StdRng;
11
12#[cfg(feature = "use-serde")]
13use serde::{Serialize, Deserialize};
14
15use std::fmt;
16
17
18use super::typed_tensor::TypedTensor;
19use crate::tensor_impl::gen_tensor::GenTensor;
20
21/// 2-to-1
22macro_rules! tensor_method {
23    ($a:ident) => {
24        pub fn $a(&self, o: &Tensor) -> Tensor {
25            Tensor {
26                v: Rc::new(RefCell::new(self.v.borrow().$a(&o.v.borrow()))),
27            }
28        }
29    }
30}
31
32/// 2-to-1option
33macro_rules! tensor_method_2_to_1option {
34    ($a:ident) => {
35        pub fn $a(&self, o: &Tensor) -> Option<Tensor> {
36            self.v.borrow().$a(&o.v.borrow()).map(|v| Tensor {
37                v: Rc::new(RefCell::new(v))})            
38        }
39    }
40}
41
42/// 1-to-other
43macro_rules! tensor_method_single_same_return {
44    ($a:ident, $b:ty) => {
45        pub fn $a(&self) -> $b {
46            self.v.borrow().$a()
47        }
48    }
49}
50
51/// 1-to-1
52macro_rules! tensor_method_single_tensor_return {
53    ($a:ident) => {
54        pub fn $a(&self) -> Tensor {
55            Tensor {
56                v: Rc::new(RefCell::new(self.v.borrow().$a())),
57            }
58        }
59    }
60}
61
62/// 1-to-1option
63macro_rules! tensor_method_1_option_tensor_return {
64    ($a:ident) => {
65        pub fn $a(&self) -> Option<Tensor> {
66            let r = self.v.borrow().$a();
67            r.map(|r1| Tensor {
68                v: Rc::new(RefCell::new(r1)),
69            })            
70        }
71    }
72}
73
74/// 1-to-2option
75macro_rules! tensor_method_2_option_tensor_return {
76    ($a:ident) => {
77        pub fn $a(&self) -> Option<[Tensor; 2]> {
78            let r = self.v.borrow().$a();
79            r.map(|[r1, r2]| [Tensor {
80                v: Rc::new(RefCell::new(r1)),},
81                              Tensor {
82                                  v: Rc::new(RefCell::new(r2)),
83                              }])
84        }
85    }
86}
87
88/// 1-to-3option
89macro_rules! tensor_method_3_option_tensor_return {
90    ($a:ident) => {
91        pub fn $a(&self) -> Option<[Tensor; 3]> {
92            let r = self.v.borrow().$a();
93            r.map(|[r1, r2, r3]| [
94                Tensor {
95                        v: Rc::new(RefCell::new(r1)),},
96                          Tensor {
97                              v: Rc::new(RefCell::new(r2)),
98                          },
99                          Tensor {
100                              v: Rc::new(RefCell::new(r3)),
101                          }
102            ])
103        }
104    }
105}
106
107
108
109pub struct Tensor {
110    v: Rc<RefCell<TypedTensor>>,
111}
112
113impl Default for Tensor {
114    fn default() -> Tensor {
115        Tensor {
116            v: Rc::new(RefCell::new(TypedTensor::new())),
117        }
118    }
119}
120
121impl Tensor {
122    pub fn new() -> Tensor {
123        Tensor {
124            v: Rc::new(RefCell::new(TypedTensor::new())),
125        }
126    }
127
128    pub fn data_copy(&self, o: &Tensor) {
129        self.v.borrow_mut().data_copy(&o.v.borrow());
130    }
131
132    pub fn swap(&self, o: &Tensor) {
133        self.v.swap(&o.v);
134    }
135
136    pub fn ref_copy(&self) -> Tensor {
137        Tensor {
138            v: self.v.clone(),
139        }
140    }
141
142    /// Right most is the continous indexing,
143    /// This method convert continuous index to index along each dimension.
144    pub fn index2dimpos(&self, index: usize) -> Vec::<usize> {
145        self.v.borrow().index2dimpos(index)
146    }
147    /// Right most is the continous indexing,
148    /// This method convert index along each dimension to continuous index.
149    pub fn dimpos2index(&self, dimpos: &[usize]) -> usize {
150        self.v.borrow().dimpos2index(dimpos)
151    }
152    
153    pub fn is_empty() -> bool {
154        unimplemented!();
155    }
156
157    pub fn size(&self) -> Vec<usize> {
158        self.v.borrow().size().clone()
159    }
160    tensor_method_single_same_return!(numel, usize);
161
162    pub fn get_scale_f32(&self) -> f32 {
163        self.v.borrow().get_scale_f32()
164    }
165    pub fn get_scale_f64(&self) -> f64 {
166        self.v.borrow().get_scale_f64()
167    }
168
169    tensor_method_single_tensor_return!(get_n);
170    tensor_method_single_tensor_return!(get_c);
171    tensor_method_single_tensor_return!(get_d);
172    tensor_method_single_tensor_return!(get_h);
173    tensor_method_single_tensor_return!(get_w);
174    tensor_method_single_tensor_return!(numel_tensor);
175
176    pub fn get_patch(&self, range: &[(usize, usize)], step: Option<&[usize]>) -> Tensor {
177        Tensor {
178            v: Rc::new(RefCell::new(self.v.borrow().get_patch(range, step)))
179        }
180    }
181    pub fn set_patch(&self, other: &Tensor,
182                     range: &[(usize, usize)], step: Option<&[usize]>) -> Tensor {
183        Tensor {
184            v: Rc::new(RefCell::new(self.v.borrow().set_patch(
185                &other.v.borrow(), range, step)))
186        }
187    }
188
189    pub fn same_shape(&self, o: &Tensor) -> bool {
190        let a = self.size();
191        let b = o.size();
192        a == b
193    }
194
195
196    pub fn from_vec_usize(input: &[usize], dim: &[usize]) -> Tensor {
197        let data: Vec<f32> = input.iter().map(|x| *x as f32).collect();
198        Self::from_vec_f32(&data, dim)
199    }
200    
201    /// Create a tensor from a Vec,
202    /// ```
203    /// # use tensor_rs::tensor::*;
204    /// let t1 = Tensor::from_vec_f32(&vec![0., 1., 2., 4.,], &vec![2,2]);
205    /// ```
206    pub fn from_vec_f32(input: &[f32], dim: &[usize]) -> Tensor {
207        let data = input.to_vec();
208        let idim = dim.to_vec();
209
210        Tensor {
211            //v: Rc::new(RefCell::new(TypedTensor::Typef32(GenTensor { d: data, dim: idim }))),
212            v: Rc::new(RefCell::new(TypedTensor::Typef32(GenTensor::new_raw(&data, &idim) ))),
213        }
214    }
215    /// return the internal buffer
216    /// May fail if the underlying data is f64
217    pub fn get_raw_f32(&self) -> Vec<f32> {
218        self.v.borrow().get_raw_f32()
219    }
220    pub fn from_vec_f64(input: &[f64], dim: &[usize]) -> Tensor {
221        let data = input.to_vec();
222        let idim = dim.to_vec();
223
224        Tensor {
225            //v: Rc::new(RefCell::new(TypedTensor::Typef32(GenTensor { d: data, dim: idim }))),
226            v: Rc::new(RefCell::new(TypedTensor::Typef64(GenTensor::new_raw(&data, &idim) ))),
227        }
228    }
229    /// return the internal buffer
230    /// May fail if the underlying data is f32
231    pub fn get_raw_f64(&self) -> Vec<f64> {
232        self.v.borrow().get_raw_f64()
233    }
234
235    /// try convert to Vec<u8>, value should be between 0, 255
236    pub fn get_u8(&self) -> Option<Vec<u8>> {
237        self.v.borrow().get_u8()
238    }
239
240    
241    pub fn from_record_f32(&self, row: usize, record: &[f32]) -> Result<(), &'static str> {
242        self.v.borrow_mut().from_record_f32(row, record)
243    }
244    pub fn from_record_f64(&self, row: usize, record: &[f64]) -> Result<(), &'static str> {
245        self.v.borrow_mut().from_record_f64(row, record)
246    }
247    pub fn get_f32(&self, o: &[usize]) -> f32 {
248        self.v.borrow().get_f32(o)
249    }
250    pub fn set_f32(&mut self, o: &[usize], v: f32) {
251        self.v.borrow_mut().set_f32(o, v);
252    }
253
254    pub fn get_f64(&self, o: &[usize]) -> f64 {
255        self.v.borrow().get_f64(o)
256    }
257    pub fn set_f64(&mut self, o: &[usize], v: f64) {
258        self.v.borrow_mut().set_f64(o, v);
259    }
260
261
262    
263    /// Returns a tensor of size size filled with fill_value.
264    pub fn fill(size: &[usize], fill_value: &Tensor) -> Tensor {
265        Tensor {
266            v: Rc::new(RefCell::new(TypedTensor::fill(size, &fill_value.v.borrow()))),
267        }
268    }
269    pub fn fill_f32(size: &[usize], fill_value: f32) -> Tensor {
270        Tensor {
271            v: Rc::new(RefCell::new(TypedTensor::fill_f32(size, fill_value))),
272        }
273    }
274    pub fn fill_f64(size: &[usize], fill_value: f64) -> Tensor {
275        Tensor {
276            v: Rc::new(RefCell::new(TypedTensor::fill_f64(size, fill_value))),
277        }
278    }
279    
280    // zeros
281    pub fn zeros(dim: &[usize]) -> Tensor {
282        Tensor {
283            #[cfg(feature = "use-f64")]
284            v: Rc::new(RefCell::new(TypedTensor::zeros_f64(dim))),
285            #[cfg(feature = "use-f32")]
286            v: Rc::new(RefCell::new(TypedTensor::zeros_f32(dim))),
287        }
288    }
289    // zeros_like
290    tensor_method_single_tensor_return!(zeros_like);
291    // ones
292    pub fn ones(dim: &[usize]) -> Tensor {
293        Tensor {
294            #[cfg(feature = "use-f64")]
295            v: Rc::new(RefCell::new(TypedTensor::ones_f64(dim))),
296            #[cfg(feature = "use-f32")]
297            v: Rc::new(RefCell::new(TypedTensor::ones_f32(dim))),
298        }
299    }
300    pub fn twos(dim: &[usize]) -> Tensor {
301        let a = Self::ones(dim);
302        a.add(&a)
303    }
304    pub fn int_n(dim: &[usize], n: isize) -> Tensor {
305	let abs_n = n.abs();
306	let mut a = Self::ones(dim);
307	let b = Self::ones(dim);
308	for _i in 0..(abs_n-1) {
309	    a = a.add(&b);
310	}
311	if n >= 0 {
312	    a
313	} else {
314	    a.neg()
315	}
316    }
317    // ones_like
318    tensor_method_single_tensor_return!(ones_like);
319    // range
320    pub fn range(start: f32, end: f32, step: Option<f32>) -> Tensor {
321        let real_step = if let Some(v) = step {
322            v
323        } else {
324            1.
325        };
326
327        let mut value = start;
328        let mut index = 0;
329        let mut data = Vec::new();
330        while value <= end {
331            value += real_step;
332            data.push(value);
333            index += 1;
334        }
335        
336        Tensor::from_vec_f32(&data, &[index])
337    }
338    // linspace
339    pub fn linspace(start: f32, end: f32, steps: usize) -> Tensor {
340        let real_step = (end-start)/(steps as f32);
341
342        let mut value = start;
343        let mut index = 0;
344        let mut data = Vec::new();
345        while value <= end {
346            value += real_step;
347            data.push(value);
348            index += 1;
349        }
350        
351        Tensor::from_vec_f32(&data, &[index])
352    }
353    // logspace
354    pub fn logspace(start: f32, end: f32, steps: usize, base: f32) -> Tensor {
355        let linspace_data = Tensor::linspace(start, end, steps);
356        let mut ret_data = Vec::new();
357        for i in 0..linspace_data.numel() {
358            ret_data.push(base.powf(linspace_data.get_f32(&[i])));
359        }
360        Tensor::from_vec_f32(&ret_data, &[ret_data.len()])
361    }
362    // eye
363    pub fn eye(n: usize, m: usize) -> Tensor {
364        let ret = Tensor::zeros(&[n, m]);
365        for i in 0..n.min(m) {
366            ret.v.borrow_mut().set_f32(&[i, i], 1.);
367        }
368        ret
369    }
370    // empty
371    pub fn empty(shape: &[usize]) -> Tensor {
372        for i in shape {
373            if *i == 0 {
374                println!("empty: shape with zeros in it.");
375            }
376        }
377        Tensor {
378            #[cfg(feature = "use-f64")]
379            v: Rc::new(RefCell::new(TypedTensor::zeros_f64(shape))),
380            #[cfg(feature = "use-f32")]
381            v: Rc::new(RefCell::new(TypedTensor::zeros_f32(shape))),
382        }
383    }
384
385    pub fn log10_like(&self) -> Tensor {
386	Tensor {
387	    v: Rc::new(RefCell::new(self.v.borrow().log10_like())),
388	}
389    }
390
391    pub fn log2_like(&self) -> Tensor {
392	Tensor {
393	    v: Rc::new(RefCell::new(self.v.borrow().log2_like())),
394	}
395    }
396
397    
398    // Indexing, Slicing, Joining, Mutating Ops
399    pub fn cat(&self, tensors: &[Tensor], dim: usize) -> Tensor {
400        let mut concrete_tensor = Vec::new();
401        
402        for i in tensors {
403            concrete_tensor.push(i.v.borrow().clone());
404        }
405        Tensor {
406            v: Rc::new(RefCell::new(self.v.borrow().cat(&concrete_tensor, dim))),
407        }
408    }
409    pub fn chunk(&self, chunks: usize, dim: usize) -> Vec<Tensor> {
410        let mut result = self.v.borrow().chunk(chunks, dim);
411        let mut ret = Vec::new();
412        for i in result.drain(..) {
413            ret.push(Tensor {
414                v: Rc::new(RefCell::new(i))
415            });
416        }
417        ret
418    }
419    pub fn gather(&self, dim: usize, index: &Tensor) -> Tensor {
420        Tensor {
421            v: Rc::new(RefCell::new(self.v.borrow().gather(dim, &index.v.borrow()))),
422        }
423    }
424    pub fn spread(&self, dim: usize, index: &Tensor, value: &Tensor) -> Tensor {
425        Tensor {
426            v: Rc::new(RefCell::new(self.v.borrow().spread(dim, &index.v.borrow(), &value.v.borrow()))),
427        }
428    }
429    pub fn index_select(&self, dim: usize, index: &Tensor) -> Tensor {
430        Tensor {
431            v: Rc::new(RefCell::new(self.v.borrow().index_select(dim, &index.v.borrow()))),
432        }
433    }
434    pub fn index_exclude(&self, dim: usize, index: &Tensor) -> Tensor {
435        Tensor {
436            v: Rc::new(RefCell::new(self.v.borrow().index_exclude(dim, &index.v.borrow()))),
437        }
438    }
439    pub fn masked_select() {
440        unimplemented!();
441    }
442    pub fn narrow() {
443        unimplemented!();
444    }
445    pub fn nonzero() {
446        unimplemented!();
447    }
448    pub fn reshape(&self, new_shape: &[usize]) -> Tensor {
449        Tensor {
450            v: Rc::new(RefCell::new(self.v.borrow().reshape(new_shape))),
451        }
452    }
453    pub fn split(&self, sections: &[usize], dim: usize) -> Vec<Tensor> {
454        let typts = self.v.borrow().split(sections, dim);
455        let mut ret = Vec::new();
456        for i in typts {
457            ret.push(Tensor {
458                v: Rc::new(RefCell::new(i)),
459            });
460        }
461        ret
462    }
463    pub fn squeeze(&self, dim: Option<usize>) -> Tensor {
464        Tensor {
465            v: Rc::new(RefCell::new(self.v.borrow().squeeze(dim))),
466        }
467    }
468    pub fn stack(&self, tensors: &[Tensor], dim: usize) -> Tensor {
469        let mut concrete_tensor = Vec::new();
470        
471        for i in tensors {
472            concrete_tensor.push(i.v.borrow().clone());
473        }
474        Tensor {
475            v: Rc::new(RefCell::new(self.v.borrow().stack(&concrete_tensor, dim))),
476        }
477    }
478    pub fn t(&self) -> Tensor {
479        Tensor {
480            v: Rc::new(RefCell::new(self.v.borrow().t()))
481        }
482    }
483    pub fn take(&self, index: &[usize]) -> Tensor {
484        Tensor {
485            v: Rc::new(RefCell::new(self.v.borrow().take(index)))
486        }
487    }
488    pub fn transpose() {
489        unimplemented!();
490    }
491    pub fn unbind() {
492        unimplemented!();
493    }
494
495    pub fn permute(&self, dim: &[usize]) -> Tensor {
496        Tensor {
497            v: Rc::new(RefCell::new(self.v.borrow().permute(dim))),
498        }
499    }
500    
501    /// Returns a new tensor with a dimension of size one inserted at the specified position.
502    /// 
503    /// The returned tensor shares the same underlying data with this tensor.
504    ///
505    /// 
506    pub fn unsqueeze(&self, dim: usize) -> Tensor {
507        Tensor {
508            v: Rc::new(RefCell::new(self.v.borrow().unsqueeze(dim))),
509        }
510    }
511    
512    //pub fn condition() {} // this is pytorch where
513    pub fn conditional_select(&self, x: &Tensor, y: &Tensor) -> Tensor {
514        Tensor {
515            v: Rc::new(RefCell::new(self.v.borrow().conditional_select(&x.v.borrow(), &y.v.borrow()))),
516        }
517    }
518    pub fn repeat(&self, dim: &[usize]) -> Tensor {
519        Tensor {
520            v: Rc::new(RefCell::new(self.v.borrow().repeat(dim))),
521        }
522    }
523
524    
525    pub fn to_f64(&mut self) {}
526    pub fn to_f32(&mut self) {}
527
528    // Pointwise Ops
529    tensor_method_single_tensor_return!(abs);
530    tensor_method_single_tensor_return!(acos);
531    tensor_method_single_tensor_return!(asin);
532    tensor_method_single_tensor_return!(atan);
533    tensor_method_single_tensor_return!(ceil);
534    // clamp
535    tensor_method_single_tensor_return!(cos);
536    tensor_method_single_tensor_return!(cosh);
537    tensor_method_single_tensor_return!(exp);
538    tensor_method_single_tensor_return!(expm1);
539    tensor_method_single_tensor_return!(floor);
540    tensor_method_single_tensor_return!(frac);
541    // lerp
542    pub fn lerp(&self, end: &Tensor, weight: &Tensor) -> Tensor {
543        self.add(&Tensor::fill(&self.size(), weight).mul(&end.sub(self)))
544    }
545    tensor_method_single_tensor_return!(log);
546    tensor_method_single_tensor_return!(log10);
547    tensor_method_single_tensor_return!(log1p);
548    tensor_method_single_tensor_return!(log1pexp);
549    tensor_method_single_tensor_return!(log2);
550    tensor_method_single_tensor_return!(neg);
551    // pow
552    pub fn pow_f32(&self, n: f32) -> Tensor {
553        Tensor {
554            v: Rc::new(RefCell::new(self.v.borrow().pow_f32(n))),
555        }
556    }
557    tensor_method_single_tensor_return!(reciprocal);
558    tensor_method_single_tensor_return!(round);
559    tensor_method_single_tensor_return!(rsqrt);
560    tensor_method_single_tensor_return!(sigmoid);
561    tensor_method_single_tensor_return!(sign);
562    tensor_method_single_tensor_return!(sin);
563    tensor_method_single_tensor_return!(sinh);
564    tensor_method_single_tensor_return!(sqrt);
565    tensor_method_single_tensor_return!(square);
566    tensor_method_single_tensor_return!(tan);
567    tensor_method_single_tensor_return!(tanh);
568    tensor_method_single_tensor_return!(trunc);
569
570    tensor_method!(add);
571    tensor_method!(sub);
572    tensor_method!(mul); // element-wise
573    tensor_method!(div);
574
575    tensor_method!(mm); //  matrix-multiplication
576    tensor_method!(matmul); // tensor-multiplication
577    pub fn outer(&self, o: &Tensor, avg: Option<bool>) -> Tensor {
578            Tensor {
579                v: Rc::new(RefCell::new(self.v.borrow().outer(&o.v.borrow(), avg))),
580            }
581        }
582
583    // reduction ops
584    pub fn argmax(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
585        Tensor {
586            v: Rc::new(RefCell::new(self.v.borrow().argmax(dim, keepdim))),
587        }
588    }
589    pub fn argmin(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
590        Tensor {
591            v: Rc::new(RefCell::new(self.v.borrow().argmin(dim, keepdim))),
592        }
593    }
594    //tensor_method_single_tensor_return!(dist);
595    pub fn logsumexp(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
596        Tensor {
597            v: Rc::new(RefCell::new(self.v.borrow().logsumexp(dim, keepdim))),
598        }
599    }
600    pub fn mean(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
601        Tensor {
602            v: Rc::new(RefCell::new(self.v.borrow().mean(dim, keepdim))),
603        }
604    }
605    pub fn prod(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
606        Tensor {
607            v: Rc::new(RefCell::new(self.v.borrow().prod(dim, keepdim))),
608        }
609    }
610    
611    //tensor_method_single_tensor_return!(median);
612    //tensor_method_single_tensor_return!(mode);
613    //tensor_method_single_tensor_return!(norm);
614    //tensor_method_single_tensor_return!(prod);
615    pub fn std(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
616        Tensor {
617            v: Rc::new(RefCell::new(self.v.borrow().std(dim, keepdim))),
618        }
619    }
620    //tensor_method_single_tensor_return!(std_mean);
621    pub fn sum(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
622        Tensor {
623            v: Rc::new(RefCell::new(self.v.borrow().sum(dim, keepdim))),
624        }
625    }
626    //tensor_method_single_tensor_return!(unique);
627    //tensor_method_single_tensor_return!(unique_consecutive);
628    pub fn var(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
629        Tensor {
630            v: Rc::new(RefCell::new(self.v.borrow().var(dim, keepdim))),
631        }
632    }
633    //tensor_method_single_tensor_return!(var_mean);
634    pub fn max(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
635        Tensor {
636            v: Rc::new(RefCell::new(self.v.borrow().max(dim, keepdim))),
637        }
638    }
639    pub fn min(&self, dim: Option<&[usize]>, keepdim: bool) -> Tensor {
640        Tensor {
641            v: Rc::new(RefCell::new(self.v.borrow().min(dim, keepdim))),
642        }
643    }
644
645    // linalg
646    /// mean and std are all scalars.
647    pub fn normalize(&self, mean: &Tensor, std: &Tensor) -> Tensor {
648        if self.size().len() != 2 {
649            panic!("fn normalize is for two-dimensional data.");
650        }
651        //let width = self.size()[1];
652        //if width != mean.len() {
653        //    panic!("input mean has a different size. {}, {}", width, mean.len());
654        //}
655        //if width != std.len() {
656        //    panic!("input std has a different size. {}, {}", width, std.len());
657        //}
658        
659        self.sub(mean).div(std)
660    }
661    tensor_method_single_tensor_return!(normalize_unit);
662
663    tensor_method_2_option_tensor_return!(lu);
664    tensor_method_2_to_1option!(lu_solve);
665    tensor_method_2_option_tensor_return!(qr);
666    tensor_method_2_option_tensor_return!(eigen);
667    tensor_method_1_option_tensor_return!(cholesky);
668    tensor_method_1_option_tensor_return!(det);
669    tensor_method_3_option_tensor_return!(svd);
670    tensor_method_1_option_tensor_return!(inv);
671    tensor_method_single_tensor_return!(pinv);
672    tensor_method_single_tensor_return!(tr);
673
674
675    // Comparison Ops
676    tensor_method!(all_close);
677    pub fn arg_sort(&self, dim: usize, descending: bool) -> Tensor {
678        Tensor {
679            v: Rc::new(RefCell::new(self.v.borrow().arg_sort(dim, descending))),
680        }
681    }
682    tensor_method!(eq_t);
683    pub fn equal(&self, o: &Tensor) -> bool {
684        self.v.borrow().equal(&o.v.borrow())
685    }
686    tensor_method!(ge);
687    tensor_method!(gt);
688    tensor_method!(le);
689    tensor_method!(lt);
690    tensor_method!(max_pair);
691    tensor_method!(min_pair);
692    tensor_method!(ne);
693
694    // rand
695    pub fn rand_usize(rng: &mut StdRng,
696                      dim: &[usize],
697                      left: usize, right: usize) -> Tensor {
698        Tensor {
699            v: Rc::new(RefCell::new(TypedTensor::rand_usize(rng, dim, left, right))),
700        }
701    }
702    pub fn normal_f64(rng: &mut StdRng,
703                  dim: &[usize],
704                  mean: f64, std: f64) -> Tensor {
705        Tensor {
706            v: Rc::new(RefCell::new(TypedTensor::normal_f64(rng, dim, mean, std))),
707        }
708    }
709    pub fn normal_f32(rng: &mut StdRng,
710                  dim: &[usize],
711                  mean: f32, std: f32) -> Tensor {
712        Tensor {
713            v: Rc::new(RefCell::new(TypedTensor::normal_f32(rng, dim, mean, std))),
714        }
715    }
716    pub fn uniform_f64(rng: &mut StdRng,
717                   dim: &[usize],
718                   from: f64, to: f64) -> Tensor {
719        Tensor {
720            v: Rc::new(RefCell::new(TypedTensor::uniform_f64(rng, dim, from, to)))
721        }
722    }
723    pub fn uniform_f32(rng: &mut StdRng,
724                   dim: &[usize],
725                   from: f32, to: f32) -> Tensor {
726        Tensor {
727            v: Rc::new(RefCell::new(TypedTensor::uniform_f32(rng, dim, from, to)))
728        }
729    }
730    
731
732    // conv ops
733    pub fn conv2d(&self, weight: &Tensor,
734                  stride: (usize, usize),
735                  padding: (usize, usize),
736                  dilation: (usize, usize),
737                  padding_mode: PaddingMode
738    ) -> Tensor {
739        Tensor {
740            v: Rc::new(RefCell::new(self.v.borrow().conv2d(&weight.v.borrow(), stride, padding, dilation, padding_mode))),
741        }
742    }
743    pub fn conv2d_grad(&self, weight: &Tensor,
744                       stride: (usize, usize),
745                       padding: (usize, usize),
746                       dilation: (usize, usize),
747                       padding_mode: PaddingMode,
748                       output_grad: &Tensor
749    ) -> (Tensor, Tensor) {
750        let (r1, r2) = self.v.borrow().conv2d_grad(&weight.v.borrow(), stride, padding, dilation, padding_mode, &output_grad.v.borrow());
751        (Tensor { v: Rc::new(RefCell::new(r1))},
752         Tensor { v: Rc::new(RefCell::new(r2))},
753        )
754    }
755
756    pub fn inner(&self) -> Rc<RefCell<TypedTensor>> {
757	self.v.clone()
758    }
759    pub fn set_inner(tt: TypedTensor) -> Tensor {
760	Tensor {
761	    v: Rc::new(RefCell::new(tt))
762	}
763    }
764}
765
766impl fmt::Display for Tensor {
767    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
768        write!(f, "{}", self.v.borrow())
769    }
770}
771impl fmt::Debug for Tensor {
772    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
773        write!(f, "({:?}, )", self.v.borrow())
774    }
775}
776impl PartialEq for Tensor {
777    fn eq(&self, other: &Self) -> bool {
778        self.v.eq(&other.v)
779    }
780}
781impl Eq for Tensor {}
782
783impl Clone for Tensor {
784    fn clone(&self) -> Self {
785        Tensor {
786            v: Rc::new(RefCell::new(self.v.borrow().clone())),
787        }
788    }
789}
790
791
792// index and slicing
793//pub struct TensorView {
794//    dim_index: usize,
795//}
796//
797//impl Index<usize> for Tensor {
798//    type Output = TensorView;
799//
800//    fn index(&self, dim_index: usize) -> &Self::Output {
801//        TensorView {
802//            dim_index: dim_index,
803//        }
804//    }
805//}
806
807
808#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
809#[derive(Clone, Copy, PartialEq)]
810pub enum PaddingMode{
811    Zeros,
812    Reflect,
813    Replicate,
814    Circular,
815}
816
817
818
819#[cfg(test)]
820mod tests {
821    use super::*;
822
823    #[test]
824    fn tensor_equal() {
825        let a = Tensor::from_vec_f32(&vec![1., 2., 3., ], &vec![3, 1]);
826        let b = Tensor::from_vec_f32(&vec![1., 2., 3., ], &vec![3, 1]);
827        assert_eq!(a.same_shape(&b), true);
828
829        let a = Tensor::from_vec_f32(&vec![1., 2., 3., ], &vec![1, 3]);
830        let b = Tensor::from_vec_f32(&vec![1., 2., 3., ], &vec![3, 1]);
831        assert_eq!(a.same_shape(&b), false);
832    }
833
834    #[test]
835    fn normalize() {
836        let a = Tensor::from_vec_f32(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
837        let b = a.normalize_unit();
838        assert_eq!(b, Tensor::from_vec_f32(&vec![0.10482848, 0.20965695, 0.31448543, 0.4193139, 0.5241424, 0.62897086,], &vec![3, 2]));
839    }
840
841    // test for basic ops
842    #[test]
843    fn test_add() {
844        let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,], &vec![2,2]);
845        let m2 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,], &vec![2,2]);
846        let m3 = m1.add(&m2);
847        assert_eq!(m3.get_f32(&vec![0,0]), 2.);
848        assert_eq!(m3.get_f32(&vec![1,1]), 8.);
849
850	let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,], &vec![2,2]);
851        let m2 = Tensor::from_vec_f32(&vec![1.,2.,], &vec![2]);
852        let m3 = m1.add(&m2);
853        assert_eq!(m3, Tensor::from_vec_f32(&vec![2.,4.,4.,6.,], &vec![2,2]));
854    }
855
856    #[test]
857    fn test_mm() {
858        let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,5.,6.], &vec![3,2]);
859        let m2 = Tensor::from_vec_f32(&vec![2.,3.,4.,5.,6.,7.], &vec![2,3]);
860        let result = m1.mm(&m2);
861        assert!(result == Tensor::from_vec_f32(&vec![12.,15.,18.,26.,33.,40.,40.,51.,62.,], &vec![3,3]), "");
862    }
863
864    #[test]
865    fn test_matmul() {
866        let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,5.,6.], &vec![3,2]);
867        let m2 = Tensor::from_vec_f32(&vec![2.,3.,4.,5.,6.,7.], &vec![2,3]);
868        let result = m1.matmul(&m2);
869        assert!(result == Tensor::from_vec_f32(&vec![12.,15.,18.,26.,33.,40.,40.,51.,62.,], &vec![3,3]), "");
870    }
871
872    #[test]
873    fn test_outer() {
874        let m1 = Tensor::from_vec_f32(&vec![1.,2.,3.,4.,5.,6.], &vec![3,2]);
875        let m2 = Tensor::from_vec_f32(&vec![2.,3.,4.,5.,6.,7.], &vec![3,2]);
876        let result = m1.outer(&m2, None);
877        assert_eq!(result, Tensor::from_vec_f32(&vec![2.0, 3.0, 4.0, 6.0, 12.0, 15.0, 16.0, 20.0, 30.0, 35.0, 36.0, 42.0], &vec![3,2, 2]));
878    }
879}