rust_bert/common/resources/
remote.rs

1use super::*;
2use crate::common::error::RustBertError;
3use cached_path::{Cache, Options, ProgressBar};
4use dirs::cache_dir;
5use lazy_static::lazy_static;
6use std::path::PathBuf;
7
8/// # Remote resource that will be downloaded and cached locally on demand
9#[derive(PartialEq, Eq, Clone, Debug)]
10pub struct RemoteResource {
11    /// Remote path/url for the resource
12    pub url: String,
13    /// Local subdirectory of the cache root where this resource is saved
14    pub cache_subdir: String,
15}
16
17impl RemoteResource {
18    /// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
19    /// download the resource (only declares the remote and local locations)
20    ///
21    /// # Arguments
22    ///
23    /// * `url` - `&str` Location of the remote resource
24    /// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
25    ///
26    /// # Returns
27    ///
28    /// * `RemoteResource` RemoteResource object
29    ///
30    /// # Example
31    ///
32    /// ```no_run
33    /// use rust_bert::resources::RemoteResource;
34    /// let config_resource = RemoteResource::new("http://config_json_location", "configs");
35    /// ```
36    pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
37        RemoteResource {
38            url: url.to_string(),
39            cache_subdir: cache_subdir.to_string(),
40        }
41    }
42
43    /// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
44    /// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
45    /// the remote and local locations)
46    ///
47    /// # Arguments
48    ///
49    /// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
50    ///
51    /// # Returns
52    ///
53    /// * `RemoteResource` RemoteResource object
54    ///
55    /// # Example
56    ///
57    /// ```no_run
58    /// use rust_bert::resources::RemoteResource;
59    /// let model_resource = RemoteResource::from_pretrained((
60    ///     "distilbert-sst2",
61    ///     "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
62    /// ));
63    /// ```
64    pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
65        let cache_subdir = name_url_tuple.0.to_string();
66        let url = name_url_tuple.1.to_string();
67        RemoteResource { url, cache_subdir }
68    }
69}
70
71impl ResourceProvider for RemoteResource {
72    /// Gets the local path for a remote resource.
73    ///
74    /// The remote resource is downloaded and cached. Then the path
75    /// to the local cache is returned.
76    ///
77    /// # Returns
78    ///
79    /// * `PathBuf` pointing to the resource file
80    ///
81    /// # Example
82    ///
83    /// ```no_run
84    /// use rust_bert::resources::{LocalResource, ResourceProvider};
85    /// use std::path::PathBuf;
86    /// let config_resource = LocalResource {
87    ///     local_path: PathBuf::from("path/to/config.json"),
88    /// };
89    /// let config_path = config_resource.get_local_path();
90    /// ```
91    fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
92        let cached_path = CACHE
93            .cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
94        Ok(cached_path)
95    }
96
97    /// Gets a wrapper around the local path for a remote resource.
98    ///
99    /// # Returns
100    ///
101    /// * `Resource` wrapping a `PathBuf` pointing to the resource file
102    ///
103    /// # Example
104    ///
105    /// ```no_run
106    /// use rust_bert::resources::{RemoteResource, ResourceProvider};
107    /// let config_resource = RemoteResource::new("http://config_json_location", "configs");
108    /// let config_path = config_resource.get_resource();
109    /// ```
110    fn get_resource(&self) -> Result<Resource, RustBertError> {
111        Ok(Resource::PathBuf(self.get_local_path()?))
112    }
113}
114
115lazy_static! {
116    #[derive(Copy, Clone, Debug)]
117/// # Global cache directory
118/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
119/// location. Otherwise defaults to `$XDG_CACHE_HOME/.rustbert`, or corresponding user cache for
120/// the current system.
121    pub static ref CACHE: Cache = Cache::builder()
122        .dir(_get_cache_directory())
123        .progress_bar(Some(ProgressBar::Light))
124        .build().unwrap();
125}
126
127fn _get_cache_directory() -> PathBuf {
128    match std::env::var("RUSTBERT_CACHE") {
129        Ok(value) => PathBuf::from(value),
130        Err(_) => {
131            let mut home = cache_dir().unwrap();
132            home.push(".rustbert");
133            home
134        }
135    }
136}