winsfs_core/io/
adaptors.rs

1use std::io;
2
3use super::{ReadSite, ReadStatus, Rewind};
4
5/// A type that keeps track of how many sites have been read from the underlying source.
6///
7/// Constructed using [`ReadSite::enumerate`].
8pub struct Enumerate<R> {
9    inner: R,
10    sites_read: usize,
11}
12
13impl<R> Enumerate<R>
14where
15    R: ReadSite,
16{
17    pub(super) fn new(inner: R) -> Self {
18        Self {
19            inner,
20            sites_read: 0,
21        }
22    }
23
24    /// Returns the number of sites read from the underlying source.
25    pub fn sites_read(&self) -> usize {
26        self.sites_read
27    }
28
29    /// Returns a reader adaptor which limits the number of sites read.
30    ///
31    /// See also [`ReadSite::take`].
32    pub fn take(self, max_sites: usize) -> Take<Self> {
33        Take::new(self, max_sites)
34    }
35}
36
37impl<R> Rewind for Enumerate<R>
38where
39    R: Rewind,
40{
41    fn is_done(&mut self) -> io::Result<bool> {
42        self.inner.is_done()
43    }
44
45    fn rewind(&mut self) -> io::Result<()> {
46        self.inner.rewind()
47    }
48}
49
50impl<R> ReadSite for Enumerate<R>
51where
52    R: ReadSite,
53{
54    type Site = R::Site;
55
56    fn read_site(&mut self, buf: &mut Self::Site) -> io::Result<ReadStatus> {
57        self.sites_read += 1;
58        self.inner.read_site(buf)
59    }
60
61    fn read_site_unnormalised(&mut self, buf: &mut Self::Site) -> io::Result<ReadStatus> {
62        self.sites_read += 1;
63        self.inner.read_site_unnormalised(buf)
64    }
65}
66
67/// A type that limits the number of sites that can be read from the underlying source.
68///
69/// Constructed using [`ReadSite::take`] or [`Enumerate::take`].
70pub struct Take<R> {
71    inner: R,
72    max_sites: usize,
73}
74
75impl<R> Take<Enumerate<R>>
76where
77    R: ReadSite,
78{
79    pub(super) fn new(inner: Enumerate<R>, max_sites: usize) -> Self {
80        Self { inner, max_sites }
81    }
82
83    /// Returns the maximum number of sites that can be read from the underlying source.
84    pub fn max_sites(&self) -> usize {
85        self.max_sites
86    }
87
88    /// Returns the number of sites read from the underlying source.
89    pub fn sites_read(&self) -> usize {
90        self.inner.sites_read()
91    }
92}
93
94impl<R> Rewind for Take<Enumerate<R>>
95where
96    R: Rewind,
97{
98    fn is_done(&mut self) -> io::Result<bool> {
99        self.inner.is_done()
100    }
101
102    fn rewind(&mut self) -> io::Result<()> {
103        self.inner.rewind()
104    }
105}
106
107impl<R> ReadSite for Take<Enumerate<R>>
108where
109    R: ReadSite,
110{
111    type Site = R::Site;
112
113    fn read_site(&mut self, buf: &mut Self::Site) -> io::Result<ReadStatus> {
114        if self.inner.sites_read() < self.max_sites {
115            self.inner.read_site(buf)
116        } else {
117            Ok(ReadStatus::Done)
118        }
119    }
120
121    fn read_site_unnormalised(&mut self, buf: &mut Self::Site) -> io::Result<ReadStatus> {
122        if self.inner.sites_read() < self.max_sites {
123            self.inner.read_site_unnormalised(buf)
124        } else {
125            Ok(ReadStatus::Done)
126        }
127    }
128}