rust_bert/common/resources/remote.rs
1use super::*;
2use crate::common::error::RustBertError;
3use cached_path::{Cache, Options, ProgressBar};
4use dirs::cache_dir;
5use lazy_static::lazy_static;
6use std::path::PathBuf;
7
8/// # Remote resource that will be downloaded and cached locally on demand
9#[derive(PartialEq, Eq, Clone, Debug)]
10pub struct RemoteResource {
11 /// Remote path/url for the resource
12 pub url: String,
13 /// Local subdirectory of the cache root where this resource is saved
14 pub cache_subdir: String,
15}
16
17impl RemoteResource {
18 /// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
19 /// download the resource (only declares the remote and local locations)
20 ///
21 /// # Arguments
22 ///
23 /// * `url` - `&str` Location of the remote resource
24 /// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
25 ///
26 /// # Returns
27 ///
28 /// * `RemoteResource` RemoteResource object
29 ///
30 /// # Example
31 ///
32 /// ```no_run
33 /// use rust_bert::resources::RemoteResource;
34 /// let config_resource = RemoteResource::new("http://config_json_location", "configs");
35 /// ```
36 pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
37 RemoteResource {
38 url: url.to_string(),
39 cache_subdir: cache_subdir.to_string(),
40 }
41 }
42
43 /// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
44 /// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
45 /// the remote and local locations)
46 ///
47 /// # Arguments
48 ///
49 /// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
50 ///
51 /// # Returns
52 ///
53 /// * `RemoteResource` RemoteResource object
54 ///
55 /// # Example
56 ///
57 /// ```no_run
58 /// use rust_bert::resources::RemoteResource;
59 /// let model_resource = RemoteResource::from_pretrained((
60 /// "distilbert-sst2",
61 /// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
62 /// ));
63 /// ```
64 pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
65 let cache_subdir = name_url_tuple.0.to_string();
66 let url = name_url_tuple.1.to_string();
67 RemoteResource { url, cache_subdir }
68 }
69}
70
71impl ResourceProvider for RemoteResource {
72 /// Gets the local path for a remote resource.
73 ///
74 /// The remote resource is downloaded and cached. Then the path
75 /// to the local cache is returned.
76 ///
77 /// # Returns
78 ///
79 /// * `PathBuf` pointing to the resource file
80 ///
81 /// # Example
82 ///
83 /// ```no_run
84 /// use rust_bert::resources::{LocalResource, ResourceProvider};
85 /// use std::path::PathBuf;
86 /// let config_resource = LocalResource {
87 /// local_path: PathBuf::from("path/to/config.json"),
88 /// };
89 /// let config_path = config_resource.get_local_path();
90 /// ```
91 fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
92 let cached_path = CACHE
93 .cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
94 Ok(cached_path)
95 }
96
97 /// Gets a wrapper around the local path for a remote resource.
98 ///
99 /// # Returns
100 ///
101 /// * `Resource` wrapping a `PathBuf` pointing to the resource file
102 ///
103 /// # Example
104 ///
105 /// ```no_run
106 /// use rust_bert::resources::{RemoteResource, ResourceProvider};
107 /// let config_resource = RemoteResource::new("http://config_json_location", "configs");
108 /// let config_path = config_resource.get_resource();
109 /// ```
110 fn get_resource(&self) -> Result<Resource, RustBertError> {
111 Ok(Resource::PathBuf(self.get_local_path()?))
112 }
113}
114
115lazy_static! {
116 #[derive(Copy, Clone, Debug)]
117/// # Global cache directory
118/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
119/// location. Otherwise defaults to `$XDG_CACHE_HOME/.rustbert`, or corresponding user cache for
120/// the current system.
121 pub static ref CACHE: Cache = Cache::builder()
122 .dir(_get_cache_directory())
123 .progress_bar(Some(ProgressBar::Light))
124 .build().unwrap();
125}
126
127fn _get_cache_directory() -> PathBuf {
128 match std::env::var("RUSTBERT_CACHE") {
129 Ok(value) => PathBuf::from(value),
130 Err(_) => {
131 let mut home = cache_dir().unwrap();
132 home.push(".rustbert");
133 home
134 }
135 }
136}