tch_plus/nn/
var_store.rs

1//! Variable stores.
2use super::Init;
3use crate::tensor::Tensor;
4use crate::wrappers::stream::ReadSeekAdapter;
5use crate::{Device, Kind, TchError};
6use std::collections::hash_map::Entry::{Occupied, Vacant};
7use std::collections::HashMap;
8use std::io::{Read, Seek};
9use std::ops::Div;
10use std::sync::{Arc, Mutex, MutexGuard};
11
12/// The separator is used to separate path elements in the tensor names.
13const SEP: char = '.';
14
15#[derive(Debug)]
16pub struct Var {
17    pub tensor: Tensor,
18    pub group: usize,
19}
20
21// When the variable store is frozen, the trainable_variables vector
22// still contains the same tensors however these tensors are set not
23// to require gradients.
24#[derive(Debug)]
25pub struct Variables {
26    pub named_variables: HashMap<String, Tensor>,
27    pub trainable_variables: Vec<Var>,
28}
29
30/// A VarStore is used to store variables used by one or multiple layers.
31/// It specifies a single device where all variables are stored.
32#[derive(Debug)]
33pub struct VarStore {
34    pub variables_: Arc<Mutex<Variables>>,
35    device: Device,
36    kind: Kind,
37}
38
39/// A variable store with an associated path for variables naming.
40#[derive(Debug, Clone)]
41pub struct Path<'a> {
42    path: Vec<String>,
43    group: usize,
44    var_store: &'a VarStore,
45}
46
47/// An Entry holds an entry corresponding to a given name in Path.
48#[derive(Debug)]
49pub struct Entry<'a> {
50    name: &'a str,
51    variables: MutexGuard<'a, Variables>,
52    // This field holds the mutex lock
53    path: &'a Path<'a>,
54}
55
56impl VarStore {
57    /// Creates a new var-store located on the specified device.
58    pub fn new(device: Device) -> VarStore {
59        let variables =
60            Variables { named_variables: HashMap::new(), trainable_variables: Vec::new() };
61        VarStore { variables_: Arc::new(Mutex::new(variables)), device, kind: Kind::Float }
62    }
63
64    pub fn merge(var_stores: Vec<(VarStore, Option<&str>)>) -> Result<VarStore, TchError> {
65        let mut new_var_store = VarStore::new(Device::Cpu);
66
67        if var_stores.is_empty() {
68            Ok(new_var_store)
69        } else {
70            let mut new_variables =
71                Variables { named_variables: HashMap::new(), trainable_variables: Vec::new() };
72            let device = var_stores[0].0.device();
73
74            for (var_store, prefix) in var_stores {
75                if var_store.device() != device {
76                    return Err(TchError::Torch(format!(
77                        "All VarStores must be on the same device, got {:?} and {:?}",
78                        device,
79                        var_store.device()
80                    )));
81                }
82                for (var_name, var) in var_store.variables() {
83                    let new_var_name = format!("{}{}", prefix.unwrap_or(""), var_name);
84                    match new_variables.named_variables.entry(new_var_name) {
85                        Occupied(v) => {
86                            return Err(TchError::Torch(format!(
87                                "Duplicate variable name found: {}. Provide a unique prefix to allow merge operation",
88                                v.key(),
89                            )));
90                        }
91                        Vacant(v) => {
92                            v.insert(var);
93                        }
94                    }
95                }
96                for trainable_var in
97                    var_store.variables_.lock().unwrap().trainable_variables.drain(..)
98                {
99                    new_variables.trainable_variables.push(trainable_var);
100                }
101            }
102            new_var_store.variables_ = Arc::new(Mutex::new(new_variables));
103            new_var_store.device = device;
104
105            Ok(new_var_store)
106        }
107    }
108
109    /// Gets the device for this var-store.
110    pub fn device(&self) -> Device {
111        self.device
112    }
113
114    /// Gets the default kind of new variables
115    pub fn kind(&self) -> Kind {
116        self.kind
117    }
118
119    /// Returns the number of tensors currently stored on this var-store.
120    pub fn len(&self) -> usize {
121        let variables = self.variables_.lock().unwrap();
122        variables.named_variables.len()
123    }
124
125    /// Returns true if no tensors are currently stored on this var-store.
126    pub fn is_empty(&self) -> bool {
127        let variables = self.variables_.lock().unwrap();
128        variables.named_variables.is_empty()
129    }
130
131    /// Returns all the trainable variables for this var-store.
132    pub fn trainable_variables(&self) -> Vec<Tensor> {
133        let variables = self.variables_.lock().unwrap();
134        variables.trainable_variables.iter().map(|v| v.tensor.shallow_clone()).collect()
135    }
136
137    /// Returns all variables along with their names.
138    pub fn variables(&self) -> HashMap<String, Tensor> {
139        let variables = self.variables_.lock().unwrap();
140        variables
141            .named_variables
142            .iter()
143            .map(|(name, v)| (name.clone(), v.shallow_clone()))
144            .collect()
145    }
146
147    /// Gets the root path for this variable store.
148    ///
149    /// Variables are named and organized using paths. This function returns
150    /// the top level path for the var store and can be combined with '/'
151    /// to create sub-paths.
152    pub fn root(&self) -> Path {
153        Path { path: vec![], group: 0, var_store: self }
154    }
155
156    /// Saves the var-store variable values to a file.
157    ///
158    /// Weight values for all the tensors currently stored in the
159    /// var-store are saved in the given file.
160    ///
161    /// If the given path ends with the suffix `.safetensors`, the file will
162    /// be saved in safetensors format. Otherwise, libtorch C++ module format
163    /// will be used. Note that saving in pickle format (`.pt` extension) is
164    /// not supported by the C++ API of Torch.
165    pub fn save<T: AsRef<std::path::Path>>(&self, path: T) -> Result<(), TchError> {
166        let variables = self.variables_.lock().unwrap();
167        let named_tensors = variables.named_variables.iter().collect::<Vec<_>>();
168        match path.as_ref().extension().and_then(|x| x.to_str()) {
169            Some("safetensors") => Tensor::write_safetensors(named_tensors.as_slice(), path),
170            Some(_) | None => Tensor::save_multi(named_tensors.as_slice(), path),
171        }
172    }
173
174    /// Saves the var-store variable values to a stream.
175    ///
176    /// Weight values for all the tensors currently stored in the
177    /// var-store gets saved in the given stream.
178    pub fn save_to_stream<W: std::io::Write>(&self, stream: W) -> Result<(), TchError> {
179        let variables = self.variables_.lock().unwrap();
180        let named_tensors = variables.named_variables.iter().collect::<Vec<_>>();
181        Tensor::save_multi_to_stream(named_tensors.as_slice(), stream)
182    }
183
184    fn named_tensors<T: AsRef<std::path::Path>>(
185        &self,
186        path: T,
187    ) -> Result<HashMap<String, Tensor>, TchError> {
188        let named_tensors = match path.as_ref().extension().and_then(|x| x.to_str()) {
189            Some("bin") | Some("pt") => Tensor::loadz_multi_with_device(&path, self.device),
190            Some("safetensors") => Tensor::read_safetensors(path),
191            Some(_) | None => Tensor::load_multi_with_device(&path, self.device),
192        };
193        Ok(named_tensors?.into_iter().collect())
194    }
195
196    /// Copies the data from source tensor to destination
197    ///
198    /// Updates the precision of the destination to match the source
199    fn copy_data_with_precision_update(src: &Tensor, dst: &mut Tensor) -> Result<(), TchError> {
200        dst.set_data(&dst.to_kind(src.kind()));
201        dst.f_copy_(src)
202    }
203
204    fn load_internal<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError> {
205        let named_tensors = self.named_tensors(&path)?;
206        let mut variables = self.variables_.lock().unwrap();
207        for (name, var) in variables.named_variables.iter_mut() {
208            match named_tensors.get(name) {
209                Some(src) => crate::no_grad(|| {
210                    Self::copy_data_with_precision_update(src, var)
211                        .map_err(|e| e.path_context(name))
212                })?,
213                None => {
214                    return Err(TchError::TensorNameNotFound(
215                        name.to_string(),
216                        path.as_ref().to_string_lossy().into_owned(),
217                    ));
218                }
219            }
220        }
221        Ok(())
222    }
223
224    /// Loads the var-store variable values from a file.
225    ///
226    /// Weight values for all the tensors currently stored in the
227    /// var-store are loaded from the given file. Note that the set of
228    /// variables stored in the var-store is not changed, only the values
229    /// for these tensors are modified.
230    ///
231    /// The format of the file is deduced from the file extension:
232    /// - `.safetensors`: The file is assumed to be in safetensors format.
233    /// - `.bin` or `.pt`: The file is assumed to be in pickle format.
234    /// - Otherwise, the file is assumed to be in libtorch C++ module format.
235    pub fn load<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError> {
236        if self.device != Device::Mps {
237            self.load_internal(path)
238        } else {
239            // Current workaround to allow loading in MPS device.
240            // On new libtorch releases check if direct loading becomes possible and revert
241            // See (https://github.com/LaurentMazare/tch-rs/issues/609#issuecomment-1427071598).
242            self.set_device(Device::Cpu);
243            let or_error = self.load_internal(path);
244            // Be cautious not to early exit so as to ensure that the device is set back to Mps
245            // even on errors.
246            self.set_device(Device::Mps);
247            or_error
248        }
249    }
250
251    /// Loads the var-store variable values from a stream.
252    ///
253    /// Weight values for all the tensors currently stored in the
254    /// var-store gets loaded from the given stream. Note that the set of
255    /// variables stored in the var-store is not changed, only the values
256    /// for these tensors are modified.
257    pub fn load_from_stream<S: Read + Seek>(&mut self, stream: S) -> Result<(), TchError> {
258        let adapter = ReadSeekAdapter::new(stream);
259        let named_tensors = Tensor::load_multi_from_stream_with_device(adapter, self.device)?;
260        let named_tensors: HashMap<_, _> = named_tensors.into_iter().collect();
261        let mut variables = self.variables_.lock().unwrap();
262        for (name, var) in variables.named_variables.iter_mut() {
263            match named_tensors.get(name) {
264                Some(src) => crate::no_grad(|| {
265                    Self::copy_data_with_precision_update(src, var)
266                        .map_err(|e| e.path_context(name))
267                })?,
268                None => {
269                    return Err(TchError::TensorNameNotFound(
270                        name.to_string(),
271                        "source stream".to_string(),
272                    ));
273                }
274            }
275        }
276        Ok(())
277    }
278
279    /// Loads the var-store variable values from a file if it exists.
280    ///
281    /// Weight values for the tensors currently stored in the var-store and the given file get
282    /// loaded from the given file. If a variable in the var store is not present in the given file,
283    /// it is skipped and its values are not updated. This method should be used if pre-trained
284    /// weight for only parts of the model are available.
285    /// Note that the set of variables stored in the var-store is not changed, only the values
286    /// for these tensors are modified.
287    ///
288    /// Returns a String Vector containing the names of missing variables.
289    pub fn load_partial<T: AsRef<std::path::Path>>(
290        &mut self,
291        path: T,
292    ) -> Result<Vec<String>, TchError> {
293        let named_tensors = self.named_tensors(&path)?;
294        let mut variables = self.variables_.lock().unwrap();
295        let mut missing_variables = Vec::new();
296        for (name, var) in variables.named_variables.iter_mut() {
297            match named_tensors.get(name) {
298                Some(src) => crate::no_grad(|| {
299                    Self::copy_data_with_precision_update(src, var)
300                        .map_err(|e| e.path_context(name))
301                })?,
302                None => {
303                    missing_variables.push(name.to_owned());
304                }
305            }
306        }
307        Ok(missing_variables)
308    }
309
310    /// Freezes a var store.
311    ///
312    /// Gradients for the variables in this store are not tracked
313    /// anymore.
314    pub fn freeze(&mut self) {
315        let variables = self.variables_.lock().unwrap();
316        for variable in variables.trainable_variables.iter() {
317            let _v = variable.tensor.set_requires_grad(false);
318        }
319    }
320
321    /// Unfreezes a var store.
322    ///
323    /// Gradients for the variables in this store are tracked again.
324    pub fn unfreeze(&mut self) {
325        let variables = self.variables_.lock().unwrap();
326        for variable in variables.trainable_variables.iter() {
327            let _v = variable.tensor.set_requires_grad(true);
328        }
329    }
330
331    /// Casts all variables in a var store to the target kind and sets the default kind
332    /// for new variables.
333    ///
334    /// For floating-point conversion, methods `half`, `bfloat16`, `float` and `double`
335    /// should be preferred as they ensure only float-like variables will be converted
336    /// to the target type.
337    pub fn set_kind(&mut self, kind: Kind) {
338        self.root().set_kind(kind);
339        self.kind = kind;
340    }
341
342    /// Casts all float-like variable of a var store to half-precision (Half kind).
343    pub fn half(&mut self) {
344        self.root().half();
345    }
346
347    /// Casts all float-like variable of a var store to bfloat16-precision (BFloat16 kind).
348    pub fn bfloat16(&mut self) {
349        self.root().bfloat16();
350    }
351
352    /// Casts all float-like variable of a var store to single-precision (Float kind).
353    pub fn float(&mut self) {
354        self.root().float();
355    }
356
357    /// Casts all float-like variable of a var store to single-precision (Double kind).
358    pub fn double(&mut self) {
359        self.root().double();
360    }
361
362    /// Migrates a var store and all its tensor to a target device.
363    pub fn set_device(&mut self, device: Device) {
364        let mut variables = self.variables_.lock().unwrap();
365        for (_, variable) in variables.named_variables.iter_mut() {
366            variable.set_data(&variable.to_device(device));
367        }
368        self.device = device
369    }
370
371    /// Copies variable values from a source var store to this var store.
372    ///
373    /// All the variables in this var store have to exist with the same
374    /// name in the source var store, otherwise an error is returned.
375    pub fn copy(&mut self, src: &VarStore) -> Result<(), TchError> {
376        let mut variables = self.variables_.lock().unwrap();
377        let src_variables = src.variables_.lock().unwrap();
378        let device = self.device;
379        for name in variables.named_variables.keys() {
380            if !src_variables.named_variables.contains_key(name) {
381                return Err(TchError::TensorNameNotFound(
382                    name.to_string(),
383                    "src var-store".to_string(),
384                ));
385            }
386        }
387        for (name, var) in variables.named_variables.iter_mut() {
388            let src_var = src_variables.named_variables.get(name).unwrap();
389            crate::no_grad(|| var.f_copy_(&src_var.to_device(device)))?;
390        }
391        Ok(())
392    }
393}
394
395impl<'a> Path<'a> {
396    /// Get the components of the path.
397    pub fn components(&self) -> impl Iterator<Item = &str> {
398        self.path.iter().map(String::as_str)
399    }
400
401    /// Gets a sub-path of the given path.
402    pub fn sub<T: std::string::ToString>(&self, s: T) -> Path<'a> {
403        let s = s.to_string();
404        if s.chars().any(|x| x == SEP) {
405            panic!("sub name cannot contain {SEP} {s}");
406        }
407        let mut path = self.path.clone();
408        path.push(s);
409        Path { path, group: self.group, var_store: self.var_store }
410    }
411
412    pub fn set_group(&self, group: usize) -> Path<'a> {
413        Path { path: self.path.clone(), group, var_store: self.var_store }
414    }
415
416    /// Gets the device where the var-store variables are stored.
417    pub fn device(&self) -> Device {
418        self.var_store.device
419    }
420
421    /// Gets the default kind of new variables
422    pub fn kind(&self) -> Kind {
423        self.var_store.kind
424    }
425
426    pub fn path(&self, name: &str) -> String {
427        if name.chars().any(|x| x == SEP) {
428            panic!("variable name cannot contain {SEP} {name}");
429        }
430        if self.path.is_empty() {
431            name.to_string()
432        } else {
433            format!("{}{}{}", self.path.join(&SEP.to_string()), SEP, name)
434        }
435    }
436
437    /// Casts all variables in a var store sub-path to the target kind .
438    ///
439    /// Only the variable in the path sub-tree are cast to the target kind:
440    /// other var store variables are unaffected. For floating-point conversion, methods
441    /// `half`, `bfloat16`, `float` and `double` should be preferred as they ensure only
442    /// float-like variables will be converted to the target type.
443    pub fn set_kind(&mut self, kind: Kind) {
444        let path_root = self.path.join(SEP.to_string().as_str());
445        let mut variables = self.var_store.variables_.lock().unwrap();
446        for (variable_name, variable) in variables.named_variables.iter_mut() {
447            if variable_name.starts_with(&path_root) {
448                variable.set_data(&variable.to_kind(kind));
449            }
450        }
451    }
452
453    /// Casts all float-like variables in a var store sub-path to the target kind .
454    ///
455    /// Only the float-like variable in the path sub-tree are cast to the target kind:
456    /// other var store variables are unaffected
457    fn set_float_kind(&mut self, kind: Kind) {
458        let path_root = self.path.join(SEP.to_string().as_str());
459        let mut variables = self.var_store.variables_.lock().unwrap();
460        for (variable_name, variable) in variables.named_variables.iter_mut() {
461            if variable_name.starts_with(&path_root) & variable.is_floating_point() {
462                variable.set_data(&variable.to_kind(kind));
463            }
464        }
465    }
466
467    /// Casts all float-like variables in a var store sub-path to half-precision (Half kind).
468    ///
469    /// Only the variable in the path sub-tree are cast to half-precision:
470    /// other var store variables are unaffected
471    pub fn half(&mut self) {
472        self.set_float_kind(Kind::Half);
473    }
474
475    /// Casts all float-like variables in a var store sub-path to bfloat16-precision (BFloat16 kind).
476    ///
477    /// Only the variable in the path sub-tree are cast to bfloat16-precision:
478    /// other var store variables are unaffected
479    pub fn bfloat16(&mut self) {
480        self.set_float_kind(Kind::BFloat16);
481    }
482
483    /// Casts all float-like variables in a var store sub-path to single-precision (Float kind).
484    ///
485    /// Only the variable in the path sub-tree are cast to single-precision:
486    /// other var store variables are unaffected
487    pub fn float(&mut self) {
488        self.set_float_kind(Kind::Float);
489    }
490
491    /// Casts all float-like variables in a var store sub-path to double-precision (Double kind).
492    ///
493    /// Only the variable in the path sub-tree are cast to double-precision:
494    /// other var store variables are unaffected
495    pub fn double(&mut self) {
496        self.set_float_kind(Kind::Double);
497    }
498
499    pub fn add(&self, name: &str, tensor: Tensor, trainable: bool) -> Tensor {
500        let path = self.path(name);
501        let mut variables = self.var_store.variables_.lock().unwrap();
502        let path = if variables.named_variables.contains_key(&path) {
503            format!("{}__{}", path, variables.named_variables.len())
504        } else {
505            path
506        };
507        let tensor = if trainable { tensor.set_requires_grad(true) } else { tensor };
508        if trainable {
509            let var = Var { tensor: tensor.shallow_clone(), group: self.group };
510            variables.trainable_variables.push(var);
511        };
512        variables.named_variables.insert(path, tensor.shallow_clone());
513        tensor
514    }
515
516    fn get_or_add_with_lock(
517        &self,
518        name: &str,
519        tensor: Tensor,
520        trainable: bool,
521        mut variables: MutexGuard<Variables>,
522    ) -> Tensor {
523        let path = self.path(name);
524        if let Some(var) = variables.named_variables.get(&path) {
525            return var.shallow_clone();
526        }
527
528        let tensor = if trainable { tensor.set_requires_grad(true) } else { tensor };
529        if trainable {
530            let var = Var { tensor: tensor.shallow_clone(), group: self.group };
531            variables.trainable_variables.push(var);
532        }
533        variables.named_variables.insert(path, tensor.shallow_clone());
534        tensor
535    }
536
537    /// Creates a new variable initialized with zeros.
538    ///
539    /// The new variable is named according to the name parameter and
540    /// has the specified shape. The variable will not be trainable so
541    /// gradients will not be tracked.
542    /// The variable uses a float tensor initialized with zeros.
543    pub fn f_zeros_no_train(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
544        let z = Tensor::f_zeros(dims, (Kind::Float, self.device()))?;
545        Ok(self.add(name, z, false))
546    }
547
548    /// Creates a new variable initialized with ones.
549    ///
550    /// The new variable is named according to the name parameter and
551    /// has the specified shape. The variable will not be trainable so
552    /// gradients will not be tracked.
553    /// The variable uses a float tensor initialized with ones.
554    pub fn f_ones_no_train(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
555        let o = Tensor::f_ones(dims, (Kind::Float, self.device()))?;
556        Ok(self.add(name, o, false))
557    }
558
559    /// Creates a new variable.
560    ///
561    /// The new variable is named according to the name parameter and
562    /// has the specified shape. The variable is trainable, its gradient
563    /// will be tracked.
564    /// The variable uses a float tensor initialized as per the
565    /// related argument.
566    pub fn f_var(&self, name: &str, dims: &[i64], init: Init) -> Result<Tensor, TchError> {
567        let v = super::f_init(init, dims, self.device(), self.kind())?;
568        Ok(self.add(name, v, true))
569    }
570
571    /// Creates a new variable initialized with zeros.
572    ///
573    /// The new variable is named according to the name parameter and
574    /// has the specified shape. The variable is trainable, its gradient
575    /// will be tracked.
576    /// The variable uses a float tensor initialized with zeros.
577    pub fn f_zeros(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
578        self.f_var(name, dims, Init::Const(0.))
579    }
580
581    /// Creates a new variable initialized with ones.
582    ///
583    /// The new variable is named according to the name parameter and
584    /// has the specified shape. The variable is trainable, its gradient
585    /// will be tracked.
586    /// The variable uses a float tensor initialized with ones.
587    pub fn f_ones(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
588        self.f_var(name, dims, Init::Const(1.))
589    }
590
591    /// Creates a new variable initialized randomly with normal distribution.
592    ///
593    /// The new variable is named according to the name parameter and
594    /// has the specified shape. The variable is trainable, its gradient
595    /// will be tracked.
596    /// The variable uses a float tensor initialized randomly using a
597    /// standard normal distribution.
598    pub fn f_randn_standard(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
599        let init = Init::Randn { mean: 0., stdev: 1. };
600        self.f_var(name, dims, init)
601    }
602
603    /// Creates a new variable initialized randomly with normal distribution.
604    ///
605    /// The new variable is named according to the name parameter and
606    /// has the specified shape. The variable is trainable, its gradient
607    /// will be tracked.
608    /// The variable uses a float tensor initialized randomly using a
609    /// normal distribution with the specified mean and standard deviation.
610    pub fn f_randn(
611        &self,
612        name: &str,
613        dims: &[i64],
614        mean: f64,
615        stdev: f64,
616    ) -> Result<Tensor, TchError> {
617        self.f_var(name, dims, Init::Randn { mean, stdev })
618    }
619
620    /// Creates a new variable initialized randomly with uniform distribution.
621    ///
622    /// The new variable is named according to the name parameter and
623    /// has the specified shape. The variable is trainable, its gradient
624    /// will be tracked.
625    /// The variable uses a float tensor initialized randomly using a
626    /// uniform distribution between the specified bounds.
627    pub fn f_uniform(
628        &self,
629        name: &str,
630        dims: &[i64],
631        lo: f64,
632        up: f64,
633    ) -> Result<Tensor, TchError> {
634        self.f_var(name, dims, Init::Uniform { lo, up })
635    }
636
637    /// Creates a new variable initialized randomly with kaiming uniform.
638    ///
639    /// The new variable is named according to the name parameter and
640    /// has the specified shape. The variable is trainable, its gradient
641    /// will be tracked.
642    /// The variable uses a float tensor initialized randomly using a
643    /// uniform distribution which bounds follow Kaiming initialization.
644    pub fn f_kaiming_uniform(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
645        self.f_var(name, dims, super::init::DEFAULT_KAIMING_UNIFORM)
646    }
647
648    /// Creates a new variable initialized randomly with kaiming normal.
649    ///
650    /// The new variable is named according to the name parameter and
651    /// has the specified shape. The variable is trainable, its gradient
652    /// will be tracked.
653    /// The variable uses a float tensor initialized randomly using a
654    /// normal distribution which stdev follow Kaiming initialization.
655    pub fn f_kaiming_normal(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
656        self.f_var(name, dims, super::init::DEFAULT_KAIMING_NORMAL)
657    }
658
659    /// Creates a new variable initialized randomly with an orthogonal matrix
660    ///
661    /// The new variable is named according to the name parameter and
662    /// has the specified shape. The variable is trainable, its gradient
663    /// will be tracked.
664    /// The variable uses a float tensor initialized randomly with an orthogonal
665    /// matrix as described in *Exact solutions to the nonlinear dynamics
666    /// of learning in deep linear neural networks* - Saxe, A. et. al. (2013).
667    /// The input tensor must have at least 2 dimensions, and for tensors
668    /// with more than 2 dimensions the trailing dimensions are flattened.
669    pub fn f_orthogonal(&self, name: &str, dims: &[i64], gain: f64) -> Result<Tensor, TchError> {
670        self.f_var(name, dims, Init::Orthogonal { gain })
671    }
672
673    /// Creates a new variable initialized by copying an existing tensor.
674    ///
675    /// The new variable is named according to the name parameter and
676    /// has the specified shape. The variable is trainable, its gradient
677    /// will be tracked.
678    /// The variable uses a float tensor initialized by copying some
679    /// given tensor.
680    pub fn f_var_copy(&self, name: &str, t: &Tensor) -> Result<Tensor, TchError> {
681        let mut v = self.f_zeros(name, &t.size())?;
682        crate::no_grad(|| v.f_copy_(t))?;
683        Ok(v)
684    }
685
686    /// Creates a new variable initialized with zeros.
687    ///
688    /// The new variable is named according to the name parameter and
689    /// has the specified shape. The variable will not be trainable so
690    /// gradients will not be tracked.
691    /// The variable uses a float tensor initialized with zeros.
692    pub fn zeros_no_train(&self, name: &str, dims: &[i64]) -> Tensor {
693        self.f_zeros_no_train(name, dims).unwrap()
694    }
695
696    /// Creates a new variable initialized with ones.
697    ///
698    /// The new variable is named according to the name parameter and
699    /// has the specified shape. The variable will not be trainable so
700    /// gradients will not be tracked.
701    /// The variable uses a float tensor initialized with ones.
702    pub fn ones_no_train(&self, name: &str, dims: &[i64]) -> Tensor {
703        self.f_ones_no_train(name, dims).unwrap()
704    }
705
706    /// Creates a new variable.
707    ///
708    /// The new variable is named according to the name parameter and
709    /// has the specified shape. The variable is trainable, its gradient
710    /// will be tracked.
711    /// The variable uses a float tensor initialized as per the
712    /// related argument.
713    pub fn var(&self, name: &str, dims: &[i64], init: Init) -> Tensor {
714        self.f_var(name, dims, init).unwrap()
715    }
716
717    /// Creates a new variable initialized with zeros.
718    ///
719    /// The new variable is named according to the name parameter and
720    /// has the specified shape. The variable is trainable, its gradient
721    /// will be tracked.
722    /// The variable uses a float tensor initialized with zeros.
723    pub fn zeros(&self, name: &str, dims: &[i64]) -> Tensor {
724        self.f_zeros(name, dims).unwrap()
725    }
726
727    /// Creates a new variable initialized with ones.
728    ///
729    /// The new variable is named according to the name parameter and
730    /// has the specified shape. The variable is trainable, its gradient
731    /// will be tracked.
732    /// The variable uses a float tensor initialized with ones.
733    pub fn ones(&self, name: &str, dims: &[i64]) -> Tensor {
734        self.f_ones(name, dims).unwrap()
735    }
736
737    /// Creates a new variable initialized randomly with normal distribution.
738    ///
739    /// The new variable is named according to the name parameter and
740    /// has the specified shape. The variable is trainable, its gradient
741    /// will be tracked.
742    /// The variable uses a float tensor initialized randomly using a
743    /// standard normal distribution.
744    pub fn randn_standard(&self, name: &str, dims: &[i64]) -> Tensor {
745        self.f_randn_standard(name, dims).unwrap()
746    }
747
748    /// Creates a new variable initialized randomly with normal distribution.
749    ///
750    /// The new variable is named according to the name parameter and
751    /// has the specified shape. The variable is trainable, its gradient
752    /// will be tracked.
753    /// The variable uses a float tensor initialized randomly using a
754    /// normal distribution with the specified mean and standard deviation.
755    pub fn randn(&self, name: &str, dims: &[i64], mean: f64, stdev: f64) -> Tensor {
756        self.f_randn(name, dims, mean, stdev).unwrap()
757    }
758
759    /// Creates a new variable initialized randomly with uniform distribution.
760    ///
761    /// The new variable is named according to the name parameter and
762    /// has the specified shape. The variable is trainable, its gradient
763    /// will be tracked.
764    /// The variable uses a float tensor initialized randomly using a
765    /// uniform distribution between the specified bounds.
766    pub fn uniform(&self, name: &str, dims: &[i64], lo: f64, up: f64) -> Tensor {
767        self.f_uniform(name, dims, lo, up).unwrap()
768    }
769
770    /// Creates a new variable initialized randomly with kaiming uniform.
771    ///
772    /// The new variable is named according to the name parameter and
773    /// has the specified shape. The variable is trainable, its gradient
774    /// will be tracked.
775    /// The variable uses a float tensor initialized randomly using a
776    /// uniform distribution which bounds follow Kaiming initialization.
777    pub fn kaiming_uniform(&self, name: &str, dims: &[i64]) -> Tensor {
778        self.f_kaiming_uniform(name, dims).unwrap()
779    }
780
781    /// Creates a new variable initialized randomly with kaiming normal.
782    ///
783    /// The new variable is named according to the name parameter and
784    /// has the specified shape. The variable is trainable, its gradient
785    /// will be tracked.
786    /// The variable uses a float tensor initialized randomly using a
787    /// normal distribution which stdev follow Kaiming initialization.
788    pub fn kaiming_normal(&self, name: &str, dims: &[i64]) -> Tensor {
789        self.f_kaiming_normal(name, dims).unwrap()
790    }
791
792    /// Creates a new variable initialized randomly with an orthogonal matrix
793    ///
794    /// The new variable is named according to the name parameter and
795    /// has the specified shape. The variable is trainable, its gradient
796    /// will be tracked.
797    /// The variable uses a float tensor initialized randomly with an orthogonal
798    /// matrix as described in *Exact solutions to the nonlinear dynamics
799    /// of learning in deep linear neural networks* - Saxe, A. et. al. (2013).
800    /// The input tensor must have at least 2 dimensions, and for tensors
801    /// with more than 2 dimensions the trailing dimensions are flattened.
802    pub fn orthogonal(&self, name: &str, dims: &[i64], gain: f64) -> Tensor {
803        self.f_orthogonal(name, dims, gain).unwrap()
804    }
805
806    /// Creates a new variable initialized by copying an existing tensor.
807    ///
808    /// The new variable is named according to the name parameter and
809    /// has the specified shape. The variable is trainable, its gradient
810    /// will be tracked.
811    /// The variable uses a float tensor initialized by copying some
812    /// given tensor.
813    pub fn var_copy(&self, name: &str, t: &Tensor) -> Tensor {
814        self.f_var_copy(name, t).unwrap()
815    }
816
817    /// Gets the tensor corresponding to a given name if present.
818    pub fn get(&self, name: &str) -> Option<Tensor> {
819        let path = self.path(name);
820        let variables = self.var_store.variables_.lock().unwrap();
821        variables.named_variables.get(&path).map(|v| v.shallow_clone())
822    }
823
824    /// Gets the entry corresponding to a given name for in-place manipulation.
825    pub fn entry<'b>(&'b self, name: &'b str) -> Entry<'b> {
826        let variables = self.var_store.variables_.lock().unwrap();
827        Entry { name, variables, path: self }
828    }
829}
830
831impl Entry<'_> {
832    /// Returns the existing entry if, otherwise create a new variable.
833    ///
834    /// If this entry name matches the name of a variables stored in the
835    /// var store, the corresponding tensor is returned. Otherwise a new
836    /// variable is added to the var-store with the entry name and is
837    /// initialized according to the init parameter.
838    pub fn or_var(self, dims: &[i64], init: Init) -> Tensor {
839        let v = super::init(init, dims, self.path.device());
840        self.path.get_or_add_with_lock(self.name, v, true, self.variables)
841    }
842
843    /// Returns the existing entry if, otherwise create a new variable.
844    pub fn or_var_copy(self, tensor: &Tensor) -> Tensor {
845        let mut v = self.or_zeros(&tensor.size());
846        crate::no_grad(|| v.copy_(tensor));
847        v
848    }
849
850    /// Returns the existing entry if, otherwise create a new variable.
851    pub fn or_kaiming_uniform(self, dims: &[i64]) -> Tensor {
852        self.or_var(dims, super::init::DEFAULT_KAIMING_NORMAL)
853    }
854
855    /// Returns the existing entry if, otherwise create a new variable.
856    pub fn or_kaiming_normal(self, dims: &[i64]) -> Tensor {
857        self.or_var(dims, super::init::DEFAULT_KAIMING_NORMAL)
858    }
859
860    /// Returns the existing entry if, otherwise create a new variable.
861    pub fn or_orthogonal(self, dims: &[i64], gain: f64) -> Tensor {
862        self.or_var(dims, Init::Orthogonal { gain })
863    }
864
865    /// Returns the existing entry if, otherwise create a new variable.
866    pub fn or_ones(self, dims: &[i64]) -> Tensor {
867        self.or_var(dims, Init::Const(1.))
868    }
869
870    /// Returns the existing entry if, otherwise create a new variable.
871    pub fn or_ones_no_train(self, dims: &[i64]) -> Tensor {
872        let o = Tensor::ones(dims, (Kind::Float, self.path.device()));
873        self.path.get_or_add_with_lock(self.name, o, true, self.variables)
874    }
875
876    /// Returns the existing entry if, otherwise create a new variable.
877    pub fn or_randn(self, dims: &[i64], mean: f64, stdev: f64) -> Tensor {
878        self.or_var(dims, Init::Randn { mean, stdev })
879    }
880
881    /// Returns the existing entry if, otherwise create a new variable.
882    pub fn or_randn_standard(self, dims: &[i64]) -> Tensor {
883        let init = Init::Randn { mean: 0., stdev: 1. };
884        self.or_var(dims, init)
885    }
886
887    /// Returns the existing entry if, otherwise create a new variable.
888    pub fn or_uniform(self, dims: &[i64], lo: f64, up: f64) -> Tensor {
889        self.or_var(dims, Init::Uniform { lo, up })
890    }
891
892    /// Returns the existing entry if, otherwise create a new variable.
893    pub fn or_zeros(self, dims: &[i64]) -> Tensor {
894        self.or_var(dims, Init::Const(0.))
895    }
896
897    /// Returns the existing entry if, otherwise create a new variable.
898    pub fn or_zeros_no_train(self, dims: &[i64]) -> Tensor {
899        let z = Tensor::zeros(dims, (Kind::Float, self.path.device()));
900        self.path.get_or_add_with_lock(self.name, z, true, self.variables)
901    }
902}
903
904impl<'a, T> Div<T> for &mut Path<'a>
905where
906    T: std::string::ToString,
907{
908    type Output = Path<'a>;
909
910    fn div(self, rhs: T) -> Self::Output {
911        self.sub(rhs.to_string())
912    }
913}
914
915impl<'a, T> Div<T> for &Path<'a>
916where
917    T: std::string::ToString,
918{
919    type Output = Path<'a>;
920
921    fn div(self, rhs: T) -> Self::Output {
922        self.sub(rhs.to_string())
923    }
924}
925
926impl<'a, T> Div<T> for Path<'a>
927where
928    T: std::string::ToString,
929{
930    type Output = Path<'a>;
931
932    fn div(self, rhs: T) -> Self::Output {
933        self.sub(rhs.to_string())
934    }
935}