1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#[cfg(unix)]
use std::os::unix::ffi::{OsStrExt, OsStringExt};

#[cfg(target_os = "wasi")]
use std::os::wasi::ffi::{OsStrExt, OsStringExt};

#[cfg(windows)]
use std::os::windows::ffi::{OsStrExt, OsStringExt};
use std::{
    ffi::OsString,
    path::{Path, PathBuf},
};

use crate::{
    OsChar, ANY_BUF_LEN_LIMIT, BUFFER_IS_TOO_SMALL, NOT_FOUND, NOT_UTF8, PATH_BUF_LEN_START,
    SUCCESS,
};

pub trait Sources {
    /// Get data from specified source.
    fn get(&self, source: &str) -> Result<Option<PathBuf>, String>;

    fn get_or_append(
        &self,
        source: &str,
        missing: &mut Vec<String>,
    ) -> Result<Option<PathBuf>, String> {
        match self.get(source) {
            Err(err) => Err(err),
            Ok(Some(path)) => Ok(Some(path)),
            Ok(None) => {
                missing.push(source.to_owned());
                Ok(None)
            }
        }
    }
}

#[repr(transparent)]
pub struct SourcesOpaque(u8);

pub type SourcesGetFn = unsafe extern "C" fn(
    sources: *const SourcesOpaque,
    source_ptr: *const u8,
    source_len: u32,
    path_ptr: *mut OsChar,
    path_len: *mut u32,
) -> i32;

unsafe extern "C" fn sources_get_ffi<'a, F>(
    sources: *const SourcesOpaque,
    source_ptr: *const u8,
    source_len: u32,
    path_ptr: *mut OsChar,
    path_len: *mut u32,
) -> i32
where
    F: Fn(&str) -> Option<&'a Path> + 'a,
{
    let source =
        match std::str::from_utf8(std::slice::from_raw_parts(source_ptr, source_len as usize)) {
            Ok(source) => source,
            Err(_) => return NOT_UTF8,
        };

    let f = sources as *const F;
    let f = &*f;

    match f(source) {
        None => return NOT_FOUND,
        Some(path) => {
            let os_str = path.as_os_str();

            #[cfg(any(unix, target_os = "wasi"))]
            let path: &[u8] = os_str.as_bytes();

            #[cfg(windows)]
            let os_str_wide = os_str.encode_wide().collect::<Vec<u16>>();

            #[cfg(windows)]
            let path: &[u16] = &*os_str_wide;

            if *path_len < path.len() as u32 {
                *path_len = path.len() as u32;
                return BUFFER_IS_TOO_SMALL;
            }

            std::ptr::copy_nonoverlapping(path.as_ptr(), path_ptr, path.len() as u32 as usize);
            *path_len = path.len() as u32;

            return SUCCESS;
        }
    }
}

pub struct SourcesFFI {
    pub opaque: *const SourcesOpaque,
    pub get: SourcesGetFn,
}

impl SourcesFFI {
    pub fn new<'a, F>(f: &F) -> Self
    where
        F: Fn(&str) -> Option<&'a Path> + 'a,
    {
        SourcesFFI {
            opaque: f as *const F as _,
            get: sources_get_ffi::<F>,
        }
    }
}

impl Sources for SourcesFFI {
    fn get(&self, source: &str) -> Result<Option<PathBuf>, String> {
        let mut path_buf = vec![0; PATH_BUF_LEN_START];
        let mut path_len = path_buf.len() as u32;

        loop {
            let result = unsafe {
                (self.get)(
                    self.opaque,
                    source.as_ptr(),
                    source.len() as u32,
                    path_buf.as_mut_ptr(),
                    &mut path_len,
                )
            };

            if result == BUFFER_IS_TOO_SMALL {
                if path_len > ANY_BUF_LEN_LIMIT as u32 {
                    return Err(format!(
                        "Source path does not fit into limit '{}', '{}' required",
                        ANY_BUF_LEN_LIMIT, path_len
                    ));
                }

                path_buf.resize(path_len as usize, 0);
                continue;
            }

            return match result {
                SUCCESS => {
                    #[cfg(any(unix, target_os = "wasi"))]
                    let path = OsString::from_vec(path_buf).into();

                    #[cfg(windows)]
                    let path = OsString::from_wide(&path_buf).into();

                    Ok(Some(path))
                }
                NOT_FOUND => return Ok(None),
                NOT_UTF8 => Err(format!("Source is not UTF8 while stored in `str`")),
                _ => Err(format!(
                    "Unexpected return code from `Sources::get` FFI: {}",
                    result
                )),
            };
        }
    }
}