winsfs_core/io/shuffle/
writer.rs

1use std::{
2    fs::File,
3    io::{self, Seek, Write},
4    path::Path,
5    thread::panicking,
6};
7
8use angsd_saf::version::Version;
9
10use super::{to_u64, to_usize, Header};
11
12use crate::{
13    em::StreamEmSite,
14    io::{Intersect, ReadSite},
15    saf::Site,
16};
17
18/// A pseudo-shuffled SAF file writer.
19///
20/// Note that the writer has a fallible drop check.
21/// See [`Writer::create`] and [`Writer::try_finish`] for more, as well as
22/// the [module docs](index.html#write) for general usage..
23pub struct Writer<W> {
24    writers: Vec<W>,
25    header: Header,
26    current: usize,
27    finish_flag: bool, // Flag used for drop check
28}
29
30impl<W> Writer<W> {
31    /// Check if reader is finished.
32    fn is_finished(&self) -> bool {
33        self.current >= to_usize(self.header.sites())
34    }
35
36    /// Creates a new writer.
37    fn new(writers: Vec<W>, header: Header) -> Self {
38        let finish_flag = header.sites() == 0;
39
40        Self {
41            writers,
42            header,
43            current: 0,
44            finish_flag,
45        }
46    }
47
48    /// Fallible drop check, used in both the actual Drop impl and try_finish.
49    fn try_drop(&mut self) -> io::Result<()> {
50        if self.is_finished() | self.finish_flag {
51            Ok(())
52        } else {
53            Err(io::Error::new(
54                io::ErrorKind::InvalidData,
55                "closing pseudo-shuffled SAF file writer before it was filled",
56            ))
57        }
58    }
59
60    /// Consumes the writer fallibly.
61    ///
62    /// This can be use to drop the writer and handle an error in the drop. See [`Writer::create`]
63    /// for more information on when the writer can be dropped.
64    pub fn try_finish(mut self) -> io::Result<()> {
65        let result = self.try_drop();
66        // Set the flag here so that drop check doesn't panic now
67        self.finish_flag = true;
68        result
69    }
70}
71
72impl Writer<io::BufWriter<File>> {
73    /// Creates a new pseudo-shuffled SAF file writer.
74    ///
75    /// Note that this will pre-allocate the full disk space needed to fit the data described in
76    /// the header. If the path already exists, it will be overwritten. The header information will
77    /// be written to the file.
78    ///
79    /// Since the full file space is pre-allocated, and since data is not written sequentially,
80    /// it is considered an error if less sites are written than specified in the `header`.
81    /// This condition is checked when dropping the reader, and the drop check will panic if the
82    /// check is failed. See [`Writer::try_finish`] to handle the result of this check.
83    pub fn create<P>(path: P, header: Header) -> io::Result<Self>
84    where
85        P: AsRef<Path>,
86    {
87        let file_size = header.file_size();
88
89        let mut f = File::create(&path)?;
90        f.set_len(to_u64(file_size))?;
91        header.write(&mut f)?;
92
93        let writers = header
94            .block_offsets()
95            .map(|offset| open_writer_at_offset(&path, to_u64(offset)))
96            .collect::<io::Result<Vec<_>>>()?;
97
98        Ok(Self::new(writers, header))
99    }
100
101    /// Writes an entire reader to the writer.
102    ///
103    /// Assumes that the reader contains the appropriate number of sites.
104    pub fn write_intersect<const D: usize, R, V>(
105        mut self,
106        mut intersect: Intersect<D, R, V>,
107    ) -> io::Result<()>
108    where
109        Intersect<D, R, V>: ReadSite<Site = Site<D>>,
110        R: io::BufRead + io::Seek,
111        V: Version,
112    {
113        let shape = intersect
114            .get()
115            .get_readers()
116            .iter()
117            .map(|reader| reader.index().alleles() + 1)
118            .collect::<Vec<_>>()
119            .try_into()
120            .unwrap();
121        let mut site = Site::from_shape(shape);
122
123        while intersect.read_site_unnormalised(&mut site)?.is_not_done() {
124            self.write_site(site.as_slice())?
125        }
126
127        self.try_finish()
128    }
129
130    /// Writes a single site to the writer.
131    ///
132    /// No more sites can be written than specified in the header specified to [`Writer::create`].
133    /// Also, the number of values in `site` must match the sum of the shape provided in the header.
134    /// If either of those conditions are not met, an error will be returned.
135    pub fn write_site(&mut self, values: &[f32]) -> io::Result<()> {
136        if self.is_finished() {
137            return Err(io::Error::new(
138                io::ErrorKind::InvalidData,
139                "attempted to write more sites to writer than allocated",
140            ));
141        } else if values.len() != self.header.width() {
142            return Err(io::Error::new(
143                io::ErrorKind::InvalidData,
144                "number of values provided to writer does not match provided shape",
145            ));
146        }
147
148        let next_idx = self.current % self.writers.len();
149        let writer = &mut self.writers[next_idx];
150        for v in values {
151            writer.write_all(&v.to_le_bytes())?;
152        }
153
154        self.current += 1;
155
156        Ok(())
157    }
158
159    /// Writes a single site split across multiple slices to the writer.
160    ///
161    /// The different slices here may for instance correspond to different populations. As for
162    /// [`Writer::write_site`], no more sites can be than specified in the header specified to
163    /// [`Writer::create`]. The provided sites must match the shape provided in the header.
164    /// If either of those conditions are not met, an error will be returned.
165    pub fn write_disjoint_site<I>(&mut self, values_iter: I) -> io::Result<()>
166    where
167        I: IntoIterator,
168        I::Item: AsRef<[f32]>,
169        I::IntoIter: ExactSizeIterator,
170    {
171        let values_iter = values_iter.into_iter();
172        let shape = self.header.shape();
173
174        if self.is_finished() {
175            return Err(io::Error::new(
176                io::ErrorKind::InvalidData,
177                "attempted to write more sites to writer than allocated",
178            ));
179        } else if values_iter.len() != shape.len() {
180            return Err(io::Error::new(
181                io::ErrorKind::InvalidData,
182                "more value slices provided for writing than shapes provided in header",
183            ));
184        }
185
186        let next_idx = self.current % self.writers.len();
187        let writer = &mut self.writers[next_idx];
188
189        for (values, &shape) in values_iter.zip(shape) {
190            if values.as_ref().len() != shape {
191                return Err(io::Error::new(
192                    io::ErrorKind::InvalidData,
193                    "provided values does not fit corresponding header shape",
194                ));
195            }
196
197            for v in values.as_ref() {
198                writer.write_all(&v.to_le_bytes())?
199            }
200        }
201
202        self.current += 1;
203
204        Ok(())
205    }
206}
207
208impl<W> Drop for Writer<W> {
209    fn drop(&mut self) {
210        // Don't check if writer is finished if already unwinding from panic,
211        // or we will likely get a double panic
212        if !panicking() {
213            self.try_drop().unwrap()
214        }
215    }
216}
217
218/// Opens path for writing without truncating and creates a writer positioned at byte offset.
219fn open_writer_at_offset<P>(path: P, offset: u64) -> io::Result<io::BufWriter<File>>
220where
221    P: AsRef<Path>,
222{
223    let mut f = File::options().write(true).open(&path)?;
224    f.seek(io::SeekFrom::Start(offset))?;
225
226    Ok(io::BufWriter::new(f))
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    use std::{
234        io::{Read, SeekFrom},
235        mem::size_of,
236    };
237
238    use tempfile::NamedTempFile;
239
240    #[test]
241    fn test_writer_too_few_sites_error() -> io::Result<()> {
242        let file = NamedTempFile::new()?;
243        let path = file.path();
244
245        let header = Header::new(514, vec![15, 7], 20);
246        let writer = Writer::create(path, header)?;
247
248        assert_eq!(
249            writer.try_finish().unwrap_err().kind(),
250            io::ErrorKind::InvalidData
251        );
252
253        file.close()
254    }
255
256    #[test]
257    fn test_writer_too_many_sites_error() -> io::Result<()> {
258        let file = NamedTempFile::new()?;
259        let path = file.path();
260
261        let header = Header::new(2, vec![1, 2], 2);
262        let mut writer = Writer::create(path, header.clone())?;
263
264        let values = vec![0.0; header.width()];
265        writer.write_site(values.as_slice())?;
266        writer.write_site(values.as_slice())?;
267
268        let result = writer.write_site(values.as_slice());
269        assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
270
271        file.close()
272    }
273
274    #[test]
275    fn test_create_writer() -> io::Result<()> {
276        let file = NamedTempFile::new()?;
277        let path = file.path();
278
279        let header = Header::new(514, vec![15, 7], 20);
280        let mut writer = Writer::create(path, header.clone())?;
281
282        assert_eq!(
283            file.as_file().metadata()?.len() as usize,
284            header.file_size(),
285        );
286
287        let initial_offsets = writer
288            .writers
289            .iter_mut()
290            .map(|writer| writer.get_mut().seek(SeekFrom::Current(0)).map(to_usize))
291            .collect::<io::Result<Vec<_>>>()?;
292        let expected_offsets = header.block_offsets().collect::<Vec<_>>();
293        assert_eq!(initial_offsets, expected_offsets);
294
295        let _error = writer.try_finish();
296        file.close()
297    }
298
299    // Helper for testing that writing the provided sites with the given header produced
300    // the expected data when read back in
301    fn test_shuffled<I, F>(
302        header: Header,
303        sites: I,
304        expected: &[f32],
305        mut write_fn: F,
306    ) -> io::Result<()>
307    where
308        I: IntoIterator,
309        F: FnMut(&mut Writer<io::BufWriter<File>>, I::Item) -> io::Result<()>,
310    {
311        let mut file = NamedTempFile::new()?;
312        let path = file.path();
313
314        let mut writer = Writer::create(path, header.clone())?;
315
316        for site in sites {
317            write_fn(&mut writer, site)?;
318        }
319
320        // Drop the writer to flush
321        writer.try_finish().unwrap();
322
323        let mut data = Vec::new();
324        file.seek(SeekFrom::Start(header.header_size() as u64))?;
325        file.read_to_end(&mut data)?;
326
327        let written: Vec<f32> = data
328            .chunks(size_of::<f32>())
329            .map(|bytes| f32::from_le_bytes(bytes.try_into().unwrap()))
330            .collect();
331
332        assert_eq!(written, expected);
333
334        file.close()
335    }
336
337    #[test]
338    fn test_writer_shuffle() -> io::Result<()> {
339        let header = Header::new(10, vec![1, 2], 4);
340
341        let sites = vec![
342            &[0., 0., 0.],
343            &[1., 1., 1.],
344            &[2., 2., 2.],
345            &[3., 3., 3.],
346            &[4., 4., 4.],
347            &[5., 5., 5.],
348            &[6., 6., 6.],
349            &[7., 7., 7.],
350            &[8., 8., 8.],
351            &[9., 9., 9.],
352        ];
353
354        #[rustfmt::skip]
355        let expected = vec![
356            0., 0., 0.,
357            4., 4., 4.,
358            8., 8., 8.,
359            1., 1., 1.,
360            5., 5., 5.,
361            9., 9., 9.,
362            2., 2., 2.,
363            6., 6., 6.,
364            3., 3., 3.,
365            7., 7., 7.,
366        ];
367
368        test_shuffled(header, sites, expected.as_slice(), |writer, site| {
369            writer.write_site(site)
370        })
371    }
372
373    #[test]
374    fn test_writer_disjoint_shuffle() -> io::Result<()> {
375        let header = Header::new(10, vec![1, 2], 4);
376
377        let sites = vec![
378            vec![&[0.][..], &[0., 0.][..]],
379            vec![&[1.][..], &[1., 1.][..]],
380            vec![&[2.][..], &[2., 2.][..]],
381            vec![&[3.][..], &[3., 3.][..]],
382            vec![&[4.][..], &[4., 4.][..]],
383            vec![&[5.][..], &[5., 5.][..]],
384            vec![&[6.][..], &[6., 6.][..]],
385            vec![&[7.][..], &[7., 7.][..]],
386            vec![&[8.][..], &[8., 8.][..]],
387            vec![&[9.][..], &[9., 9.][..]],
388        ];
389
390        #[rustfmt::skip]
391        let expected = vec![
392            0., 0., 0.,
393            4., 4., 4.,
394            8., 8., 8.,
395            1., 1., 1.,
396            5., 5., 5.,
397            9., 9., 9.,
398            2., 2., 2.,
399            6., 6., 6.,
400            3., 3., 3.,
401            7., 7., 7.,
402        ];
403
404        test_shuffled(header, sites, expected.as_slice(), |writer, site| {
405            writer.write_disjoint_site(site)
406        })
407    }
408}