winsfs_core/io/shuffle/
header.rs

1use std::{io, iter::once, mem::size_of};
2
3use super::{to_u16, to_u32, to_u64, to_usize};
4
5/// The magic number written as the first 8 bytes of a pseudo-shuffled SAF file.
6pub const MAGIC_NUMBER: [u8; 8] = *b"safvshuf";
7
8/// The header for a pseudo-shuffled SAF file.
9///
10/// The header is written at the top of the file, and contains information about the size and layout
11/// of the file.
12#[derive(Clone, Debug, Eq, PartialEq)]
13pub struct Header {
14    sites: usize,
15    shape: Vec<usize>,
16    blocks: usize,
17}
18
19impl Header {
20    /// Returns the number of blocks used for shuffling.
21    pub fn blocks(&self) -> usize {
22        self.blocks
23    }
24
25    /// Returns the size (in bytes) of the data that the file is expected to contain.
26    pub fn data_size(&self) -> usize {
27        to_usize(self.sites) * self.width() * size_of::<f32>()
28    }
29
30    /// Returns an iterator over the byte offset of the start of each block.
31    pub(super) fn block_offsets(&self) -> impl Iterator<Item = usize> {
32        once(self.header_size())
33            .chain(self.block_sizes().take(self.blocks - 1))
34            .scan(0, |acc, x| {
35                *acc += x;
36                Some(*acc)
37            })
38    }
39
40    /// Returns an iterator over the number of sites per block.
41    pub(super) fn block_sites(&self) -> impl Iterator<Item = usize> {
42        let div = self.sites / self.blocks;
43        let rem = self.sites % self.blocks;
44
45        (0..self.blocks).map(move |i| if i < rem { div + 1 } else { div })
46    }
47
48    /// Returns an iterator over the number of bytes per block.
49    pub(super) fn block_sizes(&self) -> impl Iterator<Item = usize> {
50        let width = self.width();
51        self.block_sites()
52            .map(move |sites| sites * width * size_of::<f32>())
53    }
54
55    /// Returns the size (in bytes) of the entire file.
56    ///
57    /// This is equal to the size of the header and the size of the data.
58    pub fn file_size(&self) -> usize {
59        self.header_size() + self.data_size()
60    }
61
62    /// Returns the size (in bytes) of the header as it will be written to a file.
63    pub fn header_size(&self) -> usize {
64        let shape_size = size_of::<u8>() + self.shape.len() * size_of::<u32>();
65
66        size_of::<[u8; 8]>() + size_of::<u64>() + shape_size + size_of::<u16>()
67    }
68
69    /// Creates a new header.
70    pub fn new(sites: usize, shape: Vec<usize>, blocks: usize) -> Self {
71        Self {
72            sites,
73            shape,
74            blocks,
75        }
76    }
77
78    /// Reads the header, including the magic number, from a reader.
79    pub(super) fn read<R>(mut reader: R) -> io::Result<Self>
80    where
81        R: io::Read,
82    {
83        let mut magic = [0; MAGIC_NUMBER.len()];
84        reader.read_exact(&mut magic)?;
85
86        if magic != MAGIC_NUMBER {
87            return Err(io::Error::new(
88                io::ErrorKind::InvalidData,
89                format!(
90                    "invalid or unsupported SAF magic number \
91                    (found '{magic:02x?}', expected '{MAGIC_NUMBER:02x?}')"
92                ),
93            ));
94        }
95
96        let mut sites_buf = [0u8; size_of::<u64>()];
97        reader.read_exact(&mut sites_buf)?;
98        let sites = to_usize(u64::from_le_bytes(sites_buf));
99
100        let mut shape_len_buf = [0u8; size_of::<u8>()];
101        reader.read_exact(&mut shape_len_buf)?;
102        let shape_len = u8::from_le_bytes(shape_len_buf);
103
104        let mut shape_buf = [0u8; size_of::<u32>()];
105        let mut shape = Vec::with_capacity(shape_len.into());
106        for _ in 0..shape_len {
107            reader.read_exact(&mut shape_buf)?;
108            shape.push(to_usize(u32::from_le_bytes(shape_buf)));
109        }
110
111        let mut blocks_buf = [0u8; size_of::<u16>()];
112        reader.read_exact(&mut blocks_buf)?;
113        let blocks = usize::from(u16::from_le_bytes(blocks_buf));
114
115        Ok(Self::new(sites, shape, blocks))
116    }
117
118    /// Returns the shape of each site in the file.
119    pub fn shape(&self) -> &[usize] {
120        &self.shape
121    }
122
123    /// Returns the number of sites in the file.
124    pub fn sites(&self) -> usize {
125        self.sites
126    }
127
128    /// Returns the width of each site, i.e. the total number of values.
129    pub(super) fn width(&self) -> usize {
130        self.shape.iter().sum()
131    }
132
133    /// Writes the header, including the magic number, to a writer.
134    pub(super) fn write<W>(&self, mut writer: W) -> io::Result<()>
135    where
136        W: io::Write,
137    {
138        writer.write_all(&MAGIC_NUMBER)?;
139
140        let sites = to_u64(self.sites);
141        writer.write_all(&sites.to_le_bytes())?;
142
143        let shape_len: u8 = self.shape.len().try_into().map_err(|_| {
144            io::Error::new(
145                io::ErrorKind::InvalidInput,
146                "number of header dimensions exceeds {u8::MAX}",
147            )
148        })?;
149        writer.write_all(&shape_len.to_le_bytes())?;
150        for &v in self.shape.iter() {
151            writer.write_all(&to_u32(v).to_le_bytes())?;
152        }
153
154        writer.write_all(&to_u16(self.blocks).to_le_bytes())?;
155
156        Ok(())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[rustfmt::skip]
165    const TEST_HEADER: &[u8] = &[
166        0x73, 0x61, 0x66, 0x76, 0x73, 0x68, 0x75, 0x66, // magic number
167        0x69, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 105u64 sites
168        0x02,                                           // 2u8 shapes
169        0x07, 0x00, 0x00, 0x00,                         // 5u32 = shape[0]
170        0x05, 0x00, 0x00, 0x00,                         // 7u32 = shape[1]
171        0x0A, 0x00,                                     // 10u16 blocks
172    ];
173
174    #[test]
175    fn test_write_header() -> io::Result<()> {
176        let header = Header::new(105, vec![7, 5], 10);
177        let mut dest = Vec::new();
178        header.write(&mut dest)?;
179
180        let expected = TEST_HEADER;
181        assert_eq!(dest, expected);
182
183        Ok(())
184    }
185
186    #[test]
187    fn test_read_header() -> io::Result<()> {
188        let src = TEST_HEADER;
189        let header = Header::read(src)?;
190
191        let expected = Header::new(105, vec![7, 5], 10);
192        assert_eq!(header, expected);
193
194        Ok(())
195    }
196
197    #[test]
198    fn test_read_header_fails_wrong_magic() {
199        let mut wrong_header = TEST_HEADER.to_vec();
200        wrong_header[0] = 0;
201
202        let result = Header::read(wrong_header.as_slice());
203        assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
204    }
205
206    #[test]
207    fn test_header_size() {
208        assert_eq!(Header::new(105, vec![7], 10).header_size(), 23);
209        assert_eq!(Header::new(1005, vec![7, 5], 20).header_size(), 27);
210        assert_eq!(Header::new(15, vec![7, 5, 11], 5).header_size(), 31);
211    }
212
213    #[test]
214    fn test_data_size() {
215        assert_eq!(Header::new(105, vec![7], 10).data_size(), 2940);
216        assert_eq!(Header::new(1005, vec![7, 5], 20).data_size(), 48240);
217        assert_eq!(Header::new(15, vec![7, 5, 11], 5).data_size(), 1380);
218    }
219
220    #[test]
221    fn test_block_sites_even() {
222        let header = Header::new(100, vec![3, 9], 5);
223        let expected = vec![20; 5];
224        assert_eq!(header.block_sites().collect::<Vec<_>>(), expected);
225    }
226
227    #[test]
228    fn test_block_sites_not_even() {
229        let header = Header::new(99, vec![3, 9], 5);
230        let expected: Vec<_> = vec![20, 20, 20, 20, 19];
231        assert_eq!(header.block_sites().collect::<Vec<_>>(), expected);
232
233        let header = Header::new(101, vec![3, 9], 5);
234        let expected: Vec<_> = vec![21, 20, 20, 20, 20];
235        assert_eq!(header.block_sites().collect::<Vec<_>>(), expected);
236
237        let header = Header::new(10, vec![1, 2], 4);
238        let expected: Vec<_> = vec![3, 3, 2, 2];
239        assert_eq!(header.block_sites().collect::<Vec<_>>(), expected);
240    }
241
242    #[test]
243    fn test_block_sizes() {
244        let header = Header::new(10, vec![1, 2], 4);
245        let expected: Vec<_> = vec![36, 36, 24, 24];
246        assert_eq!(header.block_sizes().collect::<Vec<_>>(), expected);
247    }
248
249    #[test]
250    fn test_block_offsets() {
251        let header = Header::new(10, vec![1, 2], 4);
252        let x = header.header_size();
253        let expected: Vec<_> = vec![x, x + 36, x + 72, x + 96];
254        assert_eq!(header.block_offsets().collect::<Vec<_>>(), expected);
255    }
256}