1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
//! # Resource definitions for model weights, vocabularies and configuration files //! //! This crate relies on the concept of Resources to access the files used by the models. //! This includes: //! - model weights //! - configuration files //! - vocabularies //! - (optional) merges files for BPE-based tokenizers //! //! These are expected in the pipelines configurations or are used as utilities to reference to the //! resource location. Two types of resources exist: //! - LocalResource: points to a local file //! - RemoteResource: points to a remote file via a URL and a local cached file //! //! For both types of resources, the local location of teh file can be retrieved using //! `get_local_path`, allowing to reference the resource file location regardless if it is a remote //! or local resource. Default implementations for a number of `RemoteResources` are available as //! pre-trained models in each model module. use crate::common::error::RustBertError; use cached_path::{Cache, Options, ProgressBar}; use lazy_static::lazy_static; use std::env; use std::path::PathBuf; extern crate dirs; /// # Resource Enum pointing to model, configuration or vocabulary resources /// Can be of type: /// - LocalResource /// - RemoteResource #[derive(PartialEq, Clone)] pub enum Resource { Local(LocalResource), Remote(RemoteResource), } impl Resource { /// Gets the local path for a given resource. /// /// If the resource is a remote resource, it is downloaded and cached. Then the path /// to the local cache is returned. /// /// # Returns /// /// * `PathBuf` pointing to the resource file /// /// # Example /// /// ```no_run /// use rust_bert::resources::{LocalResource, Resource}; /// use std::path::PathBuf; /// let config_resource = Resource::Local(LocalResource { /// local_path: PathBuf::from("path/to/config.json"), /// }); /// let config_path = config_resource.get_local_path(); /// ``` pub fn get_local_path(&self) -> Result<PathBuf, RustBertError> { match self { Resource::Local(resource) => Ok(resource.local_path.clone()), Resource::Remote(resource) => { let cached_path = CACHE.cached_path_with_options( &resource.url, &Options::default().subdir(&resource.cache_subdir), )?; Ok(cached_path) } } } } /// # Local resource #[derive(PartialEq, Clone)] pub struct LocalResource { /// Local path for the resource pub local_path: PathBuf, } /// # Remote resource #[derive(PartialEq, Clone)] pub struct RemoteResource { /// Remote path/url for the resource pub url: String, /// Local subdirectory of the cache root where this resource is saved pub cache_subdir: String, } impl RemoteResource { /// Creates a new RemoteResource from an URL and a custom local path. Note that this does not /// download the resource (only declares the remote and local locations) /// /// # Arguments /// /// * `url` - `&str` Location of the remote resource /// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to /// /// # Returns /// /// * `RemoteResource` RemoteResource object /// /// # Example /// /// ```no_run /// use rust_bert::resources::{RemoteResource, Resource}; /// let config_resource = Resource::Remote(RemoteResource::new( /// "configs", /// "http://config_json_location", /// )); /// ``` pub fn new(url: &str, cache_subdir: &str) -> RemoteResource { RemoteResource { url: url.to_string(), cache_subdir: cache_subdir.to_string(), } } /// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to /// ~/.cache/.rusbert/model_name. Note that this does not download the resource (only declares /// the remote and local locations) /// /// # Arguments /// /// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource /// /// # Returns /// /// * `RemoteResource` RemoteResource object /// /// # Example /// /// ```no_run /// use rust_bert::resources::{RemoteResource, Resource}; /// let model_resource = Resource::Remote(RemoteResource::from_pretrained(( /// "distilbert-sst2", /// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot", /// ))); /// ``` pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource { let cache_subdir = name_url_tuple.0.to_string(); let url = name_url_tuple.1.to_string(); RemoteResource { url, cache_subdir } } } lazy_static! { #[derive(Copy, Clone, Debug)] /// # Global cache directory /// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that /// location. Otherwise defaults to `~/.cache/.rustbert`. pub static ref CACHE: Cache = Cache::builder() .dir(_get_cache_directory()) .progress_bar(Some(ProgressBar::Light)) .build().unwrap(); } fn _get_cache_directory() -> PathBuf { match env::var("RUSTBERT_CACHE") { Ok(value) => PathBuf::from(value), Err(_) => { let mut home = dirs::home_dir().unwrap(); home.push(".cache"); home.push(".rustbert"); home } } } #[deprecated( since = "0.9.1", note = "Please use `Resource.get_local_path()` instead" )] /// # (Download) the resource and return a path to its local path /// This function will download remote resource to their local path if they do not exist yet. /// Then for both `LocalResource` and `RemoteResource`, it will the local path to the resource. /// For `LocalResource` only the resource path is returned. /// /// # Arguments /// /// * `resource` - Pointer to the `&Resource` to optionally download and get the local path. /// /// # Returns /// /// * `&PathBuf` Local path for the resource /// /// # Example /// /// ```no_run /// use rust_bert::resources::{RemoteResource, Resource}; /// let model_resource = Resource::Remote(RemoteResource::from_pretrained(( /// "distilbert-sst2/model.ot", /// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot", /// ))); /// let local_path = model_resource.get_local_path(); /// ``` pub fn download_resource(resource: &Resource) -> Result<PathBuf, RustBertError> { resource.get_local_path() }