tch_plus/nn/var_store.rs
1//! Variable stores.
2use super::Init;
3use crate::tensor::Tensor;
4use crate::wrappers::stream::ReadSeekAdapter;
5use crate::{Device, Kind, TchError};
6use std::collections::hash_map::Entry::{Occupied, Vacant};
7use std::collections::HashMap;
8use std::io::{Read, Seek};
9use std::ops::Div;
10use std::sync::{Arc, Mutex, MutexGuard};
11
12/// The separator is used to separate path elements in the tensor names.
13const SEP: char = '.';
14
15#[derive(Debug)]
16pub struct Var {
17 pub tensor: Tensor,
18 pub group: usize,
19}
20
21// When the variable store is frozen, the trainable_variables vector
22// still contains the same tensors however these tensors are set not
23// to require gradients.
24#[derive(Debug)]
25pub struct Variables {
26 pub named_variables: HashMap<String, Tensor>,
27 pub trainable_variables: Vec<Var>,
28}
29
30/// A VarStore is used to store variables used by one or multiple layers.
31/// It specifies a single device where all variables are stored.
32#[derive(Debug)]
33pub struct VarStore {
34 pub variables_: Arc<Mutex<Variables>>,
35 device: Device,
36 kind: Kind,
37}
38
39/// A variable store with an associated path for variables naming.
40#[derive(Debug, Clone)]
41pub struct Path<'a> {
42 path: Vec<String>,
43 group: usize,
44 var_store: &'a VarStore,
45}
46
47/// An Entry holds an entry corresponding to a given name in Path.
48#[derive(Debug)]
49pub struct Entry<'a> {
50 name: &'a str,
51 variables: MutexGuard<'a, Variables>,
52 // This field holds the mutex lock
53 path: &'a Path<'a>,
54}
55
56impl VarStore {
57 /// Creates a new var-store located on the specified device.
58 pub fn new(device: Device) -> VarStore {
59 let variables =
60 Variables { named_variables: HashMap::new(), trainable_variables: Vec::new() };
61 VarStore { variables_: Arc::new(Mutex::new(variables)), device, kind: Kind::Float }
62 }
63
64 pub fn merge(var_stores: Vec<(VarStore, Option<&str>)>) -> Result<VarStore, TchError> {
65 let mut new_var_store = VarStore::new(Device::Cpu);
66
67 if var_stores.is_empty() {
68 Ok(new_var_store)
69 } else {
70 let mut new_variables =
71 Variables { named_variables: HashMap::new(), trainable_variables: Vec::new() };
72 let device = var_stores[0].0.device();
73
74 for (var_store, prefix) in var_stores {
75 if var_store.device() != device {
76 return Err(TchError::Torch(format!(
77 "All VarStores must be on the same device, got {:?} and {:?}",
78 device,
79 var_store.device()
80 )));
81 }
82 for (var_name, var) in var_store.variables() {
83 let new_var_name = format!("{}{}", prefix.unwrap_or(""), var_name);
84 match new_variables.named_variables.entry(new_var_name) {
85 Occupied(v) => {
86 return Err(TchError::Torch(format!(
87 "Duplicate variable name found: {}. Provide a unique prefix to allow merge operation",
88 v.key(),
89 )));
90 }
91 Vacant(v) => {
92 v.insert(var);
93 }
94 }
95 }
96 for trainable_var in
97 var_store.variables_.lock().unwrap().trainable_variables.drain(..)
98 {
99 new_variables.trainable_variables.push(trainable_var);
100 }
101 }
102 new_var_store.variables_ = Arc::new(Mutex::new(new_variables));
103 new_var_store.device = device;
104
105 Ok(new_var_store)
106 }
107 }
108
109 /// Gets the device for this var-store.
110 pub fn device(&self) -> Device {
111 self.device
112 }
113
114 /// Gets the default kind of new variables
115 pub fn kind(&self) -> Kind {
116 self.kind
117 }
118
119 /// Returns the number of tensors currently stored on this var-store.
120 pub fn len(&self) -> usize {
121 let variables = self.variables_.lock().unwrap();
122 variables.named_variables.len()
123 }
124
125 /// Returns true if no tensors are currently stored on this var-store.
126 pub fn is_empty(&self) -> bool {
127 let variables = self.variables_.lock().unwrap();
128 variables.named_variables.is_empty()
129 }
130
131 /// Returns all the trainable variables for this var-store.
132 pub fn trainable_variables(&self) -> Vec<Tensor> {
133 let variables = self.variables_.lock().unwrap();
134 variables.trainable_variables.iter().map(|v| v.tensor.shallow_clone()).collect()
135 }
136
137 /// Returns all variables along with their names.
138 pub fn variables(&self) -> HashMap<String, Tensor> {
139 let variables = self.variables_.lock().unwrap();
140 variables
141 .named_variables
142 .iter()
143 .map(|(name, v)| (name.clone(), v.shallow_clone()))
144 .collect()
145 }
146
147 /// Gets the root path for this variable store.
148 ///
149 /// Variables are named and organized using paths. This function returns
150 /// the top level path for the var store and can be combined with '/'
151 /// to create sub-paths.
152 pub fn root(&self) -> Path {
153 Path { path: vec![], group: 0, var_store: self }
154 }
155
156 /// Saves the var-store variable values to a file.
157 ///
158 /// Weight values for all the tensors currently stored in the
159 /// var-store are saved in the given file.
160 ///
161 /// If the given path ends with the suffix `.safetensors`, the file will
162 /// be saved in safetensors format. Otherwise, libtorch C++ module format
163 /// will be used. Note that saving in pickle format (`.pt` extension) is
164 /// not supported by the C++ API of Torch.
165 pub fn save<T: AsRef<std::path::Path>>(&self, path: T) -> Result<(), TchError> {
166 let variables = self.variables_.lock().unwrap();
167 let named_tensors = variables.named_variables.iter().collect::<Vec<_>>();
168 match path.as_ref().extension().and_then(|x| x.to_str()) {
169 Some("safetensors") => Tensor::write_safetensors(named_tensors.as_slice(), path),
170 Some(_) | None => Tensor::save_multi(named_tensors.as_slice(), path),
171 }
172 }
173
174 /// Saves the var-store variable values to a stream.
175 ///
176 /// Weight values for all the tensors currently stored in the
177 /// var-store gets saved in the given stream.
178 pub fn save_to_stream<W: std::io::Write>(&self, stream: W) -> Result<(), TchError> {
179 let variables = self.variables_.lock().unwrap();
180 let named_tensors = variables.named_variables.iter().collect::<Vec<_>>();
181 Tensor::save_multi_to_stream(named_tensors.as_slice(), stream)
182 }
183
184 fn named_tensors<T: AsRef<std::path::Path>>(
185 &self,
186 path: T,
187 ) -> Result<HashMap<String, Tensor>, TchError> {
188 let named_tensors = match path.as_ref().extension().and_then(|x| x.to_str()) {
189 Some("bin") | Some("pt") => Tensor::loadz_multi_with_device(&path, self.device),
190 Some("safetensors") => Tensor::read_safetensors(path),
191 Some(_) | None => Tensor::load_multi_with_device(&path, self.device),
192 };
193 Ok(named_tensors?.into_iter().collect())
194 }
195
196 /// Copies the data from source tensor to destination
197 ///
198 /// Updates the precision of the destination to match the source
199 fn copy_data_with_precision_update(src: &Tensor, dst: &mut Tensor) -> Result<(), TchError> {
200 dst.set_data(&dst.to_kind(src.kind()));
201 dst.f_copy_(src)
202 }
203
204 fn load_internal<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError> {
205 let named_tensors = self.named_tensors(&path)?;
206 let mut variables = self.variables_.lock().unwrap();
207 for (name, var) in variables.named_variables.iter_mut() {
208 match named_tensors.get(name) {
209 Some(src) => crate::no_grad(|| {
210 Self::copy_data_with_precision_update(src, var)
211 .map_err(|e| e.path_context(name))
212 })?,
213 None => {
214 return Err(TchError::TensorNameNotFound(
215 name.to_string(),
216 path.as_ref().to_string_lossy().into_owned(),
217 ));
218 }
219 }
220 }
221 Ok(())
222 }
223
224 /// Loads the var-store variable values from a file.
225 ///
226 /// Weight values for all the tensors currently stored in the
227 /// var-store are loaded from the given file. Note that the set of
228 /// variables stored in the var-store is not changed, only the values
229 /// for these tensors are modified.
230 ///
231 /// The format of the file is deduced from the file extension:
232 /// - `.safetensors`: The file is assumed to be in safetensors format.
233 /// - `.bin` or `.pt`: The file is assumed to be in pickle format.
234 /// - Otherwise, the file is assumed to be in libtorch C++ module format.
235 pub fn load<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError> {
236 if self.device != Device::Mps {
237 self.load_internal(path)
238 } else {
239 // Current workaround to allow loading in MPS device.
240 // On new libtorch releases check if direct loading becomes possible and revert
241 // See (https://github.com/LaurentMazare/tch-rs/issues/609#issuecomment-1427071598).
242 self.set_device(Device::Cpu);
243 let or_error = self.load_internal(path);
244 // Be cautious not to early exit so as to ensure that the device is set back to Mps
245 // even on errors.
246 self.set_device(Device::Mps);
247 or_error
248 }
249 }
250
251 /// Loads the var-store variable values from a stream.
252 ///
253 /// Weight values for all the tensors currently stored in the
254 /// var-store gets loaded from the given stream. Note that the set of
255 /// variables stored in the var-store is not changed, only the values
256 /// for these tensors are modified.
257 pub fn load_from_stream<S: Read + Seek>(&mut self, stream: S) -> Result<(), TchError> {
258 let adapter = ReadSeekAdapter::new(stream);
259 let named_tensors = Tensor::load_multi_from_stream_with_device(adapter, self.device)?;
260 let named_tensors: HashMap<_, _> = named_tensors.into_iter().collect();
261 let mut variables = self.variables_.lock().unwrap();
262 for (name, var) in variables.named_variables.iter_mut() {
263 match named_tensors.get(name) {
264 Some(src) => crate::no_grad(|| {
265 Self::copy_data_with_precision_update(src, var)
266 .map_err(|e| e.path_context(name))
267 })?,
268 None => {
269 return Err(TchError::TensorNameNotFound(
270 name.to_string(),
271 "source stream".to_string(),
272 ));
273 }
274 }
275 }
276 Ok(())
277 }
278
279 /// Loads the var-store variable values from a file if it exists.
280 ///
281 /// Weight values for the tensors currently stored in the var-store and the given file get
282 /// loaded from the given file. If a variable in the var store is not present in the given file,
283 /// it is skipped and its values are not updated. This method should be used if pre-trained
284 /// weight for only parts of the model are available.
285 /// Note that the set of variables stored in the var-store is not changed, only the values
286 /// for these tensors are modified.
287 ///
288 /// Returns a String Vector containing the names of missing variables.
289 pub fn load_partial<T: AsRef<std::path::Path>>(
290 &mut self,
291 path: T,
292 ) -> Result<Vec<String>, TchError> {
293 let named_tensors = self.named_tensors(&path)?;
294 let mut variables = self.variables_.lock().unwrap();
295 let mut missing_variables = Vec::new();
296 for (name, var) in variables.named_variables.iter_mut() {
297 match named_tensors.get(name) {
298 Some(src) => crate::no_grad(|| {
299 Self::copy_data_with_precision_update(src, var)
300 .map_err(|e| e.path_context(name))
301 })?,
302 None => {
303 missing_variables.push(name.to_owned());
304 }
305 }
306 }
307 Ok(missing_variables)
308 }
309
310 /// Freezes a var store.
311 ///
312 /// Gradients for the variables in this store are not tracked
313 /// anymore.
314 pub fn freeze(&mut self) {
315 let variables = self.variables_.lock().unwrap();
316 for variable in variables.trainable_variables.iter() {
317 let _v = variable.tensor.set_requires_grad(false);
318 }
319 }
320
321 /// Unfreezes a var store.
322 ///
323 /// Gradients for the variables in this store are tracked again.
324 pub fn unfreeze(&mut self) {
325 let variables = self.variables_.lock().unwrap();
326 for variable in variables.trainable_variables.iter() {
327 let _v = variable.tensor.set_requires_grad(true);
328 }
329 }
330
331 /// Casts all variables in a var store to the target kind and sets the default kind
332 /// for new variables.
333 ///
334 /// For floating-point conversion, methods `half`, `bfloat16`, `float` and `double`
335 /// should be preferred as they ensure only float-like variables will be converted
336 /// to the target type.
337 pub fn set_kind(&mut self, kind: Kind) {
338 self.root().set_kind(kind);
339 self.kind = kind;
340 }
341
342 /// Casts all float-like variable of a var store to half-precision (Half kind).
343 pub fn half(&mut self) {
344 self.root().half();
345 }
346
347 /// Casts all float-like variable of a var store to bfloat16-precision (BFloat16 kind).
348 pub fn bfloat16(&mut self) {
349 self.root().bfloat16();
350 }
351
352 /// Casts all float-like variable of a var store to single-precision (Float kind).
353 pub fn float(&mut self) {
354 self.root().float();
355 }
356
357 /// Casts all float-like variable of a var store to single-precision (Double kind).
358 pub fn double(&mut self) {
359 self.root().double();
360 }
361
362 /// Migrates a var store and all its tensor to a target device.
363 pub fn set_device(&mut self, device: Device) {
364 let mut variables = self.variables_.lock().unwrap();
365 for (_, variable) in variables.named_variables.iter_mut() {
366 variable.set_data(&variable.to_device(device));
367 }
368 self.device = device
369 }
370
371 /// Copies variable values from a source var store to this var store.
372 ///
373 /// All the variables in this var store have to exist with the same
374 /// name in the source var store, otherwise an error is returned.
375 pub fn copy(&mut self, src: &VarStore) -> Result<(), TchError> {
376 let mut variables = self.variables_.lock().unwrap();
377 let src_variables = src.variables_.lock().unwrap();
378 let device = self.device;
379 for name in variables.named_variables.keys() {
380 if !src_variables.named_variables.contains_key(name) {
381 return Err(TchError::TensorNameNotFound(
382 name.to_string(),
383 "src var-store".to_string(),
384 ));
385 }
386 }
387 for (name, var) in variables.named_variables.iter_mut() {
388 let src_var = src_variables.named_variables.get(name).unwrap();
389 crate::no_grad(|| var.f_copy_(&src_var.to_device(device)))?;
390 }
391 Ok(())
392 }
393}
394
395impl<'a> Path<'a> {
396 /// Get the components of the path.
397 pub fn components(&self) -> impl Iterator<Item = &str> {
398 self.path.iter().map(String::as_str)
399 }
400
401 /// Gets a sub-path of the given path.
402 pub fn sub<T: std::string::ToString>(&self, s: T) -> Path<'a> {
403 let s = s.to_string();
404 if s.chars().any(|x| x == SEP) {
405 panic!("sub name cannot contain {SEP} {s}");
406 }
407 let mut path = self.path.clone();
408 path.push(s);
409 Path { path, group: self.group, var_store: self.var_store }
410 }
411
412 pub fn set_group(&self, group: usize) -> Path<'a> {
413 Path { path: self.path.clone(), group, var_store: self.var_store }
414 }
415
416 /// Gets the device where the var-store variables are stored.
417 pub fn device(&self) -> Device {
418 self.var_store.device
419 }
420
421 /// Gets the default kind of new variables
422 pub fn kind(&self) -> Kind {
423 self.var_store.kind
424 }
425
426 pub fn path(&self, name: &str) -> String {
427 if name.chars().any(|x| x == SEP) {
428 panic!("variable name cannot contain {SEP} {name}");
429 }
430 if self.path.is_empty() {
431 name.to_string()
432 } else {
433 format!("{}{}{}", self.path.join(&SEP.to_string()), SEP, name)
434 }
435 }
436
437 /// Casts all variables in a var store sub-path to the target kind .
438 ///
439 /// Only the variable in the path sub-tree are cast to the target kind:
440 /// other var store variables are unaffected. For floating-point conversion, methods
441 /// `half`, `bfloat16`, `float` and `double` should be preferred as they ensure only
442 /// float-like variables will be converted to the target type.
443 pub fn set_kind(&mut self, kind: Kind) {
444 let path_root = self.path.join(SEP.to_string().as_str());
445 let mut variables = self.var_store.variables_.lock().unwrap();
446 for (variable_name, variable) in variables.named_variables.iter_mut() {
447 if variable_name.starts_with(&path_root) {
448 variable.set_data(&variable.to_kind(kind));
449 }
450 }
451 }
452
453 /// Casts all float-like variables in a var store sub-path to the target kind .
454 ///
455 /// Only the float-like variable in the path sub-tree are cast to the target kind:
456 /// other var store variables are unaffected
457 fn set_float_kind(&mut self, kind: Kind) {
458 let path_root = self.path.join(SEP.to_string().as_str());
459 let mut variables = self.var_store.variables_.lock().unwrap();
460 for (variable_name, variable) in variables.named_variables.iter_mut() {
461 if variable_name.starts_with(&path_root) & variable.is_floating_point() {
462 variable.set_data(&variable.to_kind(kind));
463 }
464 }
465 }
466
467 /// Casts all float-like variables in a var store sub-path to half-precision (Half kind).
468 ///
469 /// Only the variable in the path sub-tree are cast to half-precision:
470 /// other var store variables are unaffected
471 pub fn half(&mut self) {
472 self.set_float_kind(Kind::Half);
473 }
474
475 /// Casts all float-like variables in a var store sub-path to bfloat16-precision (BFloat16 kind).
476 ///
477 /// Only the variable in the path sub-tree are cast to bfloat16-precision:
478 /// other var store variables are unaffected
479 pub fn bfloat16(&mut self) {
480 self.set_float_kind(Kind::BFloat16);
481 }
482
483 /// Casts all float-like variables in a var store sub-path to single-precision (Float kind).
484 ///
485 /// Only the variable in the path sub-tree are cast to single-precision:
486 /// other var store variables are unaffected
487 pub fn float(&mut self) {
488 self.set_float_kind(Kind::Float);
489 }
490
491 /// Casts all float-like variables in a var store sub-path to double-precision (Double kind).
492 ///
493 /// Only the variable in the path sub-tree are cast to double-precision:
494 /// other var store variables are unaffected
495 pub fn double(&mut self) {
496 self.set_float_kind(Kind::Double);
497 }
498
499 pub fn add(&self, name: &str, tensor: Tensor, trainable: bool) -> Tensor {
500 let path = self.path(name);
501 let mut variables = self.var_store.variables_.lock().unwrap();
502 let path = if variables.named_variables.contains_key(&path) {
503 format!("{}__{}", path, variables.named_variables.len())
504 } else {
505 path
506 };
507 let tensor = if trainable { tensor.set_requires_grad(true) } else { tensor };
508 if trainable {
509 let var = Var { tensor: tensor.shallow_clone(), group: self.group };
510 variables.trainable_variables.push(var);
511 };
512 variables.named_variables.insert(path, tensor.shallow_clone());
513 tensor
514 }
515
516 fn get_or_add_with_lock(
517 &self,
518 name: &str,
519 tensor: Tensor,
520 trainable: bool,
521 mut variables: MutexGuard<Variables>,
522 ) -> Tensor {
523 let path = self.path(name);
524 if let Some(var) = variables.named_variables.get(&path) {
525 return var.shallow_clone();
526 }
527
528 let tensor = if trainable { tensor.set_requires_grad(true) } else { tensor };
529 if trainable {
530 let var = Var { tensor: tensor.shallow_clone(), group: self.group };
531 variables.trainable_variables.push(var);
532 }
533 variables.named_variables.insert(path, tensor.shallow_clone());
534 tensor
535 }
536
537 /// Creates a new variable initialized with zeros.
538 ///
539 /// The new variable is named according to the name parameter and
540 /// has the specified shape. The variable will not be trainable so
541 /// gradients will not be tracked.
542 /// The variable uses a float tensor initialized with zeros.
543 pub fn f_zeros_no_train(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
544 let z = Tensor::f_zeros(dims, (Kind::Float, self.device()))?;
545 Ok(self.add(name, z, false))
546 }
547
548 /// Creates a new variable initialized with ones.
549 ///
550 /// The new variable is named according to the name parameter and
551 /// has the specified shape. The variable will not be trainable so
552 /// gradients will not be tracked.
553 /// The variable uses a float tensor initialized with ones.
554 pub fn f_ones_no_train(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
555 let o = Tensor::f_ones(dims, (Kind::Float, self.device()))?;
556 Ok(self.add(name, o, false))
557 }
558
559 /// Creates a new variable.
560 ///
561 /// The new variable is named according to the name parameter and
562 /// has the specified shape. The variable is trainable, its gradient
563 /// will be tracked.
564 /// The variable uses a float tensor initialized as per the
565 /// related argument.
566 pub fn f_var(&self, name: &str, dims: &[i64], init: Init) -> Result<Tensor, TchError> {
567 let v = super::f_init(init, dims, self.device(), self.kind())?;
568 Ok(self.add(name, v, true))
569 }
570
571 /// Creates a new variable initialized with zeros.
572 ///
573 /// The new variable is named according to the name parameter and
574 /// has the specified shape. The variable is trainable, its gradient
575 /// will be tracked.
576 /// The variable uses a float tensor initialized with zeros.
577 pub fn f_zeros(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
578 self.f_var(name, dims, Init::Const(0.))
579 }
580
581 /// Creates a new variable initialized with ones.
582 ///
583 /// The new variable is named according to the name parameter and
584 /// has the specified shape. The variable is trainable, its gradient
585 /// will be tracked.
586 /// The variable uses a float tensor initialized with ones.
587 pub fn f_ones(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
588 self.f_var(name, dims, Init::Const(1.))
589 }
590
591 /// Creates a new variable initialized randomly with normal distribution.
592 ///
593 /// The new variable is named according to the name parameter and
594 /// has the specified shape. The variable is trainable, its gradient
595 /// will be tracked.
596 /// The variable uses a float tensor initialized randomly using a
597 /// standard normal distribution.
598 pub fn f_randn_standard(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
599 let init = Init::Randn { mean: 0., stdev: 1. };
600 self.f_var(name, dims, init)
601 }
602
603 /// Creates a new variable initialized randomly with normal distribution.
604 ///
605 /// The new variable is named according to the name parameter and
606 /// has the specified shape. The variable is trainable, its gradient
607 /// will be tracked.
608 /// The variable uses a float tensor initialized randomly using a
609 /// normal distribution with the specified mean and standard deviation.
610 pub fn f_randn(
611 &self,
612 name: &str,
613 dims: &[i64],
614 mean: f64,
615 stdev: f64,
616 ) -> Result<Tensor, TchError> {
617 self.f_var(name, dims, Init::Randn { mean, stdev })
618 }
619
620 /// Creates a new variable initialized randomly with uniform distribution.
621 ///
622 /// The new variable is named according to the name parameter and
623 /// has the specified shape. The variable is trainable, its gradient
624 /// will be tracked.
625 /// The variable uses a float tensor initialized randomly using a
626 /// uniform distribution between the specified bounds.
627 pub fn f_uniform(
628 &self,
629 name: &str,
630 dims: &[i64],
631 lo: f64,
632 up: f64,
633 ) -> Result<Tensor, TchError> {
634 self.f_var(name, dims, Init::Uniform { lo, up })
635 }
636
637 /// Creates a new variable initialized randomly with kaiming uniform.
638 ///
639 /// The new variable is named according to the name parameter and
640 /// has the specified shape. The variable is trainable, its gradient
641 /// will be tracked.
642 /// The variable uses a float tensor initialized randomly using a
643 /// uniform distribution which bounds follow Kaiming initialization.
644 pub fn f_kaiming_uniform(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
645 self.f_var(name, dims, super::init::DEFAULT_KAIMING_UNIFORM)
646 }
647
648 /// Creates a new variable initialized randomly with kaiming normal.
649 ///
650 /// The new variable is named according to the name parameter and
651 /// has the specified shape. The variable is trainable, its gradient
652 /// will be tracked.
653 /// The variable uses a float tensor initialized randomly using a
654 /// normal distribution which stdev follow Kaiming initialization.
655 pub fn f_kaiming_normal(&self, name: &str, dims: &[i64]) -> Result<Tensor, TchError> {
656 self.f_var(name, dims, super::init::DEFAULT_KAIMING_NORMAL)
657 }
658
659 /// Creates a new variable initialized randomly with an orthogonal matrix
660 ///
661 /// The new variable is named according to the name parameter and
662 /// has the specified shape. The variable is trainable, its gradient
663 /// will be tracked.
664 /// The variable uses a float tensor initialized randomly with an orthogonal
665 /// matrix as described in *Exact solutions to the nonlinear dynamics
666 /// of learning in deep linear neural networks* - Saxe, A. et. al. (2013).
667 /// The input tensor must have at least 2 dimensions, and for tensors
668 /// with more than 2 dimensions the trailing dimensions are flattened.
669 pub fn f_orthogonal(&self, name: &str, dims: &[i64], gain: f64) -> Result<Tensor, TchError> {
670 self.f_var(name, dims, Init::Orthogonal { gain })
671 }
672
673 /// Creates a new variable initialized by copying an existing tensor.
674 ///
675 /// The new variable is named according to the name parameter and
676 /// has the specified shape. The variable is trainable, its gradient
677 /// will be tracked.
678 /// The variable uses a float tensor initialized by copying some
679 /// given tensor.
680 pub fn f_var_copy(&self, name: &str, t: &Tensor) -> Result<Tensor, TchError> {
681 let mut v = self.f_zeros(name, &t.size())?;
682 crate::no_grad(|| v.f_copy_(t))?;
683 Ok(v)
684 }
685
686 /// Creates a new variable initialized with zeros.
687 ///
688 /// The new variable is named according to the name parameter and
689 /// has the specified shape. The variable will not be trainable so
690 /// gradients will not be tracked.
691 /// The variable uses a float tensor initialized with zeros.
692 pub fn zeros_no_train(&self, name: &str, dims: &[i64]) -> Tensor {
693 self.f_zeros_no_train(name, dims).unwrap()
694 }
695
696 /// Creates a new variable initialized with ones.
697 ///
698 /// The new variable is named according to the name parameter and
699 /// has the specified shape. The variable will not be trainable so
700 /// gradients will not be tracked.
701 /// The variable uses a float tensor initialized with ones.
702 pub fn ones_no_train(&self, name: &str, dims: &[i64]) -> Tensor {
703 self.f_ones_no_train(name, dims).unwrap()
704 }
705
706 /// Creates a new variable.
707 ///
708 /// The new variable is named according to the name parameter and
709 /// has the specified shape. The variable is trainable, its gradient
710 /// will be tracked.
711 /// The variable uses a float tensor initialized as per the
712 /// related argument.
713 pub fn var(&self, name: &str, dims: &[i64], init: Init) -> Tensor {
714 self.f_var(name, dims, init).unwrap()
715 }
716
717 /// Creates a new variable initialized with zeros.
718 ///
719 /// The new variable is named according to the name parameter and
720 /// has the specified shape. The variable is trainable, its gradient
721 /// will be tracked.
722 /// The variable uses a float tensor initialized with zeros.
723 pub fn zeros(&self, name: &str, dims: &[i64]) -> Tensor {
724 self.f_zeros(name, dims).unwrap()
725 }
726
727 /// Creates a new variable initialized with ones.
728 ///
729 /// The new variable is named according to the name parameter and
730 /// has the specified shape. The variable is trainable, its gradient
731 /// will be tracked.
732 /// The variable uses a float tensor initialized with ones.
733 pub fn ones(&self, name: &str, dims: &[i64]) -> Tensor {
734 self.f_ones(name, dims).unwrap()
735 }
736
737 /// Creates a new variable initialized randomly with normal distribution.
738 ///
739 /// The new variable is named according to the name parameter and
740 /// has the specified shape. The variable is trainable, its gradient
741 /// will be tracked.
742 /// The variable uses a float tensor initialized randomly using a
743 /// standard normal distribution.
744 pub fn randn_standard(&self, name: &str, dims: &[i64]) -> Tensor {
745 self.f_randn_standard(name, dims).unwrap()
746 }
747
748 /// Creates a new variable initialized randomly with normal distribution.
749 ///
750 /// The new variable is named according to the name parameter and
751 /// has the specified shape. The variable is trainable, its gradient
752 /// will be tracked.
753 /// The variable uses a float tensor initialized randomly using a
754 /// normal distribution with the specified mean and standard deviation.
755 pub fn randn(&self, name: &str, dims: &[i64], mean: f64, stdev: f64) -> Tensor {
756 self.f_randn(name, dims, mean, stdev).unwrap()
757 }
758
759 /// Creates a new variable initialized randomly with uniform distribution.
760 ///
761 /// The new variable is named according to the name parameter and
762 /// has the specified shape. The variable is trainable, its gradient
763 /// will be tracked.
764 /// The variable uses a float tensor initialized randomly using a
765 /// uniform distribution between the specified bounds.
766 pub fn uniform(&self, name: &str, dims: &[i64], lo: f64, up: f64) -> Tensor {
767 self.f_uniform(name, dims, lo, up).unwrap()
768 }
769
770 /// Creates a new variable initialized randomly with kaiming uniform.
771 ///
772 /// The new variable is named according to the name parameter and
773 /// has the specified shape. The variable is trainable, its gradient
774 /// will be tracked.
775 /// The variable uses a float tensor initialized randomly using a
776 /// uniform distribution which bounds follow Kaiming initialization.
777 pub fn kaiming_uniform(&self, name: &str, dims: &[i64]) -> Tensor {
778 self.f_kaiming_uniform(name, dims).unwrap()
779 }
780
781 /// Creates a new variable initialized randomly with kaiming normal.
782 ///
783 /// The new variable is named according to the name parameter and
784 /// has the specified shape. The variable is trainable, its gradient
785 /// will be tracked.
786 /// The variable uses a float tensor initialized randomly using a
787 /// normal distribution which stdev follow Kaiming initialization.
788 pub fn kaiming_normal(&self, name: &str, dims: &[i64]) -> Tensor {
789 self.f_kaiming_normal(name, dims).unwrap()
790 }
791
792 /// Creates a new variable initialized randomly with an orthogonal matrix
793 ///
794 /// The new variable is named according to the name parameter and
795 /// has the specified shape. The variable is trainable, its gradient
796 /// will be tracked.
797 /// The variable uses a float tensor initialized randomly with an orthogonal
798 /// matrix as described in *Exact solutions to the nonlinear dynamics
799 /// of learning in deep linear neural networks* - Saxe, A. et. al. (2013).
800 /// The input tensor must have at least 2 dimensions, and for tensors
801 /// with more than 2 dimensions the trailing dimensions are flattened.
802 pub fn orthogonal(&self, name: &str, dims: &[i64], gain: f64) -> Tensor {
803 self.f_orthogonal(name, dims, gain).unwrap()
804 }
805
806 /// Creates a new variable initialized by copying an existing tensor.
807 ///
808 /// The new variable is named according to the name parameter and
809 /// has the specified shape. The variable is trainable, its gradient
810 /// will be tracked.
811 /// The variable uses a float tensor initialized by copying some
812 /// given tensor.
813 pub fn var_copy(&self, name: &str, t: &Tensor) -> Tensor {
814 self.f_var_copy(name, t).unwrap()
815 }
816
817 /// Gets the tensor corresponding to a given name if present.
818 pub fn get(&self, name: &str) -> Option<Tensor> {
819 let path = self.path(name);
820 let variables = self.var_store.variables_.lock().unwrap();
821 variables.named_variables.get(&path).map(|v| v.shallow_clone())
822 }
823
824 /// Gets the entry corresponding to a given name for in-place manipulation.
825 pub fn entry<'b>(&'b self, name: &'b str) -> Entry<'b> {
826 let variables = self.var_store.variables_.lock().unwrap();
827 Entry { name, variables, path: self }
828 }
829}
830
831impl Entry<'_> {
832 /// Returns the existing entry if, otherwise create a new variable.
833 ///
834 /// If this entry name matches the name of a variables stored in the
835 /// var store, the corresponding tensor is returned. Otherwise a new
836 /// variable is added to the var-store with the entry name and is
837 /// initialized according to the init parameter.
838 pub fn or_var(self, dims: &[i64], init: Init) -> Tensor {
839 let v = super::init(init, dims, self.path.device());
840 self.path.get_or_add_with_lock(self.name, v, true, self.variables)
841 }
842
843 /// Returns the existing entry if, otherwise create a new variable.
844 pub fn or_var_copy(self, tensor: &Tensor) -> Tensor {
845 let mut v = self.or_zeros(&tensor.size());
846 crate::no_grad(|| v.copy_(tensor));
847 v
848 }
849
850 /// Returns the existing entry if, otherwise create a new variable.
851 pub fn or_kaiming_uniform(self, dims: &[i64]) -> Tensor {
852 self.or_var(dims, super::init::DEFAULT_KAIMING_NORMAL)
853 }
854
855 /// Returns the existing entry if, otherwise create a new variable.
856 pub fn or_kaiming_normal(self, dims: &[i64]) -> Tensor {
857 self.or_var(dims, super::init::DEFAULT_KAIMING_NORMAL)
858 }
859
860 /// Returns the existing entry if, otherwise create a new variable.
861 pub fn or_orthogonal(self, dims: &[i64], gain: f64) -> Tensor {
862 self.or_var(dims, Init::Orthogonal { gain })
863 }
864
865 /// Returns the existing entry if, otherwise create a new variable.
866 pub fn or_ones(self, dims: &[i64]) -> Tensor {
867 self.or_var(dims, Init::Const(1.))
868 }
869
870 /// Returns the existing entry if, otherwise create a new variable.
871 pub fn or_ones_no_train(self, dims: &[i64]) -> Tensor {
872 let o = Tensor::ones(dims, (Kind::Float, self.path.device()));
873 self.path.get_or_add_with_lock(self.name, o, true, self.variables)
874 }
875
876 /// Returns the existing entry if, otherwise create a new variable.
877 pub fn or_randn(self, dims: &[i64], mean: f64, stdev: f64) -> Tensor {
878 self.or_var(dims, Init::Randn { mean, stdev })
879 }
880
881 /// Returns the existing entry if, otherwise create a new variable.
882 pub fn or_randn_standard(self, dims: &[i64]) -> Tensor {
883 let init = Init::Randn { mean: 0., stdev: 1. };
884 self.or_var(dims, init)
885 }
886
887 /// Returns the existing entry if, otherwise create a new variable.
888 pub fn or_uniform(self, dims: &[i64], lo: f64, up: f64) -> Tensor {
889 self.or_var(dims, Init::Uniform { lo, up })
890 }
891
892 /// Returns the existing entry if, otherwise create a new variable.
893 pub fn or_zeros(self, dims: &[i64]) -> Tensor {
894 self.or_var(dims, Init::Const(0.))
895 }
896
897 /// Returns the existing entry if, otherwise create a new variable.
898 pub fn or_zeros_no_train(self, dims: &[i64]) -> Tensor {
899 let z = Tensor::zeros(dims, (Kind::Float, self.path.device()));
900 self.path.get_or_add_with_lock(self.name, z, true, self.variables)
901 }
902}
903
904impl<'a, T> Div<T> for &mut Path<'a>
905where
906 T: std::string::ToString,
907{
908 type Output = Path<'a>;
909
910 fn div(self, rhs: T) -> Self::Output {
911 self.sub(rhs.to_string())
912 }
913}
914
915impl<'a, T> Div<T> for &Path<'a>
916where
917 T: std::string::ToString,
918{
919 type Output = Path<'a>;
920
921 fn div(self, rhs: T) -> Self::Output {
922 self.sub(rhs.to_string())
923 }
924}
925
926impl<'a, T> Div<T> for Path<'a>
927where
928 T: std::string::ToString,
929{
930 type Output = Path<'a>;
931
932 fn div(self, rhs: T) -> Self::Output {
933 self.sub(rhs.to_string())
934 }
935}