syntaxdot_tch_ext/
lib.rs

1use std::ops::Div;
2use std::rc::Rc;
3
4use itertools::Itertools;
5use tch::nn::{Init, Path, VarStore};
6use tch::{TchError, Tensor};
7
8pub mod tensor;
9
10/// Trait that provides the root of a variable store.
11pub trait RootExt {
12    /// Get the root of a variable store.
13    ///
14    /// In contrast to the regular `root` method, `root_ext` allows
15    /// you to provide a function that maps a variable name to a
16    /// parameter group. This is particularly useful for use cases
17    /// where one wants to put parameters in separate groups, to
18    /// give each group its own hyper-parameters.
19    fn root_ext<F>(&self, parameter_group_fun: F) -> PathExt
20    where
21        F: 'static + Fn(&str) -> usize;
22}
23
24impl RootExt for VarStore {
25    fn root_ext<F>(&self, parameter_group_fun: F) -> PathExt
26    where
27        F: 'static + Fn(&str) -> usize,
28    {
29        PathExt {
30            inner: self.root(),
31            parameter_group_fun: Rc::new(parameter_group_fun),
32        }
33    }
34}
35
36pub struct PathExt<'a> {
37    inner: Path<'a>,
38    parameter_group_fun: Rc<dyn Fn(&str) -> usize>,
39}
40
41impl<'a> PathExt<'a> {
42    /// Create a tensor variable initialized with ones.
43    pub fn ones(&self, name: &str, dims: &[i64]) -> Tensor {
44        let group = self.name_group(name);
45        let path = self.inner.set_group(group);
46        path.ones(name, dims)
47    }
48
49    /// Get a sub-path of the current path.
50    pub fn sub<T: ToString>(&'a self, s: T) -> PathExt<'a> {
51        PathExt {
52            inner: self.inner.sub(s),
53            parameter_group_fun: self.parameter_group_fun.clone(),
54        }
55    }
56
57    /// Create a tensor variable initialized with the given initializer.
58    pub fn var(&self, name: &str, dims: &[i64], init: Init) -> Result<Tensor, TchError> {
59        let group = self.name_group(name);
60        let path = self.inner.set_group(group);
61        path.f_var(name, dims, init)
62    }
63
64    /// Create a tensor variable initialized with the values from another tensor.
65    pub fn var_copy(&self, name: &str, t: &Tensor) -> Tensor {
66        let group = self.name_group(name);
67        let path = self.inner.set_group(group);
68        path.var_copy(name, t)
69    }
70
71    /// Get the full name of `name` and return its group.
72    fn name_group(&self, name: &str) -> usize {
73        let fullname = format!("{}.{}", self.inner.components().join("."), name);
74        (self.parameter_group_fun)(&fullname)
75    }
76
77    /// Create a tensor variable initialized with zeros.
78    pub fn zeros(&self, name: &str, dims: &[i64]) -> Tensor {
79        let group = self.name_group(name);
80        let path = self.inner.set_group(group);
81        path.zeros(name, dims)
82    }
83}
84
85impl<'a, T> Div<T> for &'a mut PathExt<'a>
86where
87    T: std::string::ToString,
88{
89    type Output = PathExt<'a>;
90
91    fn div(self, rhs: T) -> Self::Output {
92        self.sub(rhs.to_string())
93    }
94}
95
96impl<'a, T> Div<T> for &'a PathExt<'a>
97where
98    T: std::string::ToString,
99{
100    type Output = PathExt<'a>;
101
102    fn div(self, rhs: T) -> Self::Output {
103        self.sub(rhs.to_string())
104    }
105}