utf8_tokio/
lib.rs

1use ::core::future::Future;
2
3use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _};
4use tokio_util::bytes::{Buf as _, BytesMut};
5use tokio_util::codec::{Decoder, Encoder};
6
7fn invalid_utf8() -> std::io::Error {
8    std::io::Error::new(std::io::ErrorKind::InvalidInput, "value is not valid UTF8")
9}
10
11pub trait AsyncReadUtf8: AsyncRead {
12    #[cfg_attr(
13        feature = "tracing",
14        tracing::instrument(level = "trace", ret, skip_all)
15    )]
16    fn read_char_utf8(&mut self) -> impl Future<Output = std::io::Result<char>>
17    where
18        Self: Unpin,
19    {
20        async move {
21            let b = self.read_u8().await?;
22            let i = if b & 0x80 == 0 {
23                u32::from(b)
24            } else if b & 0b1110_0000 == 0b1100_0000 {
25                let b2 = self.read_u8().await?;
26                if b2 & 0b1100_0000 != 0b1000_0000 {
27                    return Err(invalid_utf8());
28                }
29                u32::from(b & 0b0001_1111) << 6 | u32::from(b2 & 0b0011_1111)
30            } else if b & 0b1111_0000 == 0b1110_0000 {
31                let mut buf = [0; 2];
32                self.read_exact(&mut buf).await?;
33                if buf[0] & 0b1100_0000 != 0b1000_0000 || buf[1] & 0b1100_0000 != 0b1000_0000 {
34                    return Err(invalid_utf8());
35                }
36                u32::from(b & 0b0000_1111) << 12
37                    | u32::from(buf[0] & 0b0011_1111) << 6
38                    | u32::from(buf[1] & 0b0011_1111)
39            } else if b & 0b1111_1000 == 0b1111_0000 {
40                let mut buf = [0; 3];
41                self.read_exact(&mut buf).await?;
42                if buf[0] & 0b1100_0000 != 0b1000_0000
43                    || buf[1] & 0b1100_0000 != 0b1000_0000
44                    || buf[2] & 0b1100_0000 != 0b1000_0000
45                {
46                    return Err(invalid_utf8());
47                }
48                u32::from(b & 0b0000_0111) << 18
49                    | u32::from(buf[0] & 0b0011_1111) << 12
50                    | u32::from(buf[1] & 0b0011_1111) << 6
51                    | u32::from(buf[2] & 0b0011_1111)
52            } else {
53                return Err(invalid_utf8());
54            };
55            i.try_into()
56                .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
57        }
58    }
59}
60
61impl<T: AsyncRead> AsyncReadUtf8 for T {}
62
63pub trait AsyncWriteUtf8: AsyncWrite {
64    #[cfg_attr(
65        feature = "tracing",
66        tracing::instrument(level = "trace", ret, skip_all)
67    )]
68    fn write_char_utf8(&mut self, x: char) -> impl Future<Output = std::io::Result<()>>
69    where
70        Self: Unpin,
71    {
72        async move { self.write_all(x.encode_utf8(&mut [0; 4]).as_bytes()).await }
73    }
74}
75
76impl<T: AsyncWrite> AsyncWriteUtf8 for T {}
77
78#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
79pub struct Utf8Codec;
80
81impl Decoder for Utf8Codec {
82    type Item = char;
83    type Error = std::io::Error;
84
85    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
86        let Some(b) = src.first().copied() else {
87            src.reserve(1);
88            return Ok(None);
89        };
90        let i = if b & 0x80 == 0 {
91            src.advance(1);
92            u32::from(b)
93        } else if b & 0b1110_0000 == 0b1100_0000 {
94            let Some(b2) = src.get(1).copied() else {
95                src.reserve(1);
96                return Ok(None);
97            };
98            if b2 & 0b1100_0000 != 0b1000_0000 {
99                return Err(invalid_utf8());
100            }
101            src.advance(2);
102            u32::from(b & 0b0001_1111) << 6 | u32::from(b2 & 0b0011_1111)
103        } else if b & 0b1111_0000 == 0b1110_0000 {
104            let Some(b2) = src.get(1).copied() else {
105                src.reserve(2);
106                return Ok(None);
107            };
108            let Some(b3) = src.get(2).copied() else {
109                src.reserve(1);
110                return Ok(None);
111            };
112            if b2 & 0b1100_0000 != 0b1000_0000 || b3 & 0b1100_0000 != 0b1000_0000 {
113                return Err(invalid_utf8());
114            }
115            src.advance(3);
116            u32::from(b & 0b0000_1111) << 12
117                | u32::from(b2 & 0b0011_1111) << 6
118                | u32::from(b3 & 0b0011_1111)
119        } else if b & 0b1111_1000 == 0b1111_0000 {
120            let Some(b2) = src.get(1).copied() else {
121                src.reserve(3);
122                return Ok(None);
123            };
124            let Some(b3) = src.get(2).copied() else {
125                src.reserve(2);
126                return Ok(None);
127            };
128            let Some(b4) = src.get(3).copied() else {
129                src.reserve(1);
130                return Ok(None);
131            };
132            if b2 & 0b1100_0000 != 0b1000_0000
133                || b3 & 0b1100_0000 != 0b1000_0000
134                || b4 & 0b1100_0000 != 0b1000_0000
135            {
136                return Err(invalid_utf8());
137            }
138            src.advance(4);
139            u32::from(b & 0b0000_0111) << 18
140                | u32::from(b2 & 0b0011_1111) << 12
141                | u32::from(b3 & 0b0011_1111) << 6
142                | u32::from(b4 & 0b0011_1111)
143        } else {
144            return Err(invalid_utf8());
145        };
146        let c = i
147            .try_into()
148            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
149        Ok(Some(c))
150    }
151}
152
153impl Encoder<char> for Utf8Codec {
154    type Error = std::io::Error;
155
156    fn encode(&mut self, x: char, dst: &mut BytesMut) -> Result<(), Self::Error> {
157        dst.extend_from_slice(x.encode_utf8(&mut [0; 4]).as_bytes());
158        Ok(())
159    }
160}
161
162impl Encoder<&char> for Utf8Codec {
163    type Error = std::io::Error;
164
165    fn encode(&mut self, x: &char, dst: &mut BytesMut) -> Result<(), Self::Error> {
166        self.encode(*x, dst)
167    }
168}
169
170impl Encoder<&&char> for Utf8Codec {
171    type Error = std::io::Error;
172
173    fn encode(&mut self, x: &&char, dst: &mut BytesMut) -> Result<(), Self::Error> {
174        self.encode(**x, dst)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test_log::test(tokio::test)]
183    async fn codec() {
184        let v = '$'
185            .encode_utf8(&mut [0; 1])
186            .as_bytes()
187            .read_char_utf8()
188            .await
189            .expect("failed to read `$`");
190        assert_eq!(v, '$');
191
192        let v = '@'
193            .encode_utf8(&mut [0; 1])
194            .as_bytes()
195            .read_char_utf8()
196            .await
197            .expect("failed to read `@`");
198        assert_eq!(v, '@');
199
200        let v = 'И'
201            .encode_utf8(&mut [0; 2])
202            .as_bytes()
203            .read_char_utf8()
204            .await
205            .expect("failed to read `И`");
206        assert_eq!(v, 'И');
207
208        let v = 'ह'
209            .encode_utf8(&mut [0; 3])
210            .as_bytes()
211            .read_char_utf8()
212            .await
213            .expect("failed to read `ह`");
214        assert_eq!(v, 'ह');
215
216        let v = '€'
217            .encode_utf8(&mut [0; 3])
218            .as_bytes()
219            .read_char_utf8()
220            .await
221            .expect("failed to read `€`");
222        assert_eq!(v, '€');
223
224        let v = '한'
225            .encode_utf8(&mut [0; 3])
226            .as_bytes()
227            .read_char_utf8()
228            .await
229            .expect("failed to read `한`");
230        assert_eq!(v, '한');
231
232        let v = '𐍈'
233            .encode_utf8(&mut [0; 4])
234            .as_bytes()
235            .read_char_utf8()
236            .await
237            .expect("failed to read `𐍈`");
238        assert_eq!(v, '𐍈');
239    }
240}