treasury_import/
sources.rs

1#[cfg(unix)]
2use std::os::unix::ffi::{OsStrExt, OsStringExt};
3
4#[cfg(target_os = "wasi")]
5use std::os::wasi::ffi::{OsStrExt, OsStringExt};
6
7#[cfg(windows)]
8use std::os::windows::ffi::{OsStrExt, OsStringExt};
9use std::{
10    ffi::OsString,
11    path::{Path, PathBuf},
12};
13
14use crate::{
15    OsChar, ANY_BUF_LEN_LIMIT, BUFFER_IS_TOO_SMALL, NOT_FOUND, NOT_UTF8, PATH_BUF_LEN_START,
16    SUCCESS,
17};
18
19pub trait Sources {
20    /// Get data from specified source.
21    fn get(&self, source: &str) -> Result<Option<PathBuf>, String>;
22
23    fn get_or_append(
24        &self,
25        source: &str,
26        missing: &mut Vec<String>,
27    ) -> Result<Option<PathBuf>, String> {
28        match self.get(source) {
29            Err(err) => Err(err),
30            Ok(Some(path)) => Ok(Some(path)),
31            Ok(None) => {
32                missing.push(source.to_owned());
33                Ok(None)
34            }
35        }
36    }
37}
38
39#[repr(transparent)]
40pub struct SourcesOpaque(u8);
41
42pub type SourcesGetFn = unsafe extern "C" fn(
43    sources: *mut SourcesOpaque,
44    source_ptr: *const u8,
45    source_len: u32,
46    path_ptr: *mut OsChar,
47    path_len: *mut u32,
48) -> i32;
49
50unsafe extern "C" fn sources_get_ffi<'a, F>(
51    sources: *mut SourcesOpaque,
52    source_ptr: *const u8,
53    source_len: u32,
54    path_ptr: *mut OsChar,
55    path_len: *mut u32,
56) -> i32
57where
58    F: FnMut(&str) -> Option<&'a Path> + 'a,
59{
60    let source =
61        match std::str::from_utf8(std::slice::from_raw_parts(source_ptr, source_len as usize)) {
62            Ok(source) => source,
63            Err(_) => return NOT_UTF8,
64        };
65
66    let f = sources as *mut F;
67    let f = &mut *f;
68
69    match f(source) {
70        None => return NOT_FOUND,
71        Some(path) => {
72            let os_str = path.as_os_str();
73
74            #[cfg(any(unix, target_os = "wasi"))]
75            let path: &[u8] = os_str.as_bytes();
76
77            #[cfg(windows)]
78            let os_str_wide = os_str.encode_wide().collect::<Vec<u16>>();
79
80            #[cfg(windows)]
81            let path: &[u16] = &*os_str_wide;
82
83            if *path_len < path.len() as u32 {
84                *path_len = path.len() as u32;
85                return BUFFER_IS_TOO_SMALL;
86            }
87
88            std::ptr::copy_nonoverlapping(path.as_ptr(), path_ptr, path.len() as u32 as usize);
89            *path_len = path.len() as u32;
90
91            return SUCCESS;
92        }
93    }
94}
95
96pub struct SourcesFFI {
97    pub opaque: *mut SourcesOpaque,
98    pub get: SourcesGetFn,
99}
100
101impl SourcesFFI {
102    pub fn new<'a, F>(f: &mut F) -> Self
103    where
104        F: FnMut(&str) -> Option<&'a Path> + 'a,
105    {
106        SourcesFFI {
107            opaque: f as *const F as _,
108            get: sources_get_ffi::<F>,
109        }
110    }
111}
112
113impl Sources for SourcesFFI {
114    fn get(&self, source: &str) -> Result<Option<PathBuf>, String> {
115        let mut path_buf = vec![0; PATH_BUF_LEN_START];
116        let mut path_len = path_buf.len() as u32;
117
118        loop {
119            let result = unsafe {
120                (self.get)(
121                    self.opaque,
122                    source.as_ptr(),
123                    source.len() as u32,
124                    path_buf.as_mut_ptr(),
125                    &mut path_len,
126                )
127            };
128
129            if result == BUFFER_IS_TOO_SMALL {
130                if path_len > ANY_BUF_LEN_LIMIT as u32 {
131                    return Err(format!(
132                        "Source path does not fit into limit '{}', '{}' required",
133                        ANY_BUF_LEN_LIMIT, path_len
134                    ));
135                }
136
137                path_buf.resize(path_len as usize, 0);
138                continue;
139            }
140
141            return match result {
142                SUCCESS => {
143                    #[cfg(any(unix, target_os = "wasi"))]
144                    let path = OsString::from_vec(path_buf).into();
145
146                    #[cfg(windows)]
147                    let path = OsString::from_wide(&path_buf).into();
148
149                    Ok(Some(path))
150                }
151                NOT_FOUND => return Ok(None),
152                NOT_UTF8 => Err(format!("Source is not UTF8 while stored in `str`")),
153                _ => Err(format!(
154                    "Unexpected return code from `Sources::get` FFI: {}",
155                    result
156                )),
157            };
158        }
159    }
160}