1use bytes::{Buf, BufMut, BytesMut};
16use swimos_api::error::{FrameIoError, InvalidFrame};
17use swimos_form::{read::RecognizerReadable, write::StructuralWritable};
18use swimos_model::Text;
19use swimos_recon::{
20 parser::{AsyncParseError, RecognizerDecoder},
21 write_recon,
22};
23use swimos_utilities::encoding::consume_bounded;
24use tokio_util::codec::{Decoder, Encoder};
25
26mod parser;
27#[cfg(test)]
28mod tests;
29
30pub use parser::{extract_header, extract_header_str};
31
32use crate::{MapMessage, MapOperation};
33
34#[derive(Debug, Default, Clone, Copy)]
35pub struct MapOperationEncoder;
36#[derive(Debug, Default, Clone, Copy)]
37pub struct RawMapOperationEncoder;
38
39pub enum MapOperationDecoderState<K, V> {
40 ReadingHeader,
41 ReadingKey {
42 remaining: usize,
43 value_size: Option<usize>,
44 },
45 AfterKey {
46 key: Option<K>,
47 remaining: usize,
48 value_size: Option<usize>,
49 },
50 ReadingValue {
51 key: Option<K>,
52 remaining: usize,
53 },
54 AfterValue {
55 key_value: Option<(K, V)>,
56 remaining: usize,
57 },
58 Discarding {
59 error: Option<AsyncParseError>,
60 remaining: usize,
61 },
62}
63
64pub struct MapOperationDecoder<K: RecognizerReadable, V: RecognizerReadable> {
65 state: MapOperationDecoderState<K, V>,
66 key_recognizer: RecognizerDecoder<K::Rec>,
67 value_recognizer: RecognizerDecoder<V::Rec>,
68}
69
70impl<K: RecognizerReadable, V: RecognizerReadable> Default for MapOperationDecoder<K, V> {
71 fn default() -> Self {
72 Self {
73 state: MapOperationDecoderState::ReadingHeader,
74 key_recognizer: RecognizerDecoder::new(K::make_recognizer()),
75 value_recognizer: RecognizerDecoder::new(V::make_recognizer()),
76 }
77 }
78}
79
80#[derive(Debug, Default, Clone, Copy)]
81pub struct RawMapOperationDecoder;
82
83const UPDATE: u8 = 0;
84const REMOVE: u8 = 1;
85const CLEAR: u8 = 2;
86const TAKE: u8 = 3;
87const DROP: u8 = 4;
88
89const LEN_SIZE: usize = std::mem::size_of::<u64>();
90const TAG_SIZE: usize = std::mem::size_of::<u8>();
91
92const OVERSIZE_KEY: &str = "Key too large.";
93const OVERSIZE_RECORD: &str = "Record too large.";
94const BAD_TAG: &str = "Invalid map operation tag: ";
95const BAD_RECORD_SIZE: &str = "Invalid record size: ";
96const BAD_KEY_SIZE: &str = "Invalid key size: ";
97
98impl<K: AsRef<[u8]>, V: AsRef<[u8]>> Encoder<MapOperation<K, V>> for RawMapOperationEncoder {
99 type Error = std::io::Error;
100
101 fn encode(&mut self, item: MapOperation<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
102 match item {
103 MapOperation::Update { key, value } => {
104 let key_bytes = key.as_ref();
105 let value_bytes = value.as_ref();
106 let total_len = key_bytes.len() + value_bytes.len() + LEN_SIZE + TAG_SIZE;
107 dst.reserve(total_len + LEN_SIZE);
108 dst.put_u64(u64::try_from(total_len).expect(OVERSIZE_RECORD));
109 dst.put_u8(UPDATE);
110 let key_len = u64::try_from(key_bytes.len()).expect(OVERSIZE_KEY);
111 dst.put_u64(key_len);
112 dst.put(key_bytes);
113 dst.put(value_bytes);
114 }
115 MapOperation::Remove { key } => {
116 let key_bytes = key.as_ref();
117 let total_len = key_bytes.len() + TAG_SIZE;
118 dst.reserve(total_len + LEN_SIZE);
119 dst.put_u64(u64::try_from(total_len).expect(OVERSIZE_RECORD));
120 dst.put_u8(REMOVE);
121 dst.put(key_bytes);
122 }
123 MapOperation::Clear => {
124 dst.reserve(LEN_SIZE + TAG_SIZE);
125 dst.put_u64(TAG_SIZE as u64);
126 dst.put_u8(CLEAR);
127 }
128 }
129 Ok(())
130 }
131}
132
133impl Decoder for RawMapOperationDecoder {
134 type Item = MapOperation<BytesMut, BytesMut>;
135
136 type Error = FrameIoError;
137
138 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
139 if src.remaining() < LEN_SIZE + TAG_SIZE {
140 src.reserve(LEN_SIZE + TAG_SIZE);
141 return Ok(None);
142 }
143 let mut header = src.as_ref();
144 let total_len = header.get_u64() as usize;
145 let tag = header.get_u8();
146 match tag {
147 UPDATE => {
148 if total_len < LEN_SIZE + TAG_SIZE {
149 return Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
150 problem: Text::from(format!("{}{}", BAD_RECORD_SIZE, total_len)),
151 }));
152 }
153 let required = LEN_SIZE + total_len;
154 if src.remaining() < required {
155 return Ok(None);
156 }
157 src.advance(LEN_SIZE);
158 let mut frame = src.split_to(total_len);
159 frame.advance(TAG_SIZE);
160 let key_len = frame.get_u64() as usize;
161
162 if key_len + LEN_SIZE + TAG_SIZE > total_len {
163 return Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
164 problem: Text::from(format!("{}{}", BAD_KEY_SIZE, key_len)),
165 }));
166 }
167
168 let key = frame.split_to(key_len);
169
170 Ok(Some(MapOperation::Update { key, value: frame }))
171 }
172 REMOVE => {
173 if total_len < TAG_SIZE {
174 return Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
175 problem: Text::from(format!("{}{}", BAD_RECORD_SIZE, total_len)),
176 }));
177 }
178 let required = LEN_SIZE + total_len;
179 if src.remaining() < required {
180 return Ok(None);
181 }
182 src.advance(LEN_SIZE);
183 let mut frame = src.split_to(total_len);
184 frame.advance(TAG_SIZE);
185
186 Ok(Some(MapOperation::Remove { key: frame }))
187 }
188 CLEAR => {
189 if total_len == TAG_SIZE {
190 src.advance(LEN_SIZE + TAG_SIZE);
191 Ok(Some(MapOperation::Clear))
192 } else {
193 Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
194 problem: Text::from(format!("{}{}", BAD_RECORD_SIZE, total_len)),
195 }))
196 }
197 }
198 ow => Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
199 problem: Text::from(format!("{}{}", BAD_TAG, ow)),
200 })),
201 }
202 }
203}
204
205impl<K: RecognizerReadable, V: RecognizerReadable> Decoder for MapOperationDecoder<K, V> {
206 type Item = MapOperation<K, V>;
207
208 type Error = FrameIoError;
209
210 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
211 let MapOperationDecoder {
212 state,
213 key_recognizer,
214 value_recognizer,
215 } = self;
216 loop {
217 match state {
218 MapOperationDecoderState::ReadingHeader => {
219 if src.remaining() < LEN_SIZE + TAG_SIZE {
220 src.reserve(LEN_SIZE + TAG_SIZE);
221 break Ok(None);
222 }
223 let mut header = src.as_ref();
224 let total_len = header.get_u64() as usize;
225 let tag = header.get_u8();
226 match tag {
227 UPDATE => {
228 let required = TAG_SIZE + 2 * LEN_SIZE;
229 if src.remaining() < required {
230 src.reserve(required - src.remaining());
231 break Ok(None);
232 }
233 let key_len = header.get_u64() as usize;
234 let value_len = if let Some(l) =
235 total_len.checked_sub(key_len + LEN_SIZE + TAG_SIZE)
236 {
237 l
238 } else {
239 break Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
240 problem: Text::from(format!("{}{}", BAD_KEY_SIZE, key_len)),
241 }));
242 };
243
244 src.advance(TAG_SIZE + 2 * LEN_SIZE);
245 *state = MapOperationDecoderState::ReadingKey {
246 remaining: key_len,
247 value_size: Some(value_len),
248 };
249 }
250 REMOVE => {
251 let key_len = if let Some(l) = total_len.checked_sub(TAG_SIZE) {
252 l
253 } else {
254 break Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
255 problem: Text::from(format!(
256 "{}{}",
257 BAD_RECORD_SIZE, total_len
258 )),
259 }));
260 };
261 src.advance(LEN_SIZE + TAG_SIZE);
262 *state = MapOperationDecoderState::ReadingKey {
263 remaining: key_len,
264 value_size: None,
265 };
266 }
267 CLEAR => {
268 src.advance(TAG_SIZE + LEN_SIZE);
269 break Ok(Some(MapOperation::Clear));
270 }
271 ow => {
272 break Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
273 problem: Text::from(format!("Invalid map operation tag: {}", ow)),
274 }));
275 }
276 }
277 }
278 MapOperationDecoderState::ReadingKey {
279 remaining,
280 value_size,
281 } => {
282 let (consumed, decode_result) =
283 consume_bounded(*remaining, src, key_recognizer);
284 *remaining -= consumed;
285 match decode_result {
286 Ok(Some(result)) => {
287 *state = MapOperationDecoderState::AfterKey {
288 key: Some(result),
289 remaining: *remaining,
290 value_size: *value_size,
291 }
292 }
293 Ok(None) => {
294 break Ok(None);
295 }
296 Err(e) => {
297 let rem = src.remaining();
298 if rem >= *remaining {
299 src.advance(*remaining);
300 *state = MapOperationDecoderState::ReadingHeader;
301 break Err(e.into());
302 } else {
303 src.clear();
304 *state = MapOperationDecoderState::Discarding {
305 remaining: *remaining - rem,
306 error: Some(e),
307 }
308 }
309 }
310 }
311 }
312 MapOperationDecoderState::AfterKey {
313 key,
314 remaining,
315 value_size,
316 } => {
317 if src.remaining() >= *remaining {
318 src.advance(*remaining);
319 if let Some(value_size) = value_size.take() {
320 *state = MapOperationDecoderState::ReadingValue {
321 key: key.take(),
322 remaining: value_size,
323 };
324 } else {
325 let op = key.take().map(|key| MapOperation::Remove { key });
326 *state = MapOperationDecoderState::ReadingHeader;
327 break Ok(op);
328 }
329 } else {
330 *remaining -= src.remaining();
331 src.clear();
332 break Ok(None);
333 }
334 }
335 MapOperationDecoderState::ReadingValue { key, remaining } => {
336 let (consumed, decode_result) =
337 consume_bounded(*remaining, src, value_recognizer);
338 *remaining -= consumed;
339 match decode_result {
340 Ok(Some(value)) => {
341 *state = MapOperationDecoderState::AfterValue {
342 key_value: key.take().map(move |k| (k, value)),
343 remaining: *remaining,
344 }
345 }
346 Ok(None) => {
347 break Ok(None);
348 }
349 Err(e) => {
350 let rem = src.remaining();
351 if rem >= *remaining {
352 src.advance(*remaining);
353 *state = MapOperationDecoderState::ReadingHeader;
354 break Err(e.into());
355 } else {
356 src.clear();
357 *state = MapOperationDecoderState::Discarding {
358 remaining: *remaining - rem,
359 error: Some(e),
360 }
361 }
362 }
363 }
364 }
365 MapOperationDecoderState::AfterValue {
366 key_value,
367 remaining,
368 } => {
369 if src.remaining() >= *remaining {
370 src.advance(*remaining);
371 let result = key_value
372 .take()
373 .map(|(key, value)| MapOperation::Update { key, value });
374 *state = MapOperationDecoderState::ReadingHeader;
375 break Ok(result);
376 } else {
377 *remaining -= src.remaining();
378 src.clear();
379 break Ok(None);
380 }
381 }
382 MapOperationDecoderState::Discarding { error, remaining } => {
383 if src.remaining() >= *remaining {
384 src.advance(*remaining);
385 let err = error.take().unwrap_or(AsyncParseError::UnconsumedInput);
386 *state = MapOperationDecoderState::ReadingHeader;
387 break Err(err.into());
388 } else {
389 *remaining -= src.remaining();
390 src.clear();
391 break Ok(None);
392 }
393 }
394 }
395 }
396 }
397}
398
399impl<K: StructuralWritable, V: StructuralWritable> Encoder<MapOperation<K, V>>
400 for MapOperationEncoder
401{
402 type Error = std::io::Error;
403
404 fn encode(&mut self, item: MapOperation<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
405 dst.reserve(TAG_SIZE);
406 match item {
407 MapOperation::Update { key, value } => {
408 dst.reserve(2 * LEN_SIZE + TAG_SIZE);
409 let body_len_offset = dst.remaining();
410 dst.put_u64(0);
411 dst.put_u8(0);
412 dst.put_u64(0);
413 let key_len = write_recon(dst, &key);
414 let value_len = write_recon(dst, &value);
415 let total_len = key_len + value_len + LEN_SIZE + TAG_SIZE;
416 let mut rewound = &mut dst.as_mut()[body_len_offset..];
417 rewound.put_u64(u64::try_from(total_len).expect(OVERSIZE_KEY));
418 rewound.put_u8(UPDATE);
419 rewound.put_u64(u64::try_from(key_len).expect(OVERSIZE_RECORD));
420 }
421 MapOperation::Remove { key } => {
422 dst.reserve(LEN_SIZE + TAG_SIZE);
423 let body_len_offset = dst.remaining();
424 dst.put_u64(0);
425 dst.put_u8(REMOVE);
426 let key_len = write_recon(dst, &key);
427 let total_len = key_len + TAG_SIZE;
428 let mut rewound = &mut dst.as_mut()[body_len_offset..];
429 rewound.put_u64(u64::try_from(total_len).expect(OVERSIZE_KEY));
430 }
431 MapOperation::Clear => {
432 dst.put_u64(TAG_SIZE as u64);
433 dst.put_u8(CLEAR);
434 }
435 }
436 Ok(())
437 }
438}
439
440#[derive(Debug, Default, Clone, Copy)]
441struct MessageEncoder<Inner>(Inner);
442
443#[derive(Debug, Default, Clone, Copy)]
444struct MessageDecoder<Inner>(Inner);
445
446impl<K, V, Inner> Encoder<MapMessage<K, V>> for MessageEncoder<Inner>
447where
448 Inner: Encoder<MapOperation<K, V>>,
449{
450 type Error = Inner::Error;
451
452 fn encode(&mut self, item: MapMessage<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
453 let MessageEncoder(inner) = self;
454 match item {
455 MapMessage::Update { key, value } => {
456 inner.encode(MapOperation::Update { key, value }, dst)
457 }
458 MapMessage::Remove { key } => inner.encode(MapOperation::Remove { key }, dst),
459 MapMessage::Clear => inner.encode(MapOperation::Clear, dst),
460 MapMessage::Take(n) => {
461 dst.reserve(TAG_SIZE + 2 * LEN_SIZE);
462 dst.put_u64((TAG_SIZE + LEN_SIZE) as u64);
463 dst.put_u8(TAKE);
464 dst.put_u64(n);
465 Ok(())
466 }
467 MapMessage::Drop(n) => {
468 dst.reserve(TAG_SIZE + 2 * LEN_SIZE);
469 dst.put_u64((TAG_SIZE + LEN_SIZE) as u64);
470 dst.put_u8(DROP);
471 dst.put_u64(n);
472 Ok(())
473 }
474 }
475 }
476}
477
478impl<K, V, Inner> Decoder for MessageDecoder<Inner>
479where
480 Inner: Decoder<Item = MapOperation<K, V>, Error = FrameIoError>,
481{
482 type Item = MapMessage<K, V>;
483
484 type Error = FrameIoError;
485
486 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
487 let MessageDecoder(inner) = self;
488 if src.remaining() < TAG_SIZE + LEN_SIZE {
489 src.reserve(TAG_SIZE + LEN_SIZE);
490 return Ok(None);
491 }
492 let mut header = src.as_ref();
493 let total_len = header.get_u64() as usize;
494 match header.get_u8() {
495 tag @ (TAKE | DROP) => {
496 if total_len != TAG_SIZE + LEN_SIZE {
497 return Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
498 problem: Text::new(BAD_RECORD_SIZE),
499 }));
500 }
501 let required = TAG_SIZE + 2 * LEN_SIZE;
502 if src.remaining() < required {
503 src.reserve(required - src.remaining());
504 return Ok(None);
505 }
506 src.advance(TAG_SIZE + LEN_SIZE);
507 let n = src.get_u64();
508 Ok(Some(if tag == TAKE {
509 MapMessage::Take(n)
510 } else {
511 MapMessage::Drop(n)
512 }))
513 }
514 _ => {
515 let result = inner.decode(src)?;
516 Ok(result.map(Into::into))
517 }
518 }
519 }
520}
521
522#[derive(Debug, Default, Clone, Copy)]
523pub struct RawMapMessageEncoder {
524 inner: MessageEncoder<RawMapOperationEncoder>,
525}
526
527impl<K: AsRef<[u8]>, V: AsRef<[u8]>> Encoder<MapMessage<K, V>> for RawMapMessageEncoder {
528 type Error = std::io::Error;
529
530 fn encode(&mut self, item: MapMessage<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
531 self.inner.encode(item, dst)
532 }
533}
534
535#[derive(Debug, Default, Clone, Copy)]
536pub struct RawMapMessageDecoder {
537 inner: MessageDecoder<RawMapOperationDecoder>,
538}
539
540impl Decoder for RawMapMessageDecoder {
541 type Item = MapMessage<BytesMut, BytesMut>;
542
543 type Error = FrameIoError;
544
545 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
546 self.inner.decode(src)
547 }
548}
549
550#[derive(Debug, Default, Clone, Copy)]
551pub struct MapMessageEncoder {
552 inner: MessageEncoder<MapOperationEncoder>,
553}
554
555impl<K: StructuralWritable, V: StructuralWritable> Encoder<MapMessage<K, V>> for MapMessageEncoder {
556 type Error = std::io::Error;
557
558 fn encode(&mut self, item: MapMessage<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
559 self.inner.encode(item, dst)
560 }
561}
562
563pub struct MapMessageDecoder<K: RecognizerReadable, V: RecognizerReadable> {
564 inner: MessageDecoder<MapOperationDecoder<K, V>>,
565}
566
567impl<K: RecognizerReadable, V: RecognizerReadable> Default for MapMessageDecoder<K, V> {
568 fn default() -> Self {
569 Self {
570 inner: Default::default(),
571 }
572 }
573}
574
575impl<K: RecognizerReadable, V: RecognizerReadable> Decoder for MapMessageDecoder<K, V> {
576 type Item = MapMessage<K, V>;
577
578 type Error = FrameIoError;
579
580 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
581 self.inner.decode(src)
582 }
583}