1use ::core::future::Future;
2
3use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _};
4use tokio_util::bytes::{Buf as _, BytesMut};
5use tokio_util::codec::{Decoder, Encoder};
6
7fn invalid_utf8() -> std::io::Error {
8 std::io::Error::new(std::io::ErrorKind::InvalidInput, "value is not valid UTF8")
9}
10
11pub trait AsyncReadUtf8: AsyncRead {
12 #[cfg_attr(
13 feature = "tracing",
14 tracing::instrument(level = "trace", ret, skip_all)
15 )]
16 fn read_char_utf8(&mut self) -> impl Future<Output = std::io::Result<char>>
17 where
18 Self: Unpin,
19 {
20 async move {
21 let b = self.read_u8().await?;
22 let i = if b & 0x80 == 0 {
23 u32::from(b)
24 } else if b & 0b1110_0000 == 0b1100_0000 {
25 let b2 = self.read_u8().await?;
26 if b2 & 0b1100_0000 != 0b1000_0000 {
27 return Err(invalid_utf8());
28 }
29 u32::from(b & 0b0001_1111) << 6 | u32::from(b2 & 0b0011_1111)
30 } else if b & 0b1111_0000 == 0b1110_0000 {
31 let mut buf = [0; 2];
32 self.read_exact(&mut buf).await?;
33 if buf[0] & 0b1100_0000 != 0b1000_0000 || buf[1] & 0b1100_0000 != 0b1000_0000 {
34 return Err(invalid_utf8());
35 }
36 u32::from(b & 0b0000_1111) << 12
37 | u32::from(buf[0] & 0b0011_1111) << 6
38 | u32::from(buf[1] & 0b0011_1111)
39 } else if b & 0b1111_1000 == 0b1111_0000 {
40 let mut buf = [0; 3];
41 self.read_exact(&mut buf).await?;
42 if buf[0] & 0b1100_0000 != 0b1000_0000
43 || buf[1] & 0b1100_0000 != 0b1000_0000
44 || buf[2] & 0b1100_0000 != 0b1000_0000
45 {
46 return Err(invalid_utf8());
47 }
48 u32::from(b & 0b0000_0111) << 18
49 | u32::from(buf[0] & 0b0011_1111) << 12
50 | u32::from(buf[1] & 0b0011_1111) << 6
51 | u32::from(buf[2] & 0b0011_1111)
52 } else {
53 return Err(invalid_utf8());
54 };
55 i.try_into()
56 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
57 }
58 }
59}
60
61impl<T: AsyncRead> AsyncReadUtf8 for T {}
62
63pub trait AsyncWriteUtf8: AsyncWrite {
64 #[cfg_attr(
65 feature = "tracing",
66 tracing::instrument(level = "trace", ret, skip_all)
67 )]
68 fn write_char_utf8(&mut self, x: char) -> impl Future<Output = std::io::Result<()>>
69 where
70 Self: Unpin,
71 {
72 async move { self.write_all(x.encode_utf8(&mut [0; 4]).as_bytes()).await }
73 }
74}
75
76impl<T: AsyncWrite> AsyncWriteUtf8 for T {}
77
78#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
79pub struct Utf8Codec;
80
81impl Decoder for Utf8Codec {
82 type Item = char;
83 type Error = std::io::Error;
84
85 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
86 let Some(b) = src.first().copied() else {
87 src.reserve(1);
88 return Ok(None);
89 };
90 let i = if b & 0x80 == 0 {
91 src.advance(1);
92 u32::from(b)
93 } else if b & 0b1110_0000 == 0b1100_0000 {
94 let Some(b2) = src.get(1).copied() else {
95 src.reserve(1);
96 return Ok(None);
97 };
98 if b2 & 0b1100_0000 != 0b1000_0000 {
99 return Err(invalid_utf8());
100 }
101 src.advance(2);
102 u32::from(b & 0b0001_1111) << 6 | u32::from(b2 & 0b0011_1111)
103 } else if b & 0b1111_0000 == 0b1110_0000 {
104 let Some(b2) = src.get(1).copied() else {
105 src.reserve(2);
106 return Ok(None);
107 };
108 let Some(b3) = src.get(2).copied() else {
109 src.reserve(1);
110 return Ok(None);
111 };
112 if b2 & 0b1100_0000 != 0b1000_0000 || b3 & 0b1100_0000 != 0b1000_0000 {
113 return Err(invalid_utf8());
114 }
115 src.advance(3);
116 u32::from(b & 0b0000_1111) << 12
117 | u32::from(b2 & 0b0011_1111) << 6
118 | u32::from(b3 & 0b0011_1111)
119 } else if b & 0b1111_1000 == 0b1111_0000 {
120 let Some(b2) = src.get(1).copied() else {
121 src.reserve(3);
122 return Ok(None);
123 };
124 let Some(b3) = src.get(2).copied() else {
125 src.reserve(2);
126 return Ok(None);
127 };
128 let Some(b4) = src.get(3).copied() else {
129 src.reserve(1);
130 return Ok(None);
131 };
132 if b2 & 0b1100_0000 != 0b1000_0000
133 || b3 & 0b1100_0000 != 0b1000_0000
134 || b4 & 0b1100_0000 != 0b1000_0000
135 {
136 return Err(invalid_utf8());
137 }
138 src.advance(4);
139 u32::from(b & 0b0000_0111) << 18
140 | u32::from(b2 & 0b0011_1111) << 12
141 | u32::from(b3 & 0b0011_1111) << 6
142 | u32::from(b4 & 0b0011_1111)
143 } else {
144 return Err(invalid_utf8());
145 };
146 let c = i
147 .try_into()
148 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
149 Ok(Some(c))
150 }
151}
152
153impl Encoder<char> for Utf8Codec {
154 type Error = std::io::Error;
155
156 fn encode(&mut self, x: char, dst: &mut BytesMut) -> Result<(), Self::Error> {
157 dst.extend_from_slice(x.encode_utf8(&mut [0; 4]).as_bytes());
158 Ok(())
159 }
160}
161
162impl Encoder<&char> for Utf8Codec {
163 type Error = std::io::Error;
164
165 fn encode(&mut self, x: &char, dst: &mut BytesMut) -> Result<(), Self::Error> {
166 self.encode(*x, dst)
167 }
168}
169
170impl Encoder<&&char> for Utf8Codec {
171 type Error = std::io::Error;
172
173 fn encode(&mut self, x: &&char, dst: &mut BytesMut) -> Result<(), Self::Error> {
174 self.encode(**x, dst)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test_log::test(tokio::test)]
183 async fn codec() {
184 let v = '$'
185 .encode_utf8(&mut [0; 1])
186 .as_bytes()
187 .read_char_utf8()
188 .await
189 .expect("failed to read `$`");
190 assert_eq!(v, '$');
191
192 let v = '@'
193 .encode_utf8(&mut [0; 1])
194 .as_bytes()
195 .read_char_utf8()
196 .await
197 .expect("failed to read `@`");
198 assert_eq!(v, '@');
199
200 let v = 'И'
201 .encode_utf8(&mut [0; 2])
202 .as_bytes()
203 .read_char_utf8()
204 .await
205 .expect("failed to read `И`");
206 assert_eq!(v, 'И');
207
208 let v = 'ह'
209 .encode_utf8(&mut [0; 3])
210 .as_bytes()
211 .read_char_utf8()
212 .await
213 .expect("failed to read `ह`");
214 assert_eq!(v, 'ह');
215
216 let v = '€'
217 .encode_utf8(&mut [0; 3])
218 .as_bytes()
219 .read_char_utf8()
220 .await
221 .expect("failed to read `€`");
222 assert_eq!(v, '€');
223
224 let v = '한'
225 .encode_utf8(&mut [0; 3])
226 .as_bytes()
227 .read_char_utf8()
228 .await
229 .expect("failed to read `한`");
230 assert_eq!(v, '한');
231
232 let v = '𐍈'
233 .encode_utf8(&mut [0; 4])
234 .as_bytes()
235 .read_char_utf8()
236 .await
237 .expect("failed to read `𐍈`");
238 assert_eq!(v, '𐍈');
239 }
240}