webtrans_proto/
capsule.rs1use std::sync::Arc;
4
5use bytes::{Buf, BufMut, Bytes, BytesMut};
6use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
7
8use crate::grease::is_grease_value;
9use crate::io::read_incremental;
10use crate::{VarInt, VarIntUnexpectedEnd};
11
12const CLOSE_WEBTRANSPORT_SESSION_TYPE: u64 = 0x2843;
16const MAX_MESSAGE_SIZE: usize = 1024;
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum Capsule {
21 CloseWebTransportSession {
23 code: u32,
25 reason: String,
27 },
28 Unknown {
30 typ: VarInt,
32 payload: Bytes,
34 },
35}
36
37impl Capsule {
38 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, CapsuleError> {
40 loop {
41 let typ = VarInt::decode(buf)?;
42 let length = VarInt::decode(buf)?;
43
44 let mut payload = buf.take(length.into_inner() as usize);
45 if payload.remaining() > MAX_MESSAGE_SIZE {
46 return Err(CapsuleError::MessageTooLong);
47 }
48
49 if payload.remaining() < payload.limit() {
50 return Err(CapsuleError::UnexpectedEnd);
51 }
52
53 match typ.into_inner() {
54 CLOSE_WEBTRANSPORT_SESSION_TYPE => {
55 if payload.remaining() < 4 {
56 return Err(CapsuleError::UnexpectedEnd);
57 }
58
59 let error_code = payload.get_u32();
60
61 let message_len = payload.remaining();
62 if message_len > MAX_MESSAGE_SIZE {
63 return Err(CapsuleError::MessageTooLong);
64 }
65
66 let message_bytes = payload.copy_to_bytes(message_len);
67 let error_message = String::from_utf8(message_bytes.to_vec())
68 .map_err(|_| CapsuleError::InvalidUtf8)?;
69
70 return Ok(Self::CloseWebTransportSession {
71 code: error_code,
72 reason: error_message,
73 });
74 }
75 t if is_grease(t) => continue,
76 _ => {
77 let payload_bytes = payload.copy_to_bytes(payload.remaining());
78 return Ok(Self::Unknown {
79 typ,
80 payload: payload_bytes,
81 });
82 }
83 }
84 }
85 }
86
87 pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, CapsuleError> {
89 read_incremental(
90 stream,
91 |cursor| Self::decode(cursor),
92 |err| matches!(err, CapsuleError::UnexpectedEnd),
93 CapsuleError::UnexpectedEnd,
94 )
95 .await
96 }
97
98 pub fn encode<B: BufMut>(&self, buf: &mut B) {
100 match self {
101 Self::CloseWebTransportSession {
102 code: error_code,
103 reason: error_message,
104 } => {
105 VarInt::from_u64(CLOSE_WEBTRANSPORT_SESSION_TYPE)
107 .unwrap()
108 .encode(buf);
109
110 let length = 4 + error_message.len();
112 VarInt::from_u32(length as u32).encode(buf);
113
114 buf.put_u32(*error_code);
116
117 buf.put_slice(error_message.as_bytes());
119 }
120 Self::Unknown { typ, payload } => {
121 typ.encode(buf);
123
124 VarInt::try_from(payload.len()).unwrap().encode(buf);
126
127 buf.put_slice(payload);
129 }
130 }
131 }
132
133 pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), CapsuleError> {
135 let mut buf = BytesMut::new();
136 self.encode(&mut buf);
137 stream.write_all_buf(&mut buf).await?;
138 Ok(())
139 }
140}
141
142fn is_grease(val: u64) -> bool {
143 is_grease_value(val)
144}
145
146#[derive(Debug, Clone, thiserror::Error)]
147pub enum CapsuleError {
149 #[error("unexpected end of buffer")]
150 UnexpectedEnd,
152
153 #[error("invalid UTF-8")]
154 InvalidUtf8,
156
157 #[error("message too long")]
158 MessageTooLong,
160
161 #[error("unknown capsule type: {0:?}")]
162 UnknownType(VarInt),
164
165 #[error("varint decode error: {0:?}")]
166 VarInt(#[from] VarIntUnexpectedEnd),
168
169 #[error("io error: {0}")]
170 Io(Arc<std::io::Error>),
172}
173
174impl From<std::io::Error> for CapsuleError {
175 fn from(err: std::io::Error) -> Self {
176 CapsuleError::Io(Arc::new(err))
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use bytes::Bytes;
184
185 #[test]
186 fn test_close_webtransport_session_decode() {
187 let mut data = Vec::new();
189 VarInt::from_u64(0x2843).unwrap().encode(&mut data);
190 VarInt::from_u32(8).encode(&mut data);
191 data.extend_from_slice(b"\x00\x00\x01\xa4test");
192
193 let mut buf = data.as_slice();
194 let capsule = Capsule::decode(&mut buf).unwrap();
195
196 match capsule {
197 Capsule::CloseWebTransportSession {
198 code: error_code,
199 reason: error_message,
200 } => {
201 assert_eq!(error_code, 420);
202 assert_eq!(error_message, "test");
203 }
204 _ => panic!("Expected CloseWebTransportSession"),
205 }
206
207 assert_eq!(buf.len(), 0); }
209
210 #[test]
211 fn test_close_webtransport_session_encode() {
212 let capsule = Capsule::CloseWebTransportSession {
213 code: 420,
214 reason: "test".to_string(),
215 };
216
217 let mut buf = Vec::new();
218 capsule.encode(&mut buf);
219
220 assert_eq!(buf, b"\x68\x43\x08\x00\x00\x01\xa4test");
223 }
224
225 #[test]
226 fn test_close_webtransport_session_roundtrip() {
227 let original = Capsule::CloseWebTransportSession {
228 code: 12345,
229 reason: "Connection closed by application".to_string(),
230 };
231
232 let mut buf = Vec::new();
233 original.encode(&mut buf);
234
235 let mut read_buf = buf.as_slice();
236 let decoded = Capsule::decode(&mut read_buf).unwrap();
237
238 assert_eq!(original, decoded);
239 assert_eq!(read_buf.len(), 0); }
241
242 #[test]
243 fn test_empty_error_message() {
244 let capsule = Capsule::CloseWebTransportSession {
245 code: 0,
246 reason: String::new(),
247 };
248
249 let mut buf = Vec::new();
250 capsule.encode(&mut buf);
251
252 assert_eq!(buf, b"\x68\x43\x04\x00\x00\x00\x00");
254
255 let mut read_buf = buf.as_slice();
256 let decoded = Capsule::decode(&mut read_buf).unwrap();
257 assert_eq!(capsule, decoded);
258 }
259
260 #[test]
261 fn test_invalid_utf8() {
262 let mut data = Vec::new();
264 VarInt::from_u64(0x2843).unwrap().encode(&mut data); VarInt::from_u32(5).encode(&mut data); data.extend_from_slice(b"\x00\x00\x00\x00"); data.push(0xFF); let mut buf = data.as_slice();
270 let result = Capsule::decode(&mut buf);
271 assert!(matches!(result, Err(CapsuleError::InvalidUtf8)));
272 }
273
274 #[test]
275 fn test_truncated_error_code() {
276 let mut data = Vec::new();
278 VarInt::from_u64(0x2843).unwrap().encode(&mut data); VarInt::from_u32(3).encode(&mut data); data.extend_from_slice(b"\x00\x00\x00"); let mut buf = data.as_slice();
283 let result = Capsule::decode(&mut buf);
284 assert!(matches!(result, Err(CapsuleError::UnexpectedEnd)));
285 }
286
287 #[test]
288 fn test_unknown_capsule() {
289 let unknown_type = 0x1234u64;
291 let payload_data = b"unknown payload";
292
293 let mut data = Vec::new();
294 VarInt::from_u64(unknown_type).unwrap().encode(&mut data);
295 VarInt::from_u32(payload_data.len() as u32).encode(&mut data);
296 data.extend_from_slice(payload_data);
297
298 let mut buf = data.as_slice();
299 let capsule = Capsule::decode(&mut buf).unwrap();
300
301 match capsule {
302 Capsule::Unknown { typ, payload } => {
303 assert_eq!(typ.into_inner(), unknown_type);
304 assert_eq!(payload.as_ref(), payload_data);
305 }
306 _ => panic!("Expected Unknown capsule"),
307 }
308 }
309
310 #[test]
311 fn test_unknown_capsule_roundtrip() {
312 let capsule = Capsule::Unknown {
313 typ: VarInt::from_u64(0x9999).unwrap(),
314 payload: Bytes::from("test payload"),
315 };
316
317 let mut buf = Vec::new();
318 capsule.encode(&mut buf);
319
320 let mut read_buf = buf.as_slice();
321 let decoded = Capsule::decode(&mut read_buf).unwrap();
322
323 assert_eq!(capsule, decoded);
324 assert_eq!(read_buf.len(), 0);
325 }
326}