tract_onnx/
data_resolver.rs1use std::fs::File;
2use std::io::{BufRead, BufReader};
3use std::path::Path;
4use tract_hir::internal::*;
5
6use tract_hir::internal::TractResult;
7
8#[cfg(not(target_family = "wasm"))]
9pub fn default() -> Box<dyn ModelDataResolver> {
10 Box::new(MmapDataResolver)
11}
12
13#[cfg(target_family = "wasm")]
14pub fn default() -> Box<dyn ModelDataResolver> {
15 Box::new(FopenDataResolver)
16}
17
18pub trait ModelDataResolver {
19 fn read_bytes_from_path(
20 &self,
21 buf: &mut Vec<u8>,
22 p: &Path,
23 offset: usize,
24 length: Option<usize>,
25 ) -> TractResult<()>;
26}
27
28pub struct FopenDataResolver;
29
30impl ModelDataResolver for FopenDataResolver {
31 fn read_bytes_from_path(
32 &self,
33 buf: &mut Vec<u8>,
34 p: &Path,
35 offset: usize,
36 length: Option<usize>,
37 ) -> TractResult<()> {
38 let file = File::open(p).with_context(|| format!("Opening {p:?}"))?;
39 let file_size = file.metadata()?.len() as usize;
40 let length = length.unwrap_or(file_size - offset);
41 buf.reserve(length);
42
43 let mut reader = BufReader::new(file);
44 reader.seek_relative(offset as i64)?;
45 while reader.fill_buf()?.len() > 0 {
46 let num_read = std::cmp::min(reader.buffer().len(), length - buf.len());
47 buf.extend_from_slice(&reader.buffer()[..num_read]);
48 if buf.len() == length {
49 break;
50 }
51 reader.consume(reader.buffer().len());
52 }
53 Ok(())
54 }
55}
56
57pub struct MmapDataResolver;
58
59impl ModelDataResolver for MmapDataResolver {
60 fn read_bytes_from_path(
61 &self,
62 buf: &mut Vec<u8>,
63 p: &Path,
64 offset: usize,
65 length: Option<usize>,
66 ) -> TractResult<()> {
67 let file = File::open(p).with_context(|| format!("Opening {p:?}"))?;
68 let mmap = unsafe { memmap2::Mmap::map(&file)? };
69 match length {
70 Some(length) => buf.extend_from_slice(&mmap[offset..offset + length]),
71 None => buf.extend_from_slice(&mmap[offset..]),
72 }
73 Ok(())
74 }
75}