1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use super::*;
use crate::common::error::RustBertError;
use cached_path::{Cache, Options, ProgressBar};
use dirs::cache_dir;
use lazy_static::lazy_static;
use std::path::PathBuf;

/// # Remote resource that will be downloaded and cached locally on demand
#[derive(PartialEq, Clone)]
pub struct RemoteResource {
    /// Remote path/url for the resource
    pub url: String,
    /// Local subdirectory of the cache root where this resource is saved
    pub cache_subdir: String,
}

impl RemoteResource {
    /// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
    /// download the resource (only declares the remote and local locations)
    ///
    /// # Arguments
    ///
    /// * `url` - `&str` Location of the remote resource
    /// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
    ///
    /// # Returns
    ///
    /// * `RemoteResource` RemoteResource object
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_bert::resources::RemoteResource;
    /// let config_resource = RemoteResource::new("configs", "http://config_json_location");
    /// ```
    pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
        RemoteResource {
            url: url.to_string(),
            cache_subdir: cache_subdir.to_string(),
        }
    }

    /// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
    /// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
    /// the remote and local locations)
    ///
    /// # Arguments
    ///
    /// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
    ///
    /// # Returns
    ///
    /// * `RemoteResource` RemoteResource object
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_bert::resources::RemoteResource;
    /// let model_resource = RemoteResource::from_pretrained((
    ///     "distilbert-sst2",
    ///     "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
    /// ));
    /// ```
    pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
        let cache_subdir = name_url_tuple.0.to_string();
        let url = name_url_tuple.1.to_string();
        RemoteResource { url, cache_subdir }
    }
}

impl ResourceProvider for RemoteResource {
    /// Gets the local path for a remote resource.
    ///
    /// The remote resource is downloaded and cached. Then the path
    /// to the local cache is returned.
    ///
    /// # Returns
    ///
    /// * `PathBuf` pointing to the resource file
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_bert::resources::{LocalResource, ResourceProvider};
    /// use std::path::PathBuf;
    /// let config_resource = LocalResource {
    ///     local_path: PathBuf::from("path/to/config.json"),
    /// };
    /// let config_path = config_resource.get_local_path();
    /// ```
    fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
        let cached_path = CACHE
            .cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
        Ok(cached_path)
    }
}

lazy_static! {
    #[derive(Copy, Clone, Debug)]
/// # Global cache directory
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
/// location. Otherwise defaults to `$XDG_CACHE_HOME/.rustbert`, or corresponding user cache for
/// the current system.
    pub static ref CACHE: Cache = Cache::builder()
        .dir(_get_cache_directory())
        .progress_bar(Some(ProgressBar::Light))
        .build().unwrap();
}

fn _get_cache_directory() -> PathBuf {
    match std::env::var("RUSTBERT_CACHE") {
        Ok(value) => PathBuf::from(value),
        Err(_) => {
            let mut home = cache_dir().unwrap();
            home.push(".rustbert");
            home
        }
    }
}