rust_bert/common/resources/
mod.rs

1//! # Resource definitions for model weights, vocabularies and configuration files
2//!
3//! This crate relies on the concept of Resources to access the data used by the models.
4//! This includes:
5//! - model weights
6//! - configuration files
7//! - vocabularies
8//! - (optional) merges files for BPE-based tokenizers
9//!
10//! These are expected in the pipelines configurations or are used as utilities to reference to the
11//! resource location. Two types of resources are pre-defined:
12//! - LocalResource: points to a local file
13//! - RemoteResource: points to a remote file via a URL
14//! - BufferResource: refers to a buffer that contains file contents for a resource (currently only
15//!                   usable for weights)
16//!
17//! For `LocalResource` and `RemoteResource`, the local location of the file can be retrieved using
18//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
19//! or local resource. Default implementations for a number of `RemoteResources` are available as
20//! pre-trained models in each model module.
21
22mod buffer;
23mod local;
24
25use crate::common::error::RustBertError;
26pub use buffer::BufferResource;
27pub use local::LocalResource;
28use std::fmt::Debug;
29use std::ops::DerefMut;
30use std::path::PathBuf;
31use std::sync::RwLockWriteGuard;
32use tch::nn::VarStore;
33use tch::{Device, Kind};
34
35pub enum Resource<'a> {
36    PathBuf(PathBuf),
37    Buffer(RwLockWriteGuard<'a, Vec<u8>>),
38}
39
40/// # Resource Trait that can provide the location or data for the model, and location of
41/// configuration or vocabulary resources
42pub trait ResourceProvider: Debug + Send + Sync {
43    /// Provides the local path for a resource.
44    ///
45    /// # Returns
46    ///
47    /// * `PathBuf` pointing to the resource file
48    ///
49    /// # Example
50    ///
51    /// ```no_run
52    /// use rust_bert::resources::{LocalResource, ResourceProvider};
53    /// use std::path::PathBuf;
54    /// let config_resource = LocalResource {
55    ///     local_path: PathBuf::from("path/to/config.json"),
56    /// };
57    /// let config_path = config_resource.get_local_path();
58    /// ```
59    fn get_local_path(&self) -> Result<PathBuf, RustBertError>;
60
61    /// Provides access to an underlying resource.
62    ///
63    /// # Returns
64    ///
65    /// * `Resource` wrapping a representation of a resource.
66    ///
67    /// # Example
68    ///
69    /// ```no_run
70    /// use rust_bert::resources::{BufferResource, LocalResource, ResourceProvider};
71    /// ```
72    fn get_resource(&self) -> Result<Resource, RustBertError>;
73}
74
75impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T> {
76    fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
77        T::get_local_path(self)
78    }
79    fn get_resource(&self) -> Result<Resource, RustBertError> {
80        T::get_resource(self)
81    }
82}
83
84/// Load the provided `VarStore` with model weights from the provided `ResourceProvider`
85pub fn load_weights(
86    rp: &(impl ResourceProvider + ?Sized),
87    vs: &mut VarStore,
88    kind: Option<Kind>,
89    device: Device,
90) -> Result<(), RustBertError> {
91    match rp.get_resource()? {
92        Resource::Buffer(mut data) => vs.load_from_stream(std::io::Cursor::new(data.deref_mut())),
93        Resource::PathBuf(path) => vs.load(path),
94    }?;
95    cast_var_store(vs, kind, device);
96    Ok(())
97}
98
99#[cfg(feature = "remote")]
100mod remote;
101use crate::pipelines::common::cast_var_store;
102#[cfg(feature = "remote")]
103pub use remote::RemoteResource;