rust_bert/common/resources/mod.rs
1//! # Resource definitions for model weights, vocabularies and configuration files
2//!
3//! This crate relies on the concept of Resources to access the data used by the models.
4//! This includes:
5//! - model weights
6//! - configuration files
7//! - vocabularies
8//! - (optional) merges files for BPE-based tokenizers
9//!
10//! These are expected in the pipelines configurations or are used as utilities to reference to the
11//! resource location. Two types of resources are pre-defined:
12//! - LocalResource: points to a local file
13//! - RemoteResource: points to a remote file via a URL
14//! - BufferResource: refers to a buffer that contains file contents for a resource (currently only
15//! usable for weights)
16//!
17//! For `LocalResource` and `RemoteResource`, the local location of the file can be retrieved using
18//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
19//! or local resource. Default implementations for a number of `RemoteResources` are available as
20//! pre-trained models in each model module.
21
22mod buffer;
23mod local;
24
25use crate::common::error::RustBertError;
26pub use buffer::BufferResource;
27pub use local::LocalResource;
28use std::fmt::Debug;
29use std::ops::DerefMut;
30use std::path::PathBuf;
31use std::sync::RwLockWriteGuard;
32use tch::nn::VarStore;
33use tch::{Device, Kind};
34
35pub enum Resource<'a> {
36 PathBuf(PathBuf),
37 Buffer(RwLockWriteGuard<'a, Vec<u8>>),
38}
39
40/// # Resource Trait that can provide the location or data for the model, and location of
41/// configuration or vocabulary resources
42pub trait ResourceProvider: Debug + Send + Sync {
43 /// Provides the local path for a resource.
44 ///
45 /// # Returns
46 ///
47 /// * `PathBuf` pointing to the resource file
48 ///
49 /// # Example
50 ///
51 /// ```no_run
52 /// use rust_bert::resources::{LocalResource, ResourceProvider};
53 /// use std::path::PathBuf;
54 /// let config_resource = LocalResource {
55 /// local_path: PathBuf::from("path/to/config.json"),
56 /// };
57 /// let config_path = config_resource.get_local_path();
58 /// ```
59 fn get_local_path(&self) -> Result<PathBuf, RustBertError>;
60
61 /// Provides access to an underlying resource.
62 ///
63 /// # Returns
64 ///
65 /// * `Resource` wrapping a representation of a resource.
66 ///
67 /// # Example
68 ///
69 /// ```no_run
70 /// use rust_bert::resources::{BufferResource, LocalResource, ResourceProvider};
71 /// ```
72 fn get_resource(&self) -> Result<Resource, RustBertError>;
73}
74
75impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T> {
76 fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
77 T::get_local_path(self)
78 }
79 fn get_resource(&self) -> Result<Resource, RustBertError> {
80 T::get_resource(self)
81 }
82}
83
84/// Load the provided `VarStore` with model weights from the provided `ResourceProvider`
85pub fn load_weights(
86 rp: &(impl ResourceProvider + ?Sized),
87 vs: &mut VarStore,
88 kind: Option<Kind>,
89 device: Device,
90) -> Result<(), RustBertError> {
91 match rp.get_resource()? {
92 Resource::Buffer(mut data) => vs.load_from_stream(std::io::Cursor::new(data.deref_mut())),
93 Resource::PathBuf(path) => vs.load(path),
94 }?;
95 cast_var_store(vs, kind, device);
96 Ok(())
97}
98
99#[cfg(feature = "remote")]
100mod remote;
101use crate::pipelines::common::cast_var_store;
102#[cfg(feature = "remote")]
103pub use remote::RemoteResource;