syntaxdot_tch_ext/
tensor.rs

1//! Convenience functions for `Tensor`.
2//!
3//! The `Tensor` API can be a bit unwieldy since it is partly
4//! autogenerated. This module prodides some additional methods
5//! that are more convenient to use.
6
7use tch::{Kind, TchError, Tensor};
8
9pub trait SumDim {
10    /// Sum over a dimension (fallible).
11    fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result<Tensor, TchError>;
12
13    /// Sum over a dimension.
14    fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor;
15}
16
17impl SumDim for Tensor {
18    fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result<Tensor, TchError> {
19        self.f_sum_dim_intlist(Some([dim].as_slice()), keep_dim, kind)
20    }
21
22    fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor {
23        self.f_sum_dim(dim, keep_dim, kind).unwrap()
24    }
25}