syntaxdot_tch_ext/
tensor.rs1use tch::{Kind, TchError, Tensor};
8
9pub trait SumDim {
10 fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result<Tensor, TchError>;
12
13 fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor;
15}
16
17impl SumDim for Tensor {
18 fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result<Tensor, TchError> {
19 self.f_sum_dim_intlist(Some([dim].as_slice()), keep_dim, kind)
20 }
21
22 fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor {
23 self.f_sum_dim(dim, keep_dim, kind).unwrap()
24 }
25}