1use std::ops::Div;
2use std::rc::Rc;
3
4use itertools::Itertools;
5use tch::nn::{Init, Path, VarStore};
6use tch::{TchError, Tensor};
7
8pub mod tensor;
9
10pub trait RootExt {
12 fn root_ext<F>(&self, parameter_group_fun: F) -> PathExt
20 where
21 F: 'static + Fn(&str) -> usize;
22}
23
24impl RootExt for VarStore {
25 fn root_ext<F>(&self, parameter_group_fun: F) -> PathExt
26 where
27 F: 'static + Fn(&str) -> usize,
28 {
29 PathExt {
30 inner: self.root(),
31 parameter_group_fun: Rc::new(parameter_group_fun),
32 }
33 }
34}
35
36pub struct PathExt<'a> {
37 inner: Path<'a>,
38 parameter_group_fun: Rc<dyn Fn(&str) -> usize>,
39}
40
41impl<'a> PathExt<'a> {
42 pub fn ones(&self, name: &str, dims: &[i64]) -> Tensor {
44 let group = self.name_group(name);
45 let path = self.inner.set_group(group);
46 path.ones(name, dims)
47 }
48
49 pub fn sub<T: ToString>(&'a self, s: T) -> PathExt<'a> {
51 PathExt {
52 inner: self.inner.sub(s),
53 parameter_group_fun: self.parameter_group_fun.clone(),
54 }
55 }
56
57 pub fn var(&self, name: &str, dims: &[i64], init: Init) -> Result<Tensor, TchError> {
59 let group = self.name_group(name);
60 let path = self.inner.set_group(group);
61 path.f_var(name, dims, init)
62 }
63
64 pub fn var_copy(&self, name: &str, t: &Tensor) -> Tensor {
66 let group = self.name_group(name);
67 let path = self.inner.set_group(group);
68 path.var_copy(name, t)
69 }
70
71 fn name_group(&self, name: &str) -> usize {
73 let fullname = format!("{}.{}", self.inner.components().join("."), name);
74 (self.parameter_group_fun)(&fullname)
75 }
76
77 pub fn zeros(&self, name: &str, dims: &[i64]) -> Tensor {
79 let group = self.name_group(name);
80 let path = self.inner.set_group(group);
81 path.zeros(name, dims)
82 }
83}
84
85impl<'a, T> Div<T> for &'a mut PathExt<'a>
86where
87 T: std::string::ToString,
88{
89 type Output = PathExt<'a>;
90
91 fn div(self, rhs: T) -> Self::Output {
92 self.sub(rhs.to_string())
93 }
94}
95
96impl<'a, T> Div<T> for &'a PathExt<'a>
97where
98 T: std::string::ToString,
99{
100 type Output = PathExt<'a>;
101
102 fn div(self, rhs: T) -> Self::Output {
103 self.sub(rhs.to_string())
104 }
105}