typeline_core/utils/
text_write.rs

1use core::panic;
2use std::io::ErrorKind;
3
4use bstr::ByteSlice;
5
6pub trait TextWrite {
7    // SAFETY: assuming that the write succeeds, the result must be valid utf-8
8    // If a previous, partial success has split a utf-8 character,
9    // the next write must complete it.
10    // NOTE: this means that if writes always succeed, buf must always be valid
11    // utf-8
12    unsafe fn write_text_unchecked(
13        &mut self,
14        buf: &[u8],
15    ) -> std::io::Result<usize>;
16    // SAFETY: assuming that the write succeeds, the result must be valid utf-8
17    // If a previous, partial success has split a utf-8 character,
18    // the next write must complete it.
19    // NOTE: this means that if writes always succeed, buf must always be valid
20    // utf-8
21    unsafe fn write_all_text_unchecked(
22        &mut self,
23        mut buf: &[u8],
24    ) -> std::io::Result<()> {
25        while !buf.is_empty() {
26            match unsafe { self.write_text_unchecked(buf) } {
27                Ok(0) => {
28                    return Err(ErrorKind::WriteZero.into());
29                }
30                Ok(n) => buf = &buf[n..],
31                Err(e) => {
32                    if e.kind() != ErrorKind::Interrupted {
33                        return Err(e);
34                    }
35                }
36            }
37        }
38        Ok(())
39    }
40    fn write_all_text(&mut self, buf: &str) -> std::io::Result<()> {
41        unsafe { self.write_all_text_unchecked(buf.as_bytes()) }
42    }
43    fn write_text_fmt(
44        &mut self,
45        args: std::fmt::Arguments<'_>,
46    ) -> std::io::Result<()> {
47        // Create a shim which translates a Write to a fmt::Write and saves
48        // off I/O errors.
49        struct Adapter<'a, T: ?Sized + 'a> {
50            inner: &'a mut T,
51            error: std::io::Result<()>,
52        }
53
54        impl<T: TextWrite + ?Sized> std::fmt::Write for Adapter<'_, T> {
55            fn write_str(&mut self, s: &str) -> std::fmt::Result {
56                match self.inner.write_all_text(s) {
57                    Ok(()) => Ok(()),
58                    Err(e) => {
59                        self.error = Err(e);
60                        Err(std::fmt::Error)
61                    }
62                }
63            }
64        }
65
66        let mut output = Adapter {
67            inner: self,
68            error: Ok(()),
69        };
70        match std::fmt::write(&mut output, args) {
71            Ok(()) => Ok(()),
72            Err(e) => {
73                // check if the error came from the underlying `Write` or not
74                if output.error.is_err() {
75                    output.error
76                } else {
77                    Err(std::io::Error::new(ErrorKind::Other, e))
78                }
79            }
80        }
81    }
82    fn flush_text(&mut self) -> std::io::Result<()>;
83}
84
85impl<W: TextWrite + ?Sized> TextWrite for &mut W {
86    unsafe fn write_text_unchecked(
87        &mut self,
88        buf: &[u8],
89    ) -> std::io::Result<usize> {
90        unsafe { (**self).write_text_unchecked(buf) }
91    }
92
93    fn flush_text(&mut self) -> std::io::Result<()> {
94        (**self).flush_text()
95    }
96
97    unsafe fn write_all_text_unchecked(
98        &mut self,
99        buf: &[u8],
100    ) -> std::io::Result<()> {
101        unsafe { (**self).write_all_text_unchecked(buf) }
102    }
103
104    fn write_all_text(&mut self, buf: &str) -> std::io::Result<()> {
105        (**self).write_all_text(buf)
106    }
107
108    fn write_text_fmt(
109        &mut self,
110        args: std::fmt::Arguments<'_>,
111    ) -> std::io::Result<()> {
112        (**self).write_text_fmt(args)
113    }
114}
115
116#[derive(Default, Clone, derive_more::Deref, derive_more::DerefMut)]
117pub struct TextWriteIoAdapter<W: std::io::Write>(pub W);
118impl<W: std::io::Write> TextWrite for TextWriteIoAdapter<W> {
119    unsafe fn write_text_unchecked(
120        &mut self,
121        buf: &[u8],
122    ) -> std::io::Result<usize> {
123        self.0.write(buf)
124    }
125    unsafe fn write_all_text_unchecked(
126        &mut self,
127        buf: &[u8],
128    ) -> std::io::Result<()> {
129        self.0.write_all(buf)
130    }
131    fn write_all_text(&mut self, buf: &str) -> std::io::Result<()> {
132        unsafe { self.write_all_text_unchecked(buf.as_bytes()) }
133    }
134    fn flush_text(&mut self) -> std::io::Result<()> {
135        self.0.flush()
136    }
137}
138impl<W: std::io::Write> std::io::Write for TextWriteIoAdapter<W> {
139    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
140        self.0.write(buf)
141    }
142
143    fn flush(&mut self) -> std::io::Result<()> {
144        self.0.flush()
145    }
146
147    fn write_vectored(
148        &mut self,
149        bufs: &[std::io::IoSlice<'_>],
150    ) -> std::io::Result<usize> {
151        self.0.write_vectored(bufs)
152    }
153
154    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
155        self.0.write_all(buf)
156    }
157
158    fn write_fmt(
159        &mut self,
160        fmt: std::fmt::Arguments<'_>,
161    ) -> std::io::Result<()> {
162        self.0.write_fmt(fmt)
163    }
164}
165impl<W: std::io::Write> MaybeTextWrite for TextWriteIoAdapter<W> {
166    fn as_text_write(&mut self) -> &mut dyn TextWrite {
167        self
168    }
169    fn as_io_write(&mut self) -> &mut dyn std::io::Write {
170        self
171    }
172    fn deref_dyn(&mut self) -> &mut dyn MaybeTextWrite {
173        self
174    }
175}
176impl<W: std::io::Write> From<W> for TextWriteIoAdapter<W> {
177    fn from(base: W) -> Self {
178        Self(base)
179    }
180}
181
182#[derive(Default, Clone, derive_more::Deref, derive_more::DerefMut)]
183pub struct TextWriteFormatAdapter<W: std::fmt::Write>(pub W);
184
185impl<W: std::fmt::Write> TextWrite for TextWriteFormatAdapter<W> {
186    unsafe fn write_text_unchecked(
187        &mut self,
188        buf: &[u8],
189    ) -> std::io::Result<usize> {
190        // SAFETY: because we never partially succeed, the state after this
191        // call (or any other call in this trait) will always be valid
192        // utf-8. Therefore any given `buf` must also be valid utf-8 by
193        // itself, due to the precondition of this trait method
194        match std::fmt::Write::write_str(&mut self.0, unsafe {
195            std::str::from_utf8_unchecked(buf)
196        }) {
197            Ok(()) => Ok(buf.len()),
198            Err(e) => Err(std::io::Error::new(ErrorKind::Other, e)),
199        }
200    }
201
202    fn flush_text(&mut self) -> std::io::Result<()> {
203        Ok(())
204    }
205}
206impl<W: std::fmt::Write> From<W> for TextWriteFormatAdapter<W> {
207    fn from(base: W) -> Self {
208        Self(base)
209    }
210}
211
212impl<W: std::fmt::Write> std::fmt::Write for TextWriteFormatAdapter<W> {
213    fn write_str(&mut self, s: &str) -> std::fmt::Result {
214        self.0.write_str(s)
215    }
216    fn write_char(&mut self, c: char) -> std::fmt::Result {
217        self.0.write_char(c)
218    }
219    fn write_fmt(
220        &mut self,
221        args: std::fmt::Arguments<'_>,
222    ) -> std::fmt::Result {
223        self.0.write_fmt(args)
224    }
225}
226
227pub trait MaybeTextWrite: TextWrite + std::io::Write {
228    fn as_text_write(&mut self) -> &mut dyn TextWrite;
229    fn as_io_write(&mut self) -> &mut dyn std::io::Write;
230    fn deref_dyn(&mut self) -> &mut dyn MaybeTextWrite;
231}
232
233impl MaybeTextWrite for &mut dyn MaybeTextWrite {
234    fn as_text_write(&mut self) -> &mut dyn TextWrite {
235        (**self).as_text_write()
236    }
237    fn as_io_write(&mut self) -> &mut dyn std::io::Write {
238        (**self).as_io_write()
239    }
240
241    fn deref_dyn(&mut self) -> &mut dyn MaybeTextWrite {
242        *self
243    }
244}
245impl<W: MaybeTextWrite> MaybeTextWrite for &mut W {
246    fn as_text_write(&mut self) -> &mut dyn TextWrite {
247        (**self).as_text_write()
248    }
249    fn as_io_write(&mut self) -> &mut dyn std::io::Write {
250        (**self).as_io_write()
251    }
252
253    fn deref_dyn(&mut self) -> &mut dyn MaybeTextWrite {
254        *self
255    }
256}
257
258#[derive(Default, Clone)]
259pub struct MaybeTextWriteFlaggedAdapter<W> {
260    base: W,
261    is_utf8: bool,
262}
263impl<W: MaybeTextWrite> MaybeTextWriteFlaggedAdapter<W> {
264    pub fn new(base: W) -> Self {
265        Self {
266            base,
267            is_utf8: true,
268        }
269    }
270    pub fn into_inner(self) -> W {
271        self.base
272    }
273    pub fn is_utf8(&self) -> bool {
274        self.is_utf8
275    }
276    pub unsafe fn set_is_utf8(&mut self, is_utf8: bool) {
277        self.is_utf8 = is_utf8;
278    }
279}
280impl<W: MaybeTextWrite> TextWrite for MaybeTextWriteFlaggedAdapter<W> {
281    unsafe fn write_text_unchecked(
282        &mut self,
283        buf: &[u8],
284    ) -> std::io::Result<usize> {
285        unsafe { self.base.write_text_unchecked(buf) }
286    }
287
288    fn flush_text(&mut self) -> std::io::Result<()> {
289        self.base.flush_text()
290    }
291
292    unsafe fn write_all_text_unchecked(
293        &mut self,
294        buf: &[u8],
295    ) -> std::io::Result<()> {
296        unsafe { self.base.write_all_text_unchecked(buf) }
297    }
298
299    fn write_all_text(&mut self, buf: &str) -> std::io::Result<()> {
300        self.base.write_all_text(buf)
301    }
302
303    fn write_text_fmt(
304        &mut self,
305        args: std::fmt::Arguments<'_>,
306    ) -> std::io::Result<()> {
307        self.base.write_text_fmt(args)
308    }
309}
310impl<W: MaybeTextWrite> std::io::Write for MaybeTextWriteFlaggedAdapter<W> {
311    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
312        self.is_utf8 = false;
313        self.base.write(buf)
314    }
315
316    fn flush(&mut self) -> std::io::Result<()> {
317        self.base.flush()
318    }
319
320    fn write_vectored(
321        &mut self,
322        bufs: &[std::io::IoSlice<'_>],
323    ) -> std::io::Result<usize> {
324        self.is_utf8 = false;
325        self.base.write_vectored(bufs)
326    }
327
328    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
329        self.is_utf8 = false;
330        self.base.write_all(buf)
331    }
332
333    fn write_fmt(
334        &mut self,
335        fmt: std::fmt::Arguments<'_>,
336    ) -> std::io::Result<()> {
337        self.is_utf8 = false;
338        self.base.write_fmt(fmt)
339    }
340
341    fn by_ref(&mut self) -> &mut Self
342    where
343        Self: Sized,
344    {
345        self
346    }
347}
348impl<W: MaybeTextWrite> MaybeTextWrite for MaybeTextWriteFlaggedAdapter<W> {
349    fn as_text_write(&mut self) -> &mut dyn TextWrite {
350        &mut self.base // we just pass through anyways
351    }
352    fn as_io_write(&mut self) -> &mut dyn std::io::Write {
353        self
354    }
355    fn deref_dyn(&mut self) -> &mut dyn MaybeTextWrite {
356        self
357    }
358}
359
360#[derive(Clone)]
361pub struct MaybeTextWritePanicAdapter<W: TextWrite>(pub W);
362
363impl<W: TextWrite> std::io::Write for MaybeTextWritePanicAdapter<W> {
364    fn write(&mut self, _buf: &[u8]) -> std::io::Result<usize> {
365        panic!("std::io::Write::write called on a MaybeTextWritePanicAdapter")
366    }
367
368    fn flush(&mut self) -> std::io::Result<()> {
369        panic!("std::io::Write::flush called on a MaybeTextWritePanicAdapter")
370    }
371}
372
373impl<W: TextWrite> TextWrite for MaybeTextWritePanicAdapter<W> {
374    unsafe fn write_text_unchecked(
375        &mut self,
376        buf: &[u8],
377    ) -> std::io::Result<usize> {
378        unsafe { self.0.write_text_unchecked(buf) }
379    }
380    fn flush_text(&mut self) -> std::io::Result<()> {
381        self.0.flush_text()
382    }
383    unsafe fn write_all_text_unchecked(
384        &mut self,
385        buf: &[u8],
386    ) -> std::io::Result<()> {
387        unsafe { self.0.write_all_text_unchecked(buf) }
388    }
389    fn write_all_text(&mut self, buf: &str) -> std::io::Result<()> {
390        self.0.write_all_text(buf)
391    }
392    fn write_text_fmt(
393        &mut self,
394        args: std::fmt::Arguments<'_>,
395    ) -> std::io::Result<()> {
396        self.0.write_text_fmt(args)
397    }
398}
399
400impl<W: TextWrite> MaybeTextWrite for MaybeTextWritePanicAdapter<W> {
401    fn as_text_write(&mut self) -> &mut dyn TextWrite {
402        self
403    }
404    fn as_io_write(&mut self) -> &mut dyn std::io::Write {
405        self
406    }
407    fn deref_dyn(&mut self) -> &mut dyn MaybeTextWrite {
408        self
409    }
410}
411
412#[derive(Clone)]
413pub struct MaybeTextWriteLossyAdapter<W: TextWrite>(pub W);
414
415impl<W: TextWrite> std::io::Write for MaybeTextWriteLossyAdapter<W> {
416    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
417        self.0.write_all_text(&buf.to_str_lossy())?;
418        Ok(buf.len())
419    }
420
421    fn flush(&mut self) -> std::io::Result<()> {
422        Ok(())
423    }
424}
425
426impl<W: TextWrite> TextWrite for MaybeTextWriteLossyAdapter<W> {
427    unsafe fn write_text_unchecked(
428        &mut self,
429        buf: &[u8],
430    ) -> std::io::Result<usize> {
431        unsafe { self.0.write_text_unchecked(buf) }
432    }
433    fn flush_text(&mut self) -> std::io::Result<()> {
434        self.0.flush_text()
435    }
436    unsafe fn write_all_text_unchecked(
437        &mut self,
438        buf: &[u8],
439    ) -> std::io::Result<()> {
440        unsafe { self.0.write_all_text_unchecked(buf) }
441    }
442    fn write_all_text(&mut self, buf: &str) -> std::io::Result<()> {
443        self.0.write_all_text(buf)
444    }
445    fn write_text_fmt(
446        &mut self,
447        args: std::fmt::Arguments<'_>,
448    ) -> std::io::Result<()> {
449        self.0.write_text_fmt(args)
450    }
451}
452
453impl TextWrite for String {
454    unsafe fn write_text_unchecked(
455        &mut self,
456        buf: &[u8],
457    ) -> std::io::Result<usize> {
458        unsafe { self.write_all_text_unchecked(buf).unwrap_unchecked() }
459        Ok(buf.len())
460    }
461
462    unsafe fn write_all_text_unchecked(
463        &mut self,
464        buf: &[u8],
465    ) -> std::io::Result<()> {
466        self.push_str(unsafe { std::str::from_utf8_unchecked(buf) });
467        Ok(())
468    }
469
470    fn flush_text(&mut self) -> std::io::Result<()> {
471        Ok(())
472    }
473}
474
475#[derive(Clone)]
476pub struct ByteComparingStream<'a> {
477    pub source: &'a [u8],
478    pub index: usize,
479    pub equal: bool,
480}
481
482impl<'a> ByteComparingStream<'a> {
483    pub fn new(source: &'a [u8]) -> Self {
484        Self {
485            source,
486            index: 0,
487            equal: true,
488        }
489    }
490    pub fn equal_and_done(&self) -> bool {
491        self.equal && self.index == self.source.len()
492    }
493}
494
495impl<'a> TextWrite for ByteComparingStream<'a> {
496    unsafe fn write_text_unchecked(
497        &mut self,
498        buf: &[u8],
499    ) -> std::io::Result<usize> {
500        <Self as std::io::Write>::write(self, buf)
501    }
502
503    fn flush_text(&mut self) -> std::io::Result<()> {
504        Ok(())
505    }
506}
507
508impl<'a> std::io::Write for ByteComparingStream<'a> {
509    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
510        let res = Ok(buf.len());
511        if !self.equal {
512            return res;
513        }
514        let end = self.index + buf.len();
515        if end > self.source.len() {
516            self.equal = false;
517            return res;
518        }
519        if &self.source[self.index..end] != buf {
520            self.equal = false;
521        }
522        self.index = end;
523        res
524    }
525
526    fn flush(&mut self) -> std::io::Result<()> {
527        Ok(())
528    }
529}
530
531impl MaybeTextWrite for ByteComparingStream<'_> {
532    fn as_text_write(&mut self) -> &mut dyn TextWrite {
533        self
534    }
535    fn as_io_write(&mut self) -> &mut dyn std::io::Write {
536        self
537    }
538    fn deref_dyn(&mut self) -> &mut dyn MaybeTextWrite {
539        self
540    }
541}