treasury_import/
lib.rs

1//! Contains everything that is required to create treasury importers library.
2//!
3//!
4//! # Usage
5//!
6//! ```
7//! struct FooImporter;
8//!
9//! impl treasury_import::Importer for FooImporter {
10//!     fn import(
11//!         &self,
12//!         source: &std::path::Path,
13//!         output: &std::path::Path,
14//!         _sources: &impl treasury_import::Sources,
15//!         _dependencies: &impl treasury_import::Dependencies,
16//!     ) -> Result<(), treasury_import::ImportError> {
17//!         match std::fs::copy(source, output) {
18//!           Ok(_) => Ok(()),
19//!           Err(err) => Err(treasury_import::ImportError::Other { reason: "SOMETHING WENT WRONG".to_owned() }),
20//!         }
21//!     }
22//! }
23//!
24//!
25//! // Define all required exports.
26//! treasury_import::make_treasury_importers_library! {
27//!     // [extensions list]  <name> : <source-format> -> <target-format> = <expr>;
28//!     // <expr> must have type &'static I where I: Importer
29//!     // Use `Box::leak(importer)` if importer instance cannot be constructed in constant expression.
30//!     [foo] foo : foo -> foo = &FooImporter;
31//! }
32//! ```
33
34use std::{borrow::Cow, mem::size_of, path::Path, str::Utf8Error};
35
36#[cfg(unix)]
37use std::{ffi::OsStr, os::unix::ffi::OsStrExt};
38
39#[cfg(target_os = "wasi")]
40use std::{ffi::OsStr, os::wasi::ffi::OsStrExt};
41
42#[cfg(windows)]
43use std::{
44    ffi::OsString,
45    os::windows::ffi::{OsStrExt, OsStringExt},
46};
47
48use dependencies::DependenciesFFI;
49use sources::SourcesFFI;
50
51mod dependencies;
52mod sources;
53
54pub use dependencies::Dependencies;
55pub use sources::Sources;
56pub use treasury_id::AssetId;
57
58pub const MAGIC: u32 = u32::from_le_bytes(*b"TRES");
59
60pub type MagicType = u32;
61pub const MAGIC_NAME: &'static str = "TREASURY_DYLIB_MAGIC";
62
63pub type VersionFnType = unsafe extern "C" fn() -> u32;
64pub const VERSION_FN_NAME: &'static str = "treasury_importer_ffi_version";
65
66pub type ExportImportersFnType = unsafe extern "C" fn(buffer: *mut ImporterFFI, count: u32) -> u32;
67pub const EXPORT_IMPORTERS_FN_NAME: &'static str = "treasury_export_importers";
68
69pub fn version() -> u32 {
70    let major = env!("CARGO_PKG_VERSION_MAJOR");
71    let version = major.parse().unwrap();
72    assert_ne!(
73        version,
74        u32::MAX,
75        "Major version hits u32::MAX. Oh no. Upgrade to u64",
76    );
77    version
78}
79
80const RESULT_BUF_LEN_START: usize = 1024;
81const PATH_BUF_LEN_START: usize = 1024;
82const ANY_BUF_LEN_LIMIT: usize = 65536;
83
84const REQUIRE_SOURCES: i32 = 2;
85const REQUIRE_DEPENDENCIES: i32 = 1;
86const SUCCESS: i32 = 0;
87const NOT_FOUND: i32 = -1;
88const NOT_UTF8: i32 = -2;
89const BUFFER_IS_TOO_SMALL: i32 = -3;
90const OTHER_ERROR: i32 = -6;
91
92#[cfg(any(unix, target_os = "wasi"))]
93type OsChar = u8;
94
95#[cfg(windows)]
96type OsChar = u16;
97
98#[derive(Debug)]
99pub struct Dependency {
100    pub source: String,
101    pub target: String,
102}
103
104/// Result of `Importer::import` method.
105pub enum ImportError {
106    /// Importer requires data from other sources.
107    RequireSources {
108        /// URLs relative to source path.
109        sources: Vec<String>,
110    },
111
112    /// Importer requires following dependencies.
113    RequireDependencies { dependencies: Vec<Dependency> },
114
115    /// Importer failed to import the asset.
116    Other {
117        /// Failure reason.
118        reason: String,
119    },
120}
121
122pub fn ensure_dependencies(missing: Vec<Dependency>) -> Result<(), ImportError> {
123    if missing.is_empty() {
124        Ok(())
125    } else {
126        Err(ImportError::RequireDependencies {
127            dependencies: missing,
128        })
129    }
130}
131
132pub fn ensure_sources(missing: Vec<String>) -> Result<(), ImportError> {
133    if missing.is_empty() {
134        Ok(())
135    } else {
136        Err(ImportError::RequireSources { sources: missing })
137    }
138}
139
140/// Trait for an importer.
141pub trait Importer: Send + Sync {
142    /// Reads data from `source` path and writes result at `output` path.
143    fn import(
144        &self,
145        source: &Path,
146        output: &Path,
147        sources: &impl Sources,
148        dependencies: &impl Dependencies,
149    ) -> Result<(), ImportError>;
150}
151
152#[repr(transparent)]
153struct ImporterOpaque(u8);
154
155type ImporterImportFn = unsafe extern "C" fn(
156    importer: *const ImporterOpaque,
157    source_ptr: *const OsChar,
158    source_len: u32,
159    output_ptr: *const OsChar,
160    output_len: u32,
161    sources: *mut sources::SourcesOpaque,
162    sources_get: sources::SourcesGetFn,
163    dependencies: *mut dependencies::DependenciesOpaque,
164    dependencies_get: dependencies::DependenciesGetFn,
165    result_ptr: *mut u8,
166    result_len: *mut u32,
167) -> i32;
168
169unsafe extern "C" fn importer_import_ffi<I>(
170    importer: *const ImporterOpaque,
171    source_ptr: *const OsChar,
172    source_len: u32,
173    output_ptr: *const OsChar,
174    output_len: u32,
175    sources: *mut sources::SourcesOpaque,
176    sources_get: sources::SourcesGetFn,
177    dependencies: *mut dependencies::DependenciesOpaque,
178    dependencies_get: dependencies::DependenciesGetFn,
179    result_ptr: *mut u8,
180    result_len: *mut u32,
181) -> i32
182where
183    I: Importer,
184{
185    let source = std::slice::from_raw_parts(source_ptr, source_len as usize);
186    let output = std::slice::from_raw_parts(output_ptr, output_len as usize);
187
188    #[cfg(any(unix, target_os = "wasi"))]
189    let source = OsStr::from_bytes(source);
190    #[cfg(any(unix, target_os = "wasi"))]
191    let output = OsStr::from_bytes(output);
192
193    #[cfg(windows)]
194    let source = OsString::from_wide(source);
195    #[cfg(windows)]
196    let output = OsString::from_wide(output);
197
198    let sources = SourcesFFI {
199        opaque: sources,
200        get: sources_get,
201    };
202
203    let dependencies = DependenciesFFI {
204        opaque: dependencies,
205        get: dependencies_get,
206    };
207
208    let importer = &*(importer as *const I);
209    let result = importer.import(source.as_ref(), output.as_ref(), &sources, &dependencies);
210
211    match result {
212        Ok(()) => SUCCESS,
213        Err(ImportError::RequireSources { sources }) => {
214            let len_required = sources
215                .iter()
216                .fold(0, |acc, p| acc + p.len() + size_of::<u32>())
217                + size_of::<u32>();
218
219            assert!(u32::try_from(len_required).is_ok());
220
221            if *result_len < len_required as u32 {
222                *result_len = len_required as u32;
223                return BUFFER_IS_TOO_SMALL;
224            }
225
226            std::ptr::copy_nonoverlapping(
227                (sources.len() as u32).to_le_bytes().as_ptr(),
228                result_ptr,
229                size_of::<u32>(),
230            );
231
232            let mut offset = size_of::<u32>();
233
234            for url in &sources {
235                let len = url.len();
236
237                std::ptr::copy_nonoverlapping(
238                    (len as u32).to_le_bytes().as_ptr(),
239                    result_ptr.add(offset),
240                    size_of::<u32>(),
241                );
242                offset += size_of::<u32>();
243
244                std::ptr::copy_nonoverlapping(
245                    url.as_ptr(),
246                    result_ptr.add(offset),
247                    len as u32 as usize,
248                );
249                offset += len;
250            }
251
252            debug_assert_eq!(len_required, offset);
253
254            *result_len = len_required as u32;
255            REQUIRE_SOURCES
256        }
257        Err(ImportError::RequireDependencies { dependencies }) => {
258            let len_required = dependencies.iter().fold(0, |acc, dep| {
259                acc + dep.source.len() + dep.target.len() + size_of::<u32>() * 2
260            }) + size_of::<u32>();
261
262            assert!(u32::try_from(len_required).is_ok());
263
264            if *result_len < len_required as u32 {
265                *result_len = len_required as u32;
266                return BUFFER_IS_TOO_SMALL;
267            }
268
269            std::ptr::copy_nonoverlapping(
270                (dependencies.len() as u32).to_le_bytes().as_ptr(),
271                result_ptr,
272                size_of::<u32>(),
273            );
274
275            let mut offset = size_of::<u32>();
276
277            for dep in &dependencies {
278                for s in [&dep.source, &dep.target] {
279                    let len = s.len();
280
281                    std::ptr::copy_nonoverlapping(
282                        (len as u32).to_le_bytes().as_ptr(),
283                        result_ptr.add(offset),
284                        size_of::<u32>(),
285                    );
286                    offset += size_of::<u32>();
287
288                    std::ptr::copy_nonoverlapping(
289                        s.as_ptr(),
290                        result_ptr.add(offset),
291                        len as u32 as usize,
292                    );
293                    offset += len;
294                }
295            }
296
297            debug_assert_eq!(len_required, offset);
298
299            *result_len = len_required as u32;
300            REQUIRE_DEPENDENCIES
301        }
302        Err(ImportError::Other { reason }) => {
303            if *result_len < reason.len() as u32 {
304                *result_len = reason.len() as u32;
305                return BUFFER_IS_TOO_SMALL;
306            }
307
308            let error_buf = std::slice::from_raw_parts_mut(result_ptr, reason.len());
309            error_buf.copy_from_slice(reason.as_bytes());
310            *result_len = reason.len() as u32;
311            OTHER_ERROR
312        }
313    }
314}
315
316const MAX_EXTENSION_LEN: usize = 16;
317const MAX_EXTENSION_COUNT: usize = 256;
318const MAX_FFI_NAME_LEN: usize = 256;
319
320#[repr(C)]
321pub struct ImporterFFI {
322    importer: *const ImporterOpaque,
323    import: ImporterImportFn,
324    name: [u8; MAX_FFI_NAME_LEN],
325    format: [u8; MAX_FFI_NAME_LEN],
326    target: [u8; MAX_FFI_NAME_LEN],
327    extensions: [[u8; MAX_EXTENSION_LEN]; MAX_EXTENSION_COUNT],
328}
329
330/// Exporting non thread-safe importers breaks the contract of the FFI.
331/// The potential unsoundness is covered by `load_dylib_importers` unsafety.
332/// There is no way to guarantee that dynamic library will uphold the contract,
333/// making `load_dylib_importers` inevitably unsound.
334unsafe impl Send for ImporterFFI {}
335unsafe impl Sync for ImporterFFI {}
336
337impl ImporterFFI {
338    pub fn new<'a, I>(
339        importer: &'static I,
340        name: &str,
341        format: &str,
342        target: &str,
343        extensions: &[&'a str],
344    ) -> Self
345    where
346        I: Importer,
347    {
348        let importer = importer as *const I as *const ImporterOpaque;
349
350        assert!(
351            name.len() <= MAX_FFI_NAME_LEN,
352            "Importer name should fit into {} bytes",
353            MAX_FFI_NAME_LEN
354        );
355        assert!(
356            format.len() <= MAX_FFI_NAME_LEN,
357            "Importer format should fit into {} bytes",
358            MAX_FFI_NAME_LEN
359        );
360        assert!(
361            target.len() <= MAX_FFI_NAME_LEN,
362            "Importer target should fit into {} bytes",
363            MAX_FFI_NAME_LEN
364        );
365        assert!(
366            extensions.len() < MAX_EXTENSION_COUNT,
367            "Importer should support no more than {} extensions",
368            MAX_EXTENSION_COUNT,
369        );
370        assert!(
371            extensions.iter().all(|e| e.len() < MAX_EXTENSION_LEN),
372            "Importer extensions should fit into {} bytes",
373            MAX_EXTENSION_LEN,
374        );
375
376        assert!(!name.is_empty(), "Importer name should not be empty");
377        assert!(!format.is_empty(), "Importer format should not be empty");
378        assert!(!target.is_empty(), "Importer target should not be empty");
379        assert!(
380            extensions.iter().all(|e| !e.is_empty()),
381            "Importer extensions should not be empty"
382        );
383
384        assert!(
385            !name.contains('\0'),
386            "Importer name should not contain '\\0' byte"
387        );
388        assert!(
389            !format.contains('\0'),
390            "Importer format should not contain '\\0' byte"
391        );
392        assert!(
393            !target.contains('\0'),
394            "Importer target should not contain '\\0' byte"
395        );
396        assert!(
397            extensions.iter().all(|e| !e.contains('\0')),
398            "Importer extensions should not contain '\\0' byte"
399        );
400
401        let mut name_buf = [0; MAX_FFI_NAME_LEN];
402        name_buf[..name.len()].copy_from_slice(name.as_bytes());
403
404        let mut format_buf = [0; MAX_FFI_NAME_LEN];
405        format_buf[..format.len()].copy_from_slice(format.as_bytes());
406
407        let mut target_buf = [0; MAX_FFI_NAME_LEN];
408        target_buf[..target.len()].copy_from_slice(target.as_bytes());
409
410        let mut extensions_buf = [[0; MAX_EXTENSION_LEN]; MAX_EXTENSION_COUNT];
411
412        for (i, &extension) in extensions.iter().enumerate() {
413            extensions_buf[i][..extension.len()].copy_from_slice(extension.as_bytes());
414        }
415
416        ImporterFFI {
417            importer,
418            import: importer_import_ffi::<I>,
419            name: name_buf,
420            format: format_buf,
421            target: target_buf,
422            extensions: extensions_buf,
423        }
424    }
425
426    pub fn name(&self) -> Result<&str, Utf8Error> {
427        match self.name.iter().position(|b| *b == 0) {
428            None => std::str::from_utf8(&self.name),
429            Some(i) => std::str::from_utf8(&self.name[..i]),
430        }
431    }
432
433    pub fn name_lossy(&self) -> Cow<'_, str> {
434        match self.name.iter().position(|b| *b == 0) {
435            None => String::from_utf8_lossy(&self.name),
436            Some(i) => String::from_utf8_lossy(&self.name[..i]),
437        }
438    }
439
440    pub fn format(&self) -> Result<&str, Utf8Error> {
441        match self.format.iter().position(|b| *b == 0) {
442            None => std::str::from_utf8(&self.format),
443            Some(i) => std::str::from_utf8(&self.format[..i]),
444        }
445    }
446
447    pub fn target(&self) -> Result<&str, Utf8Error> {
448        match self.target.iter().position(|b| *b == 0) {
449            None => std::str::from_utf8(&self.target),
450            Some(i) => std::str::from_utf8(&self.target[..i]),
451        }
452    }
453
454    pub fn extensions(&self) -> impl Iterator<Item = Result<&str, Utf8Error>> {
455        let iter = self
456            .extensions
457            .iter()
458            .take_while(|extension| extension[0] != 0);
459
460        iter.map(|extension| match extension.iter().position(|b| *b == 0) {
461            None => std::str::from_utf8(extension),
462            Some(i) => std::str::from_utf8(&extension[..i]),
463        })
464    }
465
466    pub fn import<'a, S, D>(
467        &self,
468        source: &Path,
469        output: &Path,
470        sources: &mut S,
471        dependencies: &mut D,
472    ) -> Result<(), ImportError>
473    where
474        S: FnMut(&str) -> Option<&'a Path> + 'a,
475        D: FnMut(&str, &str) -> Option<AssetId>,
476    {
477        let os_str = source.as_os_str();
478
479        #[cfg(any(unix, target_os = "wasi"))]
480        let source: &[u8] = os_str.as_bytes();
481
482        #[cfg(windows)]
483        let os_str_wide = os_str.encode_wide().collect::<Vec<u16>>();
484
485        #[cfg(windows)]
486        let source: &[u16] = &*os_str_wide;
487
488        let os_str = output.as_os_str();
489
490        #[cfg(any(unix, target_os = "wasi"))]
491        let output: &[u8] = os_str.as_bytes();
492
493        #[cfg(windows)]
494        let os_str_wide = os_str.encode_wide().collect::<Vec<u16>>();
495
496        #[cfg(windows)]
497        let output: &[u16] = &*os_str_wide;
498
499        let sources = SourcesFFI::new(sources);
500        let dependencies = DependenciesFFI::new(dependencies);
501
502        let mut result_buf = vec![0; RESULT_BUF_LEN_START];
503        let mut result_len = result_buf.len() as u32;
504
505        let result = loop {
506            let result = unsafe {
507                (self.import)(
508                    self.importer,
509                    source.as_ptr(),
510                    source.len() as u32,
511                    output.as_ptr(),
512                    output.len() as u32,
513                    sources.opaque,
514                    sources.get,
515                    dependencies.opaque,
516                    dependencies.get,
517                    result_buf.as_mut_ptr(),
518                    &mut result_len,
519                )
520            };
521
522            if result == BUFFER_IS_TOO_SMALL {
523                if result_len > ANY_BUF_LEN_LIMIT as u32 {
524                    return Err(ImportError::Other {
525                        reason: format!(
526                            "Result does not fit into limit '{}', '{}' required",
527                            ANY_BUF_LEN_LIMIT, result_len
528                        ),
529                    });
530                }
531
532                result_buf.resize(result_len as usize, 0);
533            }
534            break result;
535        };
536
537        match result {
538            SUCCESS => Ok(()),
539            REQUIRE_SOURCES => unsafe {
540                let mut u32buf = [0; size_of::<u32>()];
541                std::ptr::copy_nonoverlapping(
542                    result_buf[..size_of::<u32>()].as_ptr(),
543                    u32buf.as_mut_ptr(),
544                    size_of::<u32>(),
545                );
546                let count = u32::from_le_bytes(u32buf);
547
548                let mut offset = size_of::<u32>();
549
550                let mut sources = Vec::new();
551                for _ in 0..count {
552                    std::ptr::copy_nonoverlapping(
553                        result_buf[offset..][..size_of::<u32>()].as_ptr(),
554                        u32buf.as_mut_ptr(),
555                        size_of::<u32>(),
556                    );
557                    offset += size_of::<u32>();
558                    let len = u32::from_le_bytes(u32buf);
559                    let mut source = vec![0; len as usize];
560                    std::ptr::copy_nonoverlapping(
561                        result_buf[offset..][..len as usize].as_ptr(),
562                        source.as_mut_ptr(),
563                        len as usize,
564                    );
565                    offset += len as usize;
566                    match String::from_utf8(source) {
567                            Ok(source) => sources.push(source),
568                            Err(_) => return Err(ImportError::Other {
569                                reason: "`Importer::import` requires sources, but one of the sources is not UTF-8"
570                                    .to_owned(),
571                            }),
572                        }
573                }
574
575                Err(ImportError::RequireSources { sources })
576            },
577            REQUIRE_DEPENDENCIES => unsafe {
578                let mut u32buf = [0; size_of::<u32>()];
579                std::ptr::copy_nonoverlapping(
580                    result_buf[..size_of::<u32>()].as_ptr(),
581                    u32buf.as_mut_ptr(),
582                    size_of::<u32>(),
583                );
584                let count = u32::from_le_bytes(u32buf);
585                let mut offset = size_of::<u32>();
586
587                let mut dependencies = Vec::new();
588                for _ in 0..count {
589                    let mut decode_string = || {
590                        std::ptr::copy_nonoverlapping(
591                            result_buf[offset..][..size_of::<u32>()].as_ptr(),
592                            u32buf.as_mut_ptr(),
593                            size_of::<u32>(),
594                        );
595                        offset += size_of::<u32>();
596                        let len = u32::from_le_bytes(u32buf);
597
598                        let mut string = vec![0; len as usize];
599                        std::ptr::copy_nonoverlapping(
600                            result_buf[offset..][..len as usize].as_ptr(),
601                            string.as_mut_ptr(),
602                            len as usize,
603                        );
604                        offset += len as usize;
605
606                        match String::from_utf8(string) {
607                                Ok(string) => Ok(string),
608                                Err(_) => return Err(ImportError::Other { reason: "`Importer::import` requires dependencies, but one of the strings is not UTF-8".to_owned() }),
609                            }
610                    };
611
612                    let source = decode_string()?;
613                    let target = decode_string()?;
614
615                    dependencies.push(Dependency { source, target });
616                }
617
618                Err(ImportError::RequireDependencies { dependencies })
619            },
620            OTHER_ERROR => {
621                debug_assert!(result_len <= result_buf.len() as u32);
622
623                let error = &result_buf[..result_len as usize];
624                let error_lossy = String::from_utf8_lossy(error);
625
626                Err(ImportError::Other {
627                    reason: error_lossy.into_owned(),
628                })
629            }
630            _ => Err(ImportError::Other {
631                reason: format!(
632                    "Unexpected return code from `Importer::import` FFI: {}",
633                    result
634                ),
635            }),
636        }
637    }
638}
639
640/// Define exports required for an importers library.
641/// Accepts repetition of the following pattern:
642/// <optional array of extensions> <importer name> : <format string literal> -> <target string literal> = <importer expression of type [`&'static impl Importer`]">
643#[macro_export]
644macro_rules! make_treasury_importers_library {
645    ($(
646        $([$( $ext:ident ),* $(,)?])? $($name:ident).+ : $($format:ident).+ -> $($target:ident).+ = $importer:expr;
647    )*) => {
648        #[no_mangle]
649        pub static TREASURY_DYLIB_MAGIC: u32 = $crate::MAGIC;
650
651        #[no_mangle]
652        pub unsafe extern "C" fn treasury_importer_ffi_version() -> u32 {
653            $crate::version()
654        }
655
656        #[no_mangle]
657        pub unsafe extern "C" fn treasury_export_importers(buffer: *mut $crate::ImporterFFI, count: u32) -> u32 {
658            let mut len = 0;
659            let mut cap = count + 1;
660            $(
661                cap -= 1;
662                if cap > 0 {
663                    core::ptr::write(buffer.add(len as usize), $crate::ImporterFFI::new($importer, ::core::stringify!($($name).+), ::core::stringify!($($format).+), ::core::stringify!($($target).+), &[ $($(::core::stringify!($ext)),*)? ]));
664                }
665                len += 1;
666            )*
667
668            len
669        }
670    };
671}