rust_bert/common/resources/
buffer.rs

1use crate::common::error::RustBertError;
2use crate::resources::{Resource, ResourceProvider};
3use std::path::PathBuf;
4use std::sync::{Arc, RwLock};
5
6/// # In-memory raw buffer resource
7#[derive(Debug)]
8pub struct BufferResource {
9    /// The data representing the underlying resource
10    pub data: Arc<RwLock<Vec<u8>>>,
11}
12
13impl ResourceProvider for BufferResource {
14    /// Not implemented for this resource type
15    ///
16    /// # Returns
17    ///
18    /// * `RustBertError::UnsupportedError`
19    fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
20        Err(RustBertError::UnsupportedError)
21    }
22
23    /// Gets a wrapper referring to the in-memory resource.
24    ///
25    /// # Returns
26    ///
27    /// * `Resource` referring to the resource data
28    ///
29    /// # Example
30    ///
31    /// ```no_run
32    /// use rust_bert::resources::{BufferResource, ResourceProvider};
33    /// let data = std::fs::read("path/to/rust_model.ot").unwrap();
34    /// let weights_resource = BufferResource::from(data);
35    /// let weights = weights_resource.get_resource();
36    /// ```
37    fn get_resource(&self) -> Result<Resource, RustBertError> {
38        Ok(Resource::Buffer(self.data.write().unwrap()))
39    }
40}
41
42impl From<Vec<u8>> for BufferResource {
43    fn from(data: Vec<u8>) -> Self {
44        Self {
45            data: Arc::new(RwLock::new(data)),
46        }
47    }
48}
49
50impl From<Vec<u8>> for Box<dyn ResourceProvider> {
51    fn from(data: Vec<u8>) -> Self {
52        Box::new(BufferResource {
53            data: Arc::new(RwLock::new(data)),
54        })
55    }
56}
57
58impl From<RwLock<Vec<u8>>> for BufferResource {
59    fn from(lock: RwLock<Vec<u8>>) -> Self {
60        Self {
61            data: Arc::new(lock),
62        }
63    }
64}
65
66impl From<RwLock<Vec<u8>>> for Box<dyn ResourceProvider> {
67    fn from(lock: RwLock<Vec<u8>>) -> Self {
68        Box::new(BufferResource {
69            data: Arc::new(lock),
70        })
71    }
72}