1use std::io::{BufRead, BufReader, Read};
2
3use memchr::{memchr, memchr3};
4
5use crate::error::ReaderError;
6use crate::event::Event;
7
8const DEFAULT_BUFFER_SIZE: usize = 128 * 1024;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11enum State {
12 Start,
13 Id,
14 Sequence,
15}
16
17pub struct FastaReader<R> {
19 reader: BufReader<R>,
20 pending_consume: usize,
21 state: State,
22}
23
24impl<R: Read> FastaReader<R> {
25 pub fn new(reader: R) -> Self {
27 Self::with_capacity(DEFAULT_BUFFER_SIZE, reader)
28 }
29
30 pub fn with_capacity(capacity: usize, reader: R) -> Self {
32 Self {
33 reader: BufReader::with_capacity(capacity, reader),
34 pending_consume: 0,
35 state: State::Start,
36 }
37 }
38
39 pub fn next_event(&mut self) -> Option<Result<Event<'_>, ReaderError>> {
41 loop {
42 if self.pending_consume > 0 {
43 self.reader.consume(self.pending_consume);
44 self.pending_consume = 0;
45 }
46
47 let buf = match self.reader.fill_buf() {
48 Ok(b) if b.is_empty() => return None,
49 Ok(b) => b,
50 Err(e) => return Some(Err(e.into())),
51 };
52
53 let buf_ptr = buf.as_ptr();
54 let buf_len = buf.len();
55
56 match self.state {
57 State::Start => {
58 let first_non_ws = buf.iter().position(|&b| b != b'\n' && b != b'\r');
59
60 match first_non_ws {
61 Some(0) => {
62 if buf[0] == b'>' {
63 self.state = State::Id;
64 self.pending_consume = 1;
65 continue; } else {
67 return Some(Err(ReaderError::InvalidFormat {
68 message: format!(
69 "Expected '>' at start of FASTA record, found '{}'",
70 buf[0] as char
71 ),
72 }));
73 }
74 }
75 Some(pos) => {
76 self.pending_consume = pos;
77 continue;
78 }
79 None => {
80 self.pending_consume = buf_len;
81 continue;
82 }
83 }
84 }
85
86 State::Id => {
87 if let Some(newline_pos) = memchr(b'\n', buf) {
88 let end = if newline_pos > 0 && buf[newline_pos - 1] == b'\r' {
89 newline_pos - 1
90 } else {
91 newline_pos
92 };
93
94 self.state = State::Sequence;
95 self.pending_consume = newline_pos + 1;
96
97 if end > 0 {
98 let slice = unsafe { std::slice::from_raw_parts(buf_ptr, end) };
99 return Some(Ok(Event::IdChunk(slice)));
100 } else {
101 continue;
102 }
103 } else {
104 self.pending_consume = buf_len;
105 let slice = unsafe { std::slice::from_raw_parts(buf_ptr, buf_len) };
106 return Some(Ok(Event::IdChunk(slice)));
107 }
108 }
109
110 State::Sequence => {
111 let first_byte = buf[0];
112
113 if first_byte == b'\n' {
114 self.pending_consume = 1;
115 continue;
116 }
117 if first_byte == b'\r' {
118 self.pending_consume = if buf_len > 1 && buf[1] == b'\n' { 2 } else { 1 };
119 continue;
120 }
121 if first_byte == b'>' {
122 self.state = State::Id;
123 self.pending_consume = 1;
124 return Some(Ok(Event::NextRecord));
125 }
126
127 let chunk_end = memchr3(b'\n', b'\r', b'>', buf).unwrap_or(buf_len);
128 if chunk_end == 0 {
129 self.pending_consume = 1;
130 continue;
131 }
132
133 self.pending_consume = chunk_end;
134 let slice = unsafe { std::slice::from_raw_parts(buf_ptr, chunk_end) };
135 return Some(Ok(Event::SeqChunk(slice)));
136 }
137 }
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use std::io::Cursor;
146
147 #[test]
148 fn test_single_record() {
149 let data = b">seq1 description\nACGT\nTGCA\n";
150 let mut reader = FastaReader::new(Cursor::new(&data[..]));
151
152 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::IdChunk(id) if id == b"seq1 description"));
153 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::SeqChunk(s) if s == b"ACGT"));
154 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::SeqChunk(s) if s == b"TGCA"));
155 assert!(reader.next_event().is_none());
156 }
157
158 #[test]
159 fn test_multiple_records() {
160 let data = b">seq1\nACGT\n>seq2\nTGCA\n";
161 let mut reader = FastaReader::new(Cursor::new(&data[..]));
162
163 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::IdChunk(id) if id == b"seq1"));
164 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::SeqChunk(s) if s == b"ACGT"));
165 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::NextRecord));
166 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::IdChunk(id) if id == b"seq2"));
167 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::SeqChunk(s) if s == b"TGCA"));
168 assert!(reader.next_event().is_none());
169 }
170
171 #[test]
172 fn test_crlf_line_endings() {
173 let data = b">seq1\r\nACGT\r\nTGCA\r\n";
174 let mut reader = FastaReader::new(Cursor::new(&data[..]));
175
176 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::IdChunk(id) if id == b"seq1"));
177 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::SeqChunk(s) if s == b"ACGT"));
178 assert!(matches!(reader.next_event().unwrap().unwrap(), Event::SeqChunk(s) if s == b"TGCA"));
179 assert!(reader.next_event().is_none());
180 }
181
182 #[test]
183 fn test_small_buffer() {
184 let data = b">seq1\nACGTACGTACGT\n";
185 let mut reader = FastaReader::with_capacity(4, Cursor::new(&data[..]));
186
187 let mut id = Vec::new();
188 let mut seq = Vec::new();
189
190 loop {
191 match reader.next_event() {
192 Some(Ok(Event::IdChunk(chunk))) => id.extend_from_slice(chunk),
193 Some(Ok(Event::SeqChunk(chunk))) => seq.extend_from_slice(chunk),
194 Some(Ok(_)) => panic!("Unexpected event"),
195 Some(Err(e)) => panic!("Error: {}", e),
196 None => break,
197 }
198 }
199
200 assert_eq!(&id, b"seq1");
201 assert_eq!(&seq, b"ACGTACGTACGT");
202 }
203}