tensor_rs/tensor_impl/gen_tensor/
elemwise.rs

1use super::GenTensor;
2use crate::tensor_trait::elemwise::ElemwiseTensorOp;
3
4impl<T> ElemwiseTensorOp for GenTensor<T> where T: num_traits::Float {
5    type TensorType = GenTensor<T>;
6    type ElementType = T;
7
8    // Pointwise Ops
9    // abs
10    fn abs(&self) -> GenTensor<T> {
11        self._pointwise(|x| {
12            x.abs()
13        })
14    }
15    // acos
16    fn acos(&self) -> GenTensor<T> {
17        self._pointwise(|x| {
18            x.acos()
19        })
20    }
21    // add, there is one.
22    // addcdiv
23    // addcmul
24    // angle
25    // asin
26    fn asin(&self) -> GenTensor<T> {
27        self._pointwise(|x| {
28            x.asin()
29        })
30    }
31    // atan
32    fn atan(&self) -> GenTensor<T> {
33        self._pointwise(|x| {
34            x.atan()
35        })
36    }
37    // atan2
38    // bitwise_not
39    // bitwise_and
40    // bitwise_or
41    // bitwise_xor
42    // ceil
43    fn ceil(&self) -> GenTensor<T> {
44        self._pointwise(|x| {
45            x.ceil()
46        })
47    }
48    // clamp
49    fn clamp(&self, min: T, max: T) -> GenTensor<T> {
50        let mut ret = GenTensor::new_move(Vec::with_capacity(self.get_data().len()),
51                                          self.get_size().to_vec());
52
53        for i in self.get_data() {
54            let value;
55            if *i < min {
56                value = min;
57            } else if *i <= max {
58                value = *i;
59            } else {
60                value = max;
61            }
62            ret.get_data_mut().push(value);
63        }
64        ret
65    }
66    // conj
67    // cos
68    fn cos(&self) -> GenTensor<T> {
69        self._pointwise(|x| {
70            x.cos()
71        })
72    }
73    // cosh
74    fn cosh(&self) -> GenTensor<T> {
75        self._pointwise(|x| {
76            x.cosh()
77        })
78    }
79    // div, there is one.
80    // digamma
81    //fn digamma(&self) -> GenTensor<T> {
82    //    self._pointwise(|x| {
83    //        x.digamma()
84    //    })
85    //}
86    // erf
87    // erfc
88    // erfinv
89    // exp
90    fn exp(&self) -> GenTensor<T> {
91        self._pointwise(|x| {
92            x.exp()
93        })
94    }
95    // expm1
96    fn expm1(&self) -> GenTensor<T> {
97        self._pointwise(|x| {
98            x.exp_m1()
99        })
100    }
101    // floor
102    fn floor(&self) -> GenTensor<T> {
103        self._pointwise(|x| {
104            x.floor()
105        })
106    }
107    // floor_divide
108    // fmod
109    // frac
110    fn frac(&self) -> GenTensor<T> {
111        self._pointwise(|x| {
112            x.fract()
113        })
114    }
115    // imag
116    // lerp, this is on Tensor.
117    // lgamma
118    // log
119    fn log(&self) -> GenTensor<T> {
120        self._pointwise(|x| {
121            x.ln()
122        })
123    }
124    // log10
125    fn log10(&self) -> GenTensor<T> {
126        self._pointwise(|x| {
127            x.log10()
128        })
129    }
130    // log1p
131    fn log1p(&self) -> GenTensor<T> {
132        self._pointwise(|x| {
133            x.ln_1p()
134        })
135    }
136
137    /// Better log(1 + exp(x))
138    /// see <https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf>
139    fn log1pexp(&self) -> GenTensor<T> {
140        let mut ret = GenTensor::new_move(Vec::with_capacity(self.get_data().len()),
141                                          self.get_size().to_vec());
142        for i in self.get_data() {
143            if i <= &T::from(-37).expect("") {
144                ret.get_data_mut().push(i.exp());
145            } else if i > &T::from(-37).expect("") && i <= &T::from(18).expect("") {
146                ret.get_data_mut().push(i.exp().ln_1p());
147            } else if i > &T::from(-18).expect("") && i <= &T::from(33.3).expect("") {
148                ret.get_data_mut().push(*i + i.mul(T::from(-1).expect("")).exp());
149            } else {
150                ret.get_data_mut().push(*i);
151            }
152        }
153        ret
154    }
155    
156    // log2
157    fn log2(&self) -> GenTensor<T> {
158        self._pointwise(|x| {
159            x.log2()
160        })
161    }
162    // logical_and
163    // logical_not
164    // logical_or
165    // logical_xor
166    // mul, there is one
167    // mvlgamma
168    // neg
169    fn neg(&self) -> GenTensor<T> {
170        let mut ret = GenTensor::new_move(Vec::with_capacity(self.get_data().len()),
171                                          self.get_size().to_vec());
172
173        for i in self.get_data() {
174            ret.get_data_mut().push(i.mul(T::zero() - T::one()));
175        }
176        ret
177    }
178    
179    // polygamma
180    // pow
181    fn pow(&self, n: T) -> GenTensor<T> {
182        self._pointwise(|x| {
183            x.powf(n)
184        })
185    }
186    // real
187    // reciprocal
188    fn reciprocal(&self) -> GenTensor<T> {
189        self._pointwise(|x| {
190            x.recip()
191        })
192    }
193    // remainder
194    // round
195    fn round(&self) -> GenTensor<T> {
196        self._pointwise(|x| {
197            x.round()
198        })
199    }
200    // rsqrt
201    fn rsqrt(&self) -> GenTensor<T> {
202        self._pointwise(|x| {
203            x.sqrt()/(*x) // TODO ?
204        })
205    }
206    
207    fn sigmoid(&self) -> GenTensor<T> {
208        let mut ret = GenTensor::new_move(self.get_data().to_vec(),
209                                          self.get_size().to_vec());
210
211        for i in 0..self.get_data().len() {
212            if self.get_data()[i] > T::zero() {
213                ret.get_data_mut()[i] = T::one()/(T::one() + self.get_data()[i].neg().exp());
214            }
215            else {
216                ret.get_data_mut()[i] = self.get_data()[i].exp()/(T::one() + self.get_data()[i].exp());
217            }
218        }
219        ret
220    }
221
222    // sign
223    fn sign(&self) -> GenTensor<T> {
224        self._pointwise(|x| {
225            if *x == T::zero() {
226                T::zero()
227            } else if *x > T::zero() {
228                T::one()
229            } else {
230                T::zero() - T::one()
231            }
232        })
233    }
234    // sin
235    fn sin(&self) -> GenTensor<T> {
236        self._pointwise(|x| {
237            x.sin()
238        })
239    }
240    // sinh
241    fn sinh(&self) -> GenTensor<T> {
242        self._pointwise(|x| {
243            x.sinh()
244        })
245    }
246    // sqrt
247    fn sqrt(&self) -> GenTensor<T> {
248        self._pointwise(|x| {
249            x.sqrt()
250        })
251    }
252    // square
253    fn square(&self) -> GenTensor<T> {
254        self._pointwise(|x| {
255            (*x)*(*x)
256        })
257    }
258    // tan
259    fn tan(&self) -> GenTensor<T> {
260        self._pointwise(|x| {
261            x.tan()
262        })
263    }
264    // tanh
265    fn tanh(&self) -> GenTensor<T> {
266        self._pointwise(|x| {
267            x.tanh()
268        })
269    }
270    // true_divide
271    // trunc
272    fn trunc(&self) -> GenTensor<T> {
273        self._pointwise(|x| {
274            x.trunc()
275        })
276    }
277}
278