1#![warn(
4 clippy::all,
5 clippy::restriction,
6 clippy::pedantic,
7 clippy::nursery,
8 clippy::cargo
9)]
10
11use crate::TapirError::*;
12use std::fmt::{Debug, Display, Formatter};
13use std::io::{self, Write};
14use std::iter;
15
16fn match_right(index: usize, program: &[u8], lma: usize, lmv: usize) -> (usize, usize, usize) {
17 if lma == index {
18 return (lmv, lma, lmv);
19 }
20
21 let lma = index;
22 let mut index = index;
23
24 let mut depth = 0;
25 while index < program.len() {
26 if program[index] as char == '[' {
27 depth += 1;
28 } else if program[index] as char == ']' {
29 depth -= 1;
30 }
31 if depth == 0 {
32 break;
33 }
34 index += 1;
35 }
36 if depth != 0 {
37 return (0, lma, lmv);
38 }
39 (index, lma, index)
40}
41
42fn match_left(
43 index: usize,
44 program: &[u8],
45 lma: usize,
46 lmv: usize,
47) -> (Option<usize>, usize, usize) {
48 if lma == index {
49 return (Some(lmv), lma, lmv);
50 }
51
52 let lma = index;
53 let origin = index;
54 let mut index = index;
55 let mut depth = 0;
56
57 while index <= origin {
58 if program[index] as char == '[' {
59 depth += 1;
60 } else if program[index] as char == ']' {
61 depth -= 1;
62 }
63 if depth == 0 {
64 break;
65 }
66 index -= 1;
67 }
68 if depth != 0 {
69 return (None, lma, lmv);
70 }
71
72 (Some(index), lma, index)
73}
74
75fn enhanced(s: &[u8]) -> Result<Vec<u8>, String> {
76 let mut bytes: Vec<u8> = Vec::with_capacity(s.len());
77
78 let byte_input = s;
79
80 let mut i = 0;
81 let input_len = byte_input.len();
82 while i < input_len {
83 if byte_input[i] as char == '\\' {
84 if i + 1 < input_len {
85 i += 1;
86 bytes.push(match byte_input[i] {
87 b'n' => b'\n' as u8,
88 b'r' => b'\r' as u8,
89 b't' => b'\t' as u8,
90 b'\\' => b'\\' as u8,
91 b'#' => b'#' as u8,
92 _ => {
93 return Err(format!(
94 "expected an escape code but found '{}'",
95 byte_input[i] as char
96 ));
97 }
98 });
99 } else {
100 return Err("expected an escape code but found nothing".to_string());
101 }
102 } else if byte_input[i] as char == '#' {
103 if i + 3 < input_len {
104 let mut res: u8 = 0;
105 for j in 1..4 {
106 if byte_input[i + j].is_ascii_digit() {
107 res += (byte_input[i + j] - 48) * 10_u8.pow(3 - j as u32);
108 } else {
109 return Err(format!(
110 "expected a digit but found '{}'",
111 byte_input[i + j] as char
112 ));
113 }
114 }
115 bytes.push(res);
116 i += 3;
117 } else {
118 return Err("expected a byte literal but didn't find enough digits".to_string());
119 }
120 } else {
121 bytes.push(byte_input[i]);
122 }
123 i += 1;
124 }
125
126 debug_assert!(
127 !bytes.is_empty(),
128 "This function should only ever be called on non-empty arguments."
129 );
130
131 Ok(bytes)
132}
133
134pub fn enhanced_input<I, E>(input: I) -> impl Iterator<Item = Result<u8, EnhancedInputError<E>>>
143where
144 I: IntoIterator<Item = Result<u8, E>>,
145{
146 let mut input = input.into_iter();
147
148 let mut line: Vec<u8> = Vec::new();
149
150 iter::from_fn(move || match input.next() {
151 Some(Ok(b'\n')) => {
152 let r = Some(Some(Ok(line.clone())));
153 line.clear();
154 r
155 }
156 Some(Ok(c)) => {
157 line.push(c);
158 Some(None)
159 }
160 Some(Err(e)) => Some(Some(Err(EnhancedInputError::InputException(e)))),
161 None => None,
162 })
163 .filter_map(|x| match x {
164 None => None,
165 Some(Ok(line)) => match enhanced(&line) {
166 Ok(u8s) => Some(Ok(u8s)),
167 Err(e) => Some(Err(EnhancedInputError::MalformedInput(e))),
168 },
169 Some(Err(e)) => Some(Err(e)),
170 })
171 .flat_map(|result| match result {
172 Ok(vec) => vec.into_iter().map(Ok).collect::<Vec<_>>().into_iter(),
173 Err(e) => std::iter::once(Err(e)).collect::<Vec<_>>().into_iter(),
174 })
175}
176
177#[derive(Debug)]
180pub enum EnhancedInputError<E> {
181 MalformedInput(String),
182 InputException(E),
183}
184
185impl<E: Display> Display for EnhancedInputError<E> {
186 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187 match self {
188 Self::MalformedInput(s) => write!(f, "Malformed input: {}", s),
189 Self::InputException(e) => write!(f, "{}", e),
190 }
191 }
192}
193
194#[derive(Debug)]
197pub enum TapirError<E> {
198 BracketError,
199 MemPtrUnderflowError,
200 InputExceptionError(E),
201 MissingInputError,
202 OutputError(io::Error),
203}
204
205impl<E: Display> Display for TapirError<E> {
206 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
207 match self {
208 InputExceptionError(e) => write!(f, "InputExceptionError: {}", e),
209 OutputError(e) => write!(f, "OutputError: {}", e),
210 BracketError => write!(f, "BracketError"),
211 MemPtrUnderflowError => write!(f, "MemPtrUnderflowError"),
212 MissingInputError => write!(f, "MissingInputError"),
213 }
214 }
215}
216
217impl<E> Into<i32> for TapirError<E> {
218 fn into(self) -> i32 {
219 match self {
220 BracketError => 2,
221 MemPtrUnderflowError => 3,
222 InputExceptionError(_) => 4,
223 MissingInputError => 5,
224 OutputError(_) => 6,
225 }
226 }
227}
228
229#[inline]
243pub fn interpret<I, E, O>(
244 mem: &mut Vec<u8>,
245 program: &[u8],
246 input: I,
247 mut output: O,
248 eof_retry: bool,
249) -> Result<(), (TapirError<E>, usize)>
250where
251 I: IntoIterator<Item = Result<u8, E>>,
252 O: Write,
253{
254 let mut mem_ptr: usize = 0;
255 let mut ins_ptr: usize = 0;
256
257 let mut last_match_left_arg = 0;
258 let mut last_match_left_val = 0;
259 let mut last_match_right_arg = usize::MAX;
260 let mut last_match_right_val = 0;
261
262 if mem.is_empty() {
263 mem.push(0);
264 } let mut input = input.into_iter();
267
268 while ins_ptr < program.len() {
269 match program[ins_ptr] as char {
270 '>' => {
271 mem_ptr += 1;
272 if mem_ptr == mem.len() {
273 mem.push(0);
274 mem.resize(mem.capacity(), 0);
275 } else {
276 debug_assert!(mem_ptr < mem.len(), "if we somehow increment by > 1");
277 }
278 }
279 '<' => {
280 let new_mem_ptr = mem_ptr - 1;
281 if new_mem_ptr > mem_ptr {
282 return Err((MemPtrUnderflowError, ins_ptr));
283 }
284 mem_ptr = new_mem_ptr;
285 }
286 '+' => mem[mem_ptr] += 1,
287 '-' => mem[mem_ptr] -= 1,
288 '.' => {
289 print!("{}", mem[mem_ptr] as char);
290 }
291 ',' => {
292 if let Err(e) = output.flush() {
293 return Err((OutputError(e), ins_ptr));
294 }
295
296 let incoming_byte: u8 = match input.next() {
300 Some(Err(e)) => return Err((InputExceptionError(e), ins_ptr)),
301 Some(Ok(val)) => val,
302 None => {
303 if eof_retry {
304 match (|| loop {
305 match input.next() {
306 Some(Ok(val)) => return Ok(val),
307 Some(Err(e)) => return Err(InputExceptionError(e)),
308 None => (),
309 }
310 })() {
311 Ok(val) => val,
312 Err(e) => return Err((e, ins_ptr)),
313 }
314 } else {
315 return Err((MissingInputError, ins_ptr));
316 }
317 }
318 };
319
320 mem[mem_ptr] = incoming_byte;
321 }
322 '[' => {
323 if mem[mem_ptr] == 0 {
324 match match_right(ins_ptr, program, last_match_right_arg, last_match_right_val)
325 {
326 (0, _, _) => return Err((BracketError, ins_ptr)),
327 (n, lma, lmv) => {
328 ins_ptr = n;
329 last_match_right_arg = lma;
330 last_match_right_val = lmv;
331 }
332 }
333 }
334 }
335 ']' => {
336 if mem[mem_ptr] != 0 {
337 match match_left(ins_ptr, program, last_match_left_arg, last_match_left_val) {
338 (None, _, _) => return Err((BracketError, ins_ptr)),
339 (Some(n), lma, lmv) => {
340 ins_ptr = n;
341 last_match_left_arg = lma;
342 last_match_left_val = lmv;
343 }
344 }
345 }
346 }
347 _ => {}
348 }
349 ins_ptr += 1;
350 }
351
352 if let Err(e) = output.flush() {
353 return Err((OutputError(e), ins_ptr));
354 }
355
356 Ok(())
357}