1use std::{io, iter::once, mem::size_of};
2
3use super::{to_u16, to_u32, to_u64, to_usize};
4
5pub const MAGIC_NUMBER: [u8; 8] = *b"safvshuf";
7
8#[derive(Clone, Debug, Eq, PartialEq)]
13pub struct Header {
14 sites: usize,
15 shape: Vec<usize>,
16 blocks: usize,
17}
18
19impl Header {
20 pub fn blocks(&self) -> usize {
22 self.blocks
23 }
24
25 pub fn data_size(&self) -> usize {
27 to_usize(self.sites) * self.width() * size_of::<f32>()
28 }
29
30 pub(super) fn block_offsets(&self) -> impl Iterator<Item = usize> {
32 once(self.header_size())
33 .chain(self.block_sizes().take(self.blocks - 1))
34 .scan(0, |acc, x| {
35 *acc += x;
36 Some(*acc)
37 })
38 }
39
40 pub(super) fn block_sites(&self) -> impl Iterator<Item = usize> {
42 let div = self.sites / self.blocks;
43 let rem = self.sites % self.blocks;
44
45 (0..self.blocks).map(move |i| if i < rem { div + 1 } else { div })
46 }
47
48 pub(super) fn block_sizes(&self) -> impl Iterator<Item = usize> {
50 let width = self.width();
51 self.block_sites()
52 .map(move |sites| sites * width * size_of::<f32>())
53 }
54
55 pub fn file_size(&self) -> usize {
59 self.header_size() + self.data_size()
60 }
61
62 pub fn header_size(&self) -> usize {
64 let shape_size = size_of::<u8>() + self.shape.len() * size_of::<u32>();
65
66 size_of::<[u8; 8]>() + size_of::<u64>() + shape_size + size_of::<u16>()
67 }
68
69 pub fn new(sites: usize, shape: Vec<usize>, blocks: usize) -> Self {
71 Self {
72 sites,
73 shape,
74 blocks,
75 }
76 }
77
78 pub(super) fn read<R>(mut reader: R) -> io::Result<Self>
80 where
81 R: io::Read,
82 {
83 let mut magic = [0; MAGIC_NUMBER.len()];
84 reader.read_exact(&mut magic)?;
85
86 if magic != MAGIC_NUMBER {
87 return Err(io::Error::new(
88 io::ErrorKind::InvalidData,
89 format!(
90 "invalid or unsupported SAF magic number \
91 (found '{magic:02x?}', expected '{MAGIC_NUMBER:02x?}')"
92 ),
93 ));
94 }
95
96 let mut sites_buf = [0u8; size_of::<u64>()];
97 reader.read_exact(&mut sites_buf)?;
98 let sites = to_usize(u64::from_le_bytes(sites_buf));
99
100 let mut shape_len_buf = [0u8; size_of::<u8>()];
101 reader.read_exact(&mut shape_len_buf)?;
102 let shape_len = u8::from_le_bytes(shape_len_buf);
103
104 let mut shape_buf = [0u8; size_of::<u32>()];
105 let mut shape = Vec::with_capacity(shape_len.into());
106 for _ in 0..shape_len {
107 reader.read_exact(&mut shape_buf)?;
108 shape.push(to_usize(u32::from_le_bytes(shape_buf)));
109 }
110
111 let mut blocks_buf = [0u8; size_of::<u16>()];
112 reader.read_exact(&mut blocks_buf)?;
113 let blocks = usize::from(u16::from_le_bytes(blocks_buf));
114
115 Ok(Self::new(sites, shape, blocks))
116 }
117
118 pub fn shape(&self) -> &[usize] {
120 &self.shape
121 }
122
123 pub fn sites(&self) -> usize {
125 self.sites
126 }
127
128 pub(super) fn width(&self) -> usize {
130 self.shape.iter().sum()
131 }
132
133 pub(super) fn write<W>(&self, mut writer: W) -> io::Result<()>
135 where
136 W: io::Write,
137 {
138 writer.write_all(&MAGIC_NUMBER)?;
139
140 let sites = to_u64(self.sites);
141 writer.write_all(&sites.to_le_bytes())?;
142
143 let shape_len: u8 = self.shape.len().try_into().map_err(|_| {
144 io::Error::new(
145 io::ErrorKind::InvalidInput,
146 "number of header dimensions exceeds {u8::MAX}",
147 )
148 })?;
149 writer.write_all(&shape_len.to_le_bytes())?;
150 for &v in self.shape.iter() {
151 writer.write_all(&to_u32(v).to_le_bytes())?;
152 }
153
154 writer.write_all(&to_u16(self.blocks).to_le_bytes())?;
155
156 Ok(())
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[rustfmt::skip]
165 const TEST_HEADER: &[u8] = &[
166 0x73, 0x61, 0x66, 0x76, 0x73, 0x68, 0x75, 0x66, 0x69, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x07, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x0A, 0x00, ];
173
174 #[test]
175 fn test_write_header() -> io::Result<()> {
176 let header = Header::new(105, vec![7, 5], 10);
177 let mut dest = Vec::new();
178 header.write(&mut dest)?;
179
180 let expected = TEST_HEADER;
181 assert_eq!(dest, expected);
182
183 Ok(())
184 }
185
186 #[test]
187 fn test_read_header() -> io::Result<()> {
188 let src = TEST_HEADER;
189 let header = Header::read(src)?;
190
191 let expected = Header::new(105, vec![7, 5], 10);
192 assert_eq!(header, expected);
193
194 Ok(())
195 }
196
197 #[test]
198 fn test_read_header_fails_wrong_magic() {
199 let mut wrong_header = TEST_HEADER.to_vec();
200 wrong_header[0] = 0;
201
202 let result = Header::read(wrong_header.as_slice());
203 assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
204 }
205
206 #[test]
207 fn test_header_size() {
208 assert_eq!(Header::new(105, vec![7], 10).header_size(), 23);
209 assert_eq!(Header::new(1005, vec![7, 5], 20).header_size(), 27);
210 assert_eq!(Header::new(15, vec![7, 5, 11], 5).header_size(), 31);
211 }
212
213 #[test]
214 fn test_data_size() {
215 assert_eq!(Header::new(105, vec![7], 10).data_size(), 2940);
216 assert_eq!(Header::new(1005, vec![7, 5], 20).data_size(), 48240);
217 assert_eq!(Header::new(15, vec![7, 5, 11], 5).data_size(), 1380);
218 }
219
220 #[test]
221 fn test_block_sites_even() {
222 let header = Header::new(100, vec![3, 9], 5);
223 let expected = vec![20; 5];
224 assert_eq!(header.block_sites().collect::<Vec<_>>(), expected);
225 }
226
227 #[test]
228 fn test_block_sites_not_even() {
229 let header = Header::new(99, vec![3, 9], 5);
230 let expected: Vec<_> = vec![20, 20, 20, 20, 19];
231 assert_eq!(header.block_sites().collect::<Vec<_>>(), expected);
232
233 let header = Header::new(101, vec![3, 9], 5);
234 let expected: Vec<_> = vec![21, 20, 20, 20, 20];
235 assert_eq!(header.block_sites().collect::<Vec<_>>(), expected);
236
237 let header = Header::new(10, vec![1, 2], 4);
238 let expected: Vec<_> = vec![3, 3, 2, 2];
239 assert_eq!(header.block_sites().collect::<Vec<_>>(), expected);
240 }
241
242 #[test]
243 fn test_block_sizes() {
244 let header = Header::new(10, vec![1, 2], 4);
245 let expected: Vec<_> = vec![36, 36, 24, 24];
246 assert_eq!(header.block_sizes().collect::<Vec<_>>(), expected);
247 }
248
249 #[test]
250 fn test_block_offsets() {
251 let header = Header::new(10, vec![1, 2], 4);
252 let x = header.header_size();
253 let expected: Vec<_> = vec![x, x + 36, x + 72, x + 96];
254 assert_eq!(header.block_offsets().collect::<Vec<_>>(), expected);
255 }
256}