treasury_import/
sources.rs1#[cfg(unix)]
2use std::os::unix::ffi::{OsStrExt, OsStringExt};
3
4#[cfg(target_os = "wasi")]
5use std::os::wasi::ffi::{OsStrExt, OsStringExt};
6
7#[cfg(windows)]
8use std::os::windows::ffi::{OsStrExt, OsStringExt};
9use std::{
10 ffi::OsString,
11 path::{Path, PathBuf},
12};
13
14use crate::{
15 OsChar, ANY_BUF_LEN_LIMIT, BUFFER_IS_TOO_SMALL, NOT_FOUND, NOT_UTF8, PATH_BUF_LEN_START,
16 SUCCESS,
17};
18
19pub trait Sources {
20 fn get(&self, source: &str) -> Result<Option<PathBuf>, String>;
22
23 fn get_or_append(
24 &self,
25 source: &str,
26 missing: &mut Vec<String>,
27 ) -> Result<Option<PathBuf>, String> {
28 match self.get(source) {
29 Err(err) => Err(err),
30 Ok(Some(path)) => Ok(Some(path)),
31 Ok(None) => {
32 missing.push(source.to_owned());
33 Ok(None)
34 }
35 }
36 }
37}
38
39#[repr(transparent)]
40pub struct SourcesOpaque(u8);
41
42pub type SourcesGetFn = unsafe extern "C" fn(
43 sources: *mut SourcesOpaque,
44 source_ptr: *const u8,
45 source_len: u32,
46 path_ptr: *mut OsChar,
47 path_len: *mut u32,
48) -> i32;
49
50unsafe extern "C" fn sources_get_ffi<'a, F>(
51 sources: *mut SourcesOpaque,
52 source_ptr: *const u8,
53 source_len: u32,
54 path_ptr: *mut OsChar,
55 path_len: *mut u32,
56) -> i32
57where
58 F: FnMut(&str) -> Option<&'a Path> + 'a,
59{
60 let source =
61 match std::str::from_utf8(std::slice::from_raw_parts(source_ptr, source_len as usize)) {
62 Ok(source) => source,
63 Err(_) => return NOT_UTF8,
64 };
65
66 let f = sources as *mut F;
67 let f = &mut *f;
68
69 match f(source) {
70 None => return NOT_FOUND,
71 Some(path) => {
72 let os_str = path.as_os_str();
73
74 #[cfg(any(unix, target_os = "wasi"))]
75 let path: &[u8] = os_str.as_bytes();
76
77 #[cfg(windows)]
78 let os_str_wide = os_str.encode_wide().collect::<Vec<u16>>();
79
80 #[cfg(windows)]
81 let path: &[u16] = &*os_str_wide;
82
83 if *path_len < path.len() as u32 {
84 *path_len = path.len() as u32;
85 return BUFFER_IS_TOO_SMALL;
86 }
87
88 std::ptr::copy_nonoverlapping(path.as_ptr(), path_ptr, path.len() as u32 as usize);
89 *path_len = path.len() as u32;
90
91 return SUCCESS;
92 }
93 }
94}
95
96pub struct SourcesFFI {
97 pub opaque: *mut SourcesOpaque,
98 pub get: SourcesGetFn,
99}
100
101impl SourcesFFI {
102 pub fn new<'a, F>(f: &mut F) -> Self
103 where
104 F: FnMut(&str) -> Option<&'a Path> + 'a,
105 {
106 SourcesFFI {
107 opaque: f as *const F as _,
108 get: sources_get_ffi::<F>,
109 }
110 }
111}
112
113impl Sources for SourcesFFI {
114 fn get(&self, source: &str) -> Result<Option<PathBuf>, String> {
115 let mut path_buf = vec![0; PATH_BUF_LEN_START];
116 let mut path_len = path_buf.len() as u32;
117
118 loop {
119 let result = unsafe {
120 (self.get)(
121 self.opaque,
122 source.as_ptr(),
123 source.len() as u32,
124 path_buf.as_mut_ptr(),
125 &mut path_len,
126 )
127 };
128
129 if result == BUFFER_IS_TOO_SMALL {
130 if path_len > ANY_BUF_LEN_LIMIT as u32 {
131 return Err(format!(
132 "Source path does not fit into limit '{}', '{}' required",
133 ANY_BUF_LEN_LIMIT, path_len
134 ));
135 }
136
137 path_buf.resize(path_len as usize, 0);
138 continue;
139 }
140
141 return match result {
142 SUCCESS => {
143 #[cfg(any(unix, target_os = "wasi"))]
144 let path = OsString::from_vec(path_buf).into();
145
146 #[cfg(windows)]
147 let path = OsString::from_wide(&path_buf).into();
148
149 Ok(Some(path))
150 }
151 NOT_FOUND => return Ok(None),
152 NOT_UTF8 => Err(format!("Source is not UTF8 while stored in `str`")),
153 _ => Err(format!(
154 "Unexpected return code from `Sources::get` FFI: {}",
155 result
156 )),
157 };
158 }
159 }
160}