tract_onnx/
data_resolver.rs

1use std::fs::File;
2use std::io::{BufRead, BufReader};
3use std::path::Path;
4use tract_hir::internal::*;
5
6use tract_hir::internal::TractResult;
7
8#[cfg(not(target_family = "wasm"))]
9pub fn default() -> Box<dyn ModelDataResolver> {
10    Box::new(MmapDataResolver)
11}
12
13#[cfg(target_family = "wasm")]
14pub fn default() -> Box<dyn ModelDataResolver> {
15    Box::new(FopenDataResolver)
16}
17
18pub trait ModelDataResolver {
19    fn read_bytes_from_path(
20        &self,
21        buf: &mut Vec<u8>,
22        p: &Path,
23        offset: usize,
24        length: Option<usize>,
25    ) -> TractResult<()>;
26}
27
28pub struct FopenDataResolver;
29
30impl ModelDataResolver for FopenDataResolver {
31    fn read_bytes_from_path(
32        &self,
33        buf: &mut Vec<u8>,
34        p: &Path,
35        offset: usize,
36        length: Option<usize>,
37    ) -> TractResult<()> {
38        let file = File::open(p).with_context(|| format!("Opening {p:?}"))?;
39        let file_size = file.metadata()?.len() as usize;
40        let length = length.unwrap_or(file_size - offset);
41        buf.reserve(length);
42
43        let mut reader = BufReader::new(file);
44        reader.seek_relative(offset as i64)?;
45        while reader.fill_buf()?.len() > 0 {
46            let num_read = std::cmp::min(reader.buffer().len(), length - buf.len());
47            buf.extend_from_slice(&reader.buffer()[..num_read]);
48            if buf.len() == length {
49                break;
50            }
51            reader.consume(reader.buffer().len());
52        }
53        Ok(())
54    }
55}
56
57pub struct MmapDataResolver;
58
59impl ModelDataResolver for MmapDataResolver {
60    fn read_bytes_from_path(
61        &self,
62        buf: &mut Vec<u8>,
63        p: &Path,
64        offset: usize,
65        length: Option<usize>,
66    ) -> TractResult<()> {
67        let file = File::open(p).with_context(|| format!("Opening {p:?}"))?;
68        let mmap = unsafe { memmap2::Mmap::map(&file)? };
69        match length {
70            Some(length) => buf.extend_from_slice(&mmap[offset..offset + length]),
71            None => buf.extend_from_slice(&mmap[offset..]),
72        }
73        Ok(())
74    }
75}