1use std::{
2 fs::File,
3 io::{self, Seek, Write},
4 path::Path,
5 thread::panicking,
6};
7
8use angsd_saf::version::Version;
9
10use super::{to_u64, to_usize, Header};
11
12use crate::{
13 em::StreamEmSite,
14 io::{Intersect, ReadSite},
15 saf::Site,
16};
17
18pub struct Writer<W> {
24 writers: Vec<W>,
25 header: Header,
26 current: usize,
27 finish_flag: bool, }
29
30impl<W> Writer<W> {
31 fn is_finished(&self) -> bool {
33 self.current >= to_usize(self.header.sites())
34 }
35
36 fn new(writers: Vec<W>, header: Header) -> Self {
38 let finish_flag = header.sites() == 0;
39
40 Self {
41 writers,
42 header,
43 current: 0,
44 finish_flag,
45 }
46 }
47
48 fn try_drop(&mut self) -> io::Result<()> {
50 if self.is_finished() | self.finish_flag {
51 Ok(())
52 } else {
53 Err(io::Error::new(
54 io::ErrorKind::InvalidData,
55 "closing pseudo-shuffled SAF file writer before it was filled",
56 ))
57 }
58 }
59
60 pub fn try_finish(mut self) -> io::Result<()> {
65 let result = self.try_drop();
66 self.finish_flag = true;
68 result
69 }
70}
71
72impl Writer<io::BufWriter<File>> {
73 pub fn create<P>(path: P, header: Header) -> io::Result<Self>
84 where
85 P: AsRef<Path>,
86 {
87 let file_size = header.file_size();
88
89 let mut f = File::create(&path)?;
90 f.set_len(to_u64(file_size))?;
91 header.write(&mut f)?;
92
93 let writers = header
94 .block_offsets()
95 .map(|offset| open_writer_at_offset(&path, to_u64(offset)))
96 .collect::<io::Result<Vec<_>>>()?;
97
98 Ok(Self::new(writers, header))
99 }
100
101 pub fn write_intersect<const D: usize, R, V>(
105 mut self,
106 mut intersect: Intersect<D, R, V>,
107 ) -> io::Result<()>
108 where
109 Intersect<D, R, V>: ReadSite<Site = Site<D>>,
110 R: io::BufRead + io::Seek,
111 V: Version,
112 {
113 let shape = intersect
114 .get()
115 .get_readers()
116 .iter()
117 .map(|reader| reader.index().alleles() + 1)
118 .collect::<Vec<_>>()
119 .try_into()
120 .unwrap();
121 let mut site = Site::from_shape(shape);
122
123 while intersect.read_site_unnormalised(&mut site)?.is_not_done() {
124 self.write_site(site.as_slice())?
125 }
126
127 self.try_finish()
128 }
129
130 pub fn write_site(&mut self, values: &[f32]) -> io::Result<()> {
136 if self.is_finished() {
137 return Err(io::Error::new(
138 io::ErrorKind::InvalidData,
139 "attempted to write more sites to writer than allocated",
140 ));
141 } else if values.len() != self.header.width() {
142 return Err(io::Error::new(
143 io::ErrorKind::InvalidData,
144 "number of values provided to writer does not match provided shape",
145 ));
146 }
147
148 let next_idx = self.current % self.writers.len();
149 let writer = &mut self.writers[next_idx];
150 for v in values {
151 writer.write_all(&v.to_le_bytes())?;
152 }
153
154 self.current += 1;
155
156 Ok(())
157 }
158
159 pub fn write_disjoint_site<I>(&mut self, values_iter: I) -> io::Result<()>
166 where
167 I: IntoIterator,
168 I::Item: AsRef<[f32]>,
169 I::IntoIter: ExactSizeIterator,
170 {
171 let values_iter = values_iter.into_iter();
172 let shape = self.header.shape();
173
174 if self.is_finished() {
175 return Err(io::Error::new(
176 io::ErrorKind::InvalidData,
177 "attempted to write more sites to writer than allocated",
178 ));
179 } else if values_iter.len() != shape.len() {
180 return Err(io::Error::new(
181 io::ErrorKind::InvalidData,
182 "more value slices provided for writing than shapes provided in header",
183 ));
184 }
185
186 let next_idx = self.current % self.writers.len();
187 let writer = &mut self.writers[next_idx];
188
189 for (values, &shape) in values_iter.zip(shape) {
190 if values.as_ref().len() != shape {
191 return Err(io::Error::new(
192 io::ErrorKind::InvalidData,
193 "provided values does not fit corresponding header shape",
194 ));
195 }
196
197 for v in values.as_ref() {
198 writer.write_all(&v.to_le_bytes())?
199 }
200 }
201
202 self.current += 1;
203
204 Ok(())
205 }
206}
207
208impl<W> Drop for Writer<W> {
209 fn drop(&mut self) {
210 if !panicking() {
213 self.try_drop().unwrap()
214 }
215 }
216}
217
218fn open_writer_at_offset<P>(path: P, offset: u64) -> io::Result<io::BufWriter<File>>
220where
221 P: AsRef<Path>,
222{
223 let mut f = File::options().write(true).open(&path)?;
224 f.seek(io::SeekFrom::Start(offset))?;
225
226 Ok(io::BufWriter::new(f))
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 use std::{
234 io::{Read, SeekFrom},
235 mem::size_of,
236 };
237
238 use tempfile::NamedTempFile;
239
240 #[test]
241 fn test_writer_too_few_sites_error() -> io::Result<()> {
242 let file = NamedTempFile::new()?;
243 let path = file.path();
244
245 let header = Header::new(514, vec![15, 7], 20);
246 let writer = Writer::create(path, header)?;
247
248 assert_eq!(
249 writer.try_finish().unwrap_err().kind(),
250 io::ErrorKind::InvalidData
251 );
252
253 file.close()
254 }
255
256 #[test]
257 fn test_writer_too_many_sites_error() -> io::Result<()> {
258 let file = NamedTempFile::new()?;
259 let path = file.path();
260
261 let header = Header::new(2, vec![1, 2], 2);
262 let mut writer = Writer::create(path, header.clone())?;
263
264 let values = vec![0.0; header.width()];
265 writer.write_site(values.as_slice())?;
266 writer.write_site(values.as_slice())?;
267
268 let result = writer.write_site(values.as_slice());
269 assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
270
271 file.close()
272 }
273
274 #[test]
275 fn test_create_writer() -> io::Result<()> {
276 let file = NamedTempFile::new()?;
277 let path = file.path();
278
279 let header = Header::new(514, vec![15, 7], 20);
280 let mut writer = Writer::create(path, header.clone())?;
281
282 assert_eq!(
283 file.as_file().metadata()?.len() as usize,
284 header.file_size(),
285 );
286
287 let initial_offsets = writer
288 .writers
289 .iter_mut()
290 .map(|writer| writer.get_mut().seek(SeekFrom::Current(0)).map(to_usize))
291 .collect::<io::Result<Vec<_>>>()?;
292 let expected_offsets = header.block_offsets().collect::<Vec<_>>();
293 assert_eq!(initial_offsets, expected_offsets);
294
295 let _error = writer.try_finish();
296 file.close()
297 }
298
299 fn test_shuffled<I, F>(
302 header: Header,
303 sites: I,
304 expected: &[f32],
305 mut write_fn: F,
306 ) -> io::Result<()>
307 where
308 I: IntoIterator,
309 F: FnMut(&mut Writer<io::BufWriter<File>>, I::Item) -> io::Result<()>,
310 {
311 let mut file = NamedTempFile::new()?;
312 let path = file.path();
313
314 let mut writer = Writer::create(path, header.clone())?;
315
316 for site in sites {
317 write_fn(&mut writer, site)?;
318 }
319
320 writer.try_finish().unwrap();
322
323 let mut data = Vec::new();
324 file.seek(SeekFrom::Start(header.header_size() as u64))?;
325 file.read_to_end(&mut data)?;
326
327 let written: Vec<f32> = data
328 .chunks(size_of::<f32>())
329 .map(|bytes| f32::from_le_bytes(bytes.try_into().unwrap()))
330 .collect();
331
332 assert_eq!(written, expected);
333
334 file.close()
335 }
336
337 #[test]
338 fn test_writer_shuffle() -> io::Result<()> {
339 let header = Header::new(10, vec![1, 2], 4);
340
341 let sites = vec![
342 &[0., 0., 0.],
343 &[1., 1., 1.],
344 &[2., 2., 2.],
345 &[3., 3., 3.],
346 &[4., 4., 4.],
347 &[5., 5., 5.],
348 &[6., 6., 6.],
349 &[7., 7., 7.],
350 &[8., 8., 8.],
351 &[9., 9., 9.],
352 ];
353
354 #[rustfmt::skip]
355 let expected = vec![
356 0., 0., 0.,
357 4., 4., 4.,
358 8., 8., 8.,
359 1., 1., 1.,
360 5., 5., 5.,
361 9., 9., 9.,
362 2., 2., 2.,
363 6., 6., 6.,
364 3., 3., 3.,
365 7., 7., 7.,
366 ];
367
368 test_shuffled(header, sites, expected.as_slice(), |writer, site| {
369 writer.write_site(site)
370 })
371 }
372
373 #[test]
374 fn test_writer_disjoint_shuffle() -> io::Result<()> {
375 let header = Header::new(10, vec![1, 2], 4);
376
377 let sites = vec![
378 vec![&[0.][..], &[0., 0.][..]],
379 vec![&[1.][..], &[1., 1.][..]],
380 vec![&[2.][..], &[2., 2.][..]],
381 vec![&[3.][..], &[3., 3.][..]],
382 vec![&[4.][..], &[4., 4.][..]],
383 vec![&[5.][..], &[5., 5.][..]],
384 vec![&[6.][..], &[6., 6.][..]],
385 vec![&[7.][..], &[7., 7.][..]],
386 vec![&[8.][..], &[8., 8.][..]],
387 vec![&[9.][..], &[9., 9.][..]],
388 ];
389
390 #[rustfmt::skip]
391 let expected = vec![
392 0., 0., 0.,
393 4., 4., 4.,
394 8., 8., 8.,
395 1., 1., 1.,
396 5., 5., 5.,
397 9., 9., 9.,
398 2., 2., 2.,
399 6., 6., 6.,
400 3., 3., 3.,
401 7., 7., 7.,
402 ];
403
404 test_shuffled(header, sites, expected.as_slice(), |writer, site| {
405 writer.write_disjoint_site(site)
406 })
407 }
408}