swimos_agent_protocol/map/
mod.rs

1// Copyright 2015-2024 Swim Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bytes::{Buf, BufMut, BytesMut};
16use swimos_api::error::{FrameIoError, InvalidFrame};
17use swimos_form::{read::RecognizerReadable, write::StructuralWritable};
18use swimos_model::Text;
19use swimos_recon::{
20    parser::{AsyncParseError, RecognizerDecoder},
21    write_recon,
22};
23use swimos_utilities::encoding::consume_bounded;
24use tokio_util::codec::{Decoder, Encoder};
25
26mod parser;
27#[cfg(test)]
28mod tests;
29
30pub use parser::{extract_header, extract_header_str};
31
32use crate::{MapMessage, MapOperation};
33
34#[derive(Debug, Default, Clone, Copy)]
35pub struct MapOperationEncoder;
36#[derive(Debug, Default, Clone, Copy)]
37pub struct RawMapOperationEncoder;
38
39pub enum MapOperationDecoderState<K, V> {
40    ReadingHeader,
41    ReadingKey {
42        remaining: usize,
43        value_size: Option<usize>,
44    },
45    AfterKey {
46        key: Option<K>,
47        remaining: usize,
48        value_size: Option<usize>,
49    },
50    ReadingValue {
51        key: Option<K>,
52        remaining: usize,
53    },
54    AfterValue {
55        key_value: Option<(K, V)>,
56        remaining: usize,
57    },
58    Discarding {
59        error: Option<AsyncParseError>,
60        remaining: usize,
61    },
62}
63
64pub struct MapOperationDecoder<K: RecognizerReadable, V: RecognizerReadable> {
65    state: MapOperationDecoderState<K, V>,
66    key_recognizer: RecognizerDecoder<K::Rec>,
67    value_recognizer: RecognizerDecoder<V::Rec>,
68}
69
70impl<K: RecognizerReadable, V: RecognizerReadable> Default for MapOperationDecoder<K, V> {
71    fn default() -> Self {
72        Self {
73            state: MapOperationDecoderState::ReadingHeader,
74            key_recognizer: RecognizerDecoder::new(K::make_recognizer()),
75            value_recognizer: RecognizerDecoder::new(V::make_recognizer()),
76        }
77    }
78}
79
80#[derive(Debug, Default, Clone, Copy)]
81pub struct RawMapOperationDecoder;
82
83const UPDATE: u8 = 0;
84const REMOVE: u8 = 1;
85const CLEAR: u8 = 2;
86const TAKE: u8 = 3;
87const DROP: u8 = 4;
88
89const LEN_SIZE: usize = std::mem::size_of::<u64>();
90const TAG_SIZE: usize = std::mem::size_of::<u8>();
91
92const OVERSIZE_KEY: &str = "Key too large.";
93const OVERSIZE_RECORD: &str = "Record too large.";
94const BAD_TAG: &str = "Invalid map operation tag: ";
95const BAD_RECORD_SIZE: &str = "Invalid record size: ";
96const BAD_KEY_SIZE: &str = "Invalid key size: ";
97
98impl<K: AsRef<[u8]>, V: AsRef<[u8]>> Encoder<MapOperation<K, V>> for RawMapOperationEncoder {
99    type Error = std::io::Error;
100
101    fn encode(&mut self, item: MapOperation<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
102        match item {
103            MapOperation::Update { key, value } => {
104                let key_bytes = key.as_ref();
105                let value_bytes = value.as_ref();
106                let total_len = key_bytes.len() + value_bytes.len() + LEN_SIZE + TAG_SIZE;
107                dst.reserve(total_len + LEN_SIZE);
108                dst.put_u64(u64::try_from(total_len).expect(OVERSIZE_RECORD));
109                dst.put_u8(UPDATE);
110                let key_len = u64::try_from(key_bytes.len()).expect(OVERSIZE_KEY);
111                dst.put_u64(key_len);
112                dst.put(key_bytes);
113                dst.put(value_bytes);
114            }
115            MapOperation::Remove { key } => {
116                let key_bytes = key.as_ref();
117                let total_len = key_bytes.len() + TAG_SIZE;
118                dst.reserve(total_len + LEN_SIZE);
119                dst.put_u64(u64::try_from(total_len).expect(OVERSIZE_RECORD));
120                dst.put_u8(REMOVE);
121                dst.put(key_bytes);
122            }
123            MapOperation::Clear => {
124                dst.reserve(LEN_SIZE + TAG_SIZE);
125                dst.put_u64(TAG_SIZE as u64);
126                dst.put_u8(CLEAR);
127            }
128        }
129        Ok(())
130    }
131}
132
133impl Decoder for RawMapOperationDecoder {
134    type Item = MapOperation<BytesMut, BytesMut>;
135
136    type Error = FrameIoError;
137
138    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
139        if src.remaining() < LEN_SIZE + TAG_SIZE {
140            src.reserve(LEN_SIZE + TAG_SIZE);
141            return Ok(None);
142        }
143        let mut header = src.as_ref();
144        let total_len = header.get_u64() as usize;
145        let tag = header.get_u8();
146        match tag {
147            UPDATE => {
148                if total_len < LEN_SIZE + TAG_SIZE {
149                    return Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
150                        problem: Text::from(format!("{}{}", BAD_RECORD_SIZE, total_len)),
151                    }));
152                }
153                let required = LEN_SIZE + total_len;
154                if src.remaining() < required {
155                    return Ok(None);
156                }
157                src.advance(LEN_SIZE);
158                let mut frame = src.split_to(total_len);
159                frame.advance(TAG_SIZE);
160                let key_len = frame.get_u64() as usize;
161
162                if key_len + LEN_SIZE + TAG_SIZE > total_len {
163                    return Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
164                        problem: Text::from(format!("{}{}", BAD_KEY_SIZE, key_len)),
165                    }));
166                }
167
168                let key = frame.split_to(key_len);
169
170                Ok(Some(MapOperation::Update { key, value: frame }))
171            }
172            REMOVE => {
173                if total_len < TAG_SIZE {
174                    return Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
175                        problem: Text::from(format!("{}{}", BAD_RECORD_SIZE, total_len)),
176                    }));
177                }
178                let required = LEN_SIZE + total_len;
179                if src.remaining() < required {
180                    return Ok(None);
181                }
182                src.advance(LEN_SIZE);
183                let mut frame = src.split_to(total_len);
184                frame.advance(TAG_SIZE);
185
186                Ok(Some(MapOperation::Remove { key: frame }))
187            }
188            CLEAR => {
189                if total_len == TAG_SIZE {
190                    src.advance(LEN_SIZE + TAG_SIZE);
191                    Ok(Some(MapOperation::Clear))
192                } else {
193                    Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
194                        problem: Text::from(format!("{}{}", BAD_RECORD_SIZE, total_len)),
195                    }))
196                }
197            }
198            ow => Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
199                problem: Text::from(format!("{}{}", BAD_TAG, ow)),
200            })),
201        }
202    }
203}
204
205impl<K: RecognizerReadable, V: RecognizerReadable> Decoder for MapOperationDecoder<K, V> {
206    type Item = MapOperation<K, V>;
207
208    type Error = FrameIoError;
209
210    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
211        let MapOperationDecoder {
212            state,
213            key_recognizer,
214            value_recognizer,
215        } = self;
216        loop {
217            match state {
218                MapOperationDecoderState::ReadingHeader => {
219                    if src.remaining() < LEN_SIZE + TAG_SIZE {
220                        src.reserve(LEN_SIZE + TAG_SIZE);
221                        break Ok(None);
222                    }
223                    let mut header = src.as_ref();
224                    let total_len = header.get_u64() as usize;
225                    let tag = header.get_u8();
226                    match tag {
227                        UPDATE => {
228                            let required = TAG_SIZE + 2 * LEN_SIZE;
229                            if src.remaining() < required {
230                                src.reserve(required - src.remaining());
231                                break Ok(None);
232                            }
233                            let key_len = header.get_u64() as usize;
234                            let value_len = if let Some(l) =
235                                total_len.checked_sub(key_len + LEN_SIZE + TAG_SIZE)
236                            {
237                                l
238                            } else {
239                                break Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
240                                    problem: Text::from(format!("{}{}", BAD_KEY_SIZE, key_len)),
241                                }));
242                            };
243
244                            src.advance(TAG_SIZE + 2 * LEN_SIZE);
245                            *state = MapOperationDecoderState::ReadingKey {
246                                remaining: key_len,
247                                value_size: Some(value_len),
248                            };
249                        }
250                        REMOVE => {
251                            let key_len = if let Some(l) = total_len.checked_sub(TAG_SIZE) {
252                                l
253                            } else {
254                                break Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
255                                    problem: Text::from(format!(
256                                        "{}{}",
257                                        BAD_RECORD_SIZE, total_len
258                                    )),
259                                }));
260                            };
261                            src.advance(LEN_SIZE + TAG_SIZE);
262                            *state = MapOperationDecoderState::ReadingKey {
263                                remaining: key_len,
264                                value_size: None,
265                            };
266                        }
267                        CLEAR => {
268                            src.advance(TAG_SIZE + LEN_SIZE);
269                            break Ok(Some(MapOperation::Clear));
270                        }
271                        ow => {
272                            break Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
273                                problem: Text::from(format!("Invalid map operation tag: {}", ow)),
274                            }));
275                        }
276                    }
277                }
278                MapOperationDecoderState::ReadingKey {
279                    remaining,
280                    value_size,
281                } => {
282                    let (consumed, decode_result) =
283                        consume_bounded(*remaining, src, key_recognizer);
284                    *remaining -= consumed;
285                    match decode_result {
286                        Ok(Some(result)) => {
287                            *state = MapOperationDecoderState::AfterKey {
288                                key: Some(result),
289                                remaining: *remaining,
290                                value_size: *value_size,
291                            }
292                        }
293                        Ok(None) => {
294                            break Ok(None);
295                        }
296                        Err(e) => {
297                            let rem = src.remaining();
298                            if rem >= *remaining {
299                                src.advance(*remaining);
300                                *state = MapOperationDecoderState::ReadingHeader;
301                                break Err(e.into());
302                            } else {
303                                src.clear();
304                                *state = MapOperationDecoderState::Discarding {
305                                    remaining: *remaining - rem,
306                                    error: Some(e),
307                                }
308                            }
309                        }
310                    }
311                }
312                MapOperationDecoderState::AfterKey {
313                    key,
314                    remaining,
315                    value_size,
316                } => {
317                    if src.remaining() >= *remaining {
318                        src.advance(*remaining);
319                        if let Some(value_size) = value_size.take() {
320                            *state = MapOperationDecoderState::ReadingValue {
321                                key: key.take(),
322                                remaining: value_size,
323                            };
324                        } else {
325                            let op = key.take().map(|key| MapOperation::Remove { key });
326                            *state = MapOperationDecoderState::ReadingHeader;
327                            break Ok(op);
328                        }
329                    } else {
330                        *remaining -= src.remaining();
331                        src.clear();
332                        break Ok(None);
333                    }
334                }
335                MapOperationDecoderState::ReadingValue { key, remaining } => {
336                    let (consumed, decode_result) =
337                        consume_bounded(*remaining, src, value_recognizer);
338                    *remaining -= consumed;
339                    match decode_result {
340                        Ok(Some(value)) => {
341                            *state = MapOperationDecoderState::AfterValue {
342                                key_value: key.take().map(move |k| (k, value)),
343                                remaining: *remaining,
344                            }
345                        }
346                        Ok(None) => {
347                            break Ok(None);
348                        }
349                        Err(e) => {
350                            let rem = src.remaining();
351                            if rem >= *remaining {
352                                src.advance(*remaining);
353                                *state = MapOperationDecoderState::ReadingHeader;
354                                break Err(e.into());
355                            } else {
356                                src.clear();
357                                *state = MapOperationDecoderState::Discarding {
358                                    remaining: *remaining - rem,
359                                    error: Some(e),
360                                }
361                            }
362                        }
363                    }
364                }
365                MapOperationDecoderState::AfterValue {
366                    key_value,
367                    remaining,
368                } => {
369                    if src.remaining() >= *remaining {
370                        src.advance(*remaining);
371                        let result = key_value
372                            .take()
373                            .map(|(key, value)| MapOperation::Update { key, value });
374                        *state = MapOperationDecoderState::ReadingHeader;
375                        break Ok(result);
376                    } else {
377                        *remaining -= src.remaining();
378                        src.clear();
379                        break Ok(None);
380                    }
381                }
382                MapOperationDecoderState::Discarding { error, remaining } => {
383                    if src.remaining() >= *remaining {
384                        src.advance(*remaining);
385                        let err = error.take().unwrap_or(AsyncParseError::UnconsumedInput);
386                        *state = MapOperationDecoderState::ReadingHeader;
387                        break Err(err.into());
388                    } else {
389                        *remaining -= src.remaining();
390                        src.clear();
391                        break Ok(None);
392                    }
393                }
394            }
395        }
396    }
397}
398
399impl<K: StructuralWritable, V: StructuralWritable> Encoder<MapOperation<K, V>>
400    for MapOperationEncoder
401{
402    type Error = std::io::Error;
403
404    fn encode(&mut self, item: MapOperation<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
405        dst.reserve(TAG_SIZE);
406        match item {
407            MapOperation::Update { key, value } => {
408                dst.reserve(2 * LEN_SIZE + TAG_SIZE);
409                let body_len_offset = dst.remaining();
410                dst.put_u64(0);
411                dst.put_u8(0);
412                dst.put_u64(0);
413                let key_len = write_recon(dst, &key);
414                let value_len = write_recon(dst, &value);
415                let total_len = key_len + value_len + LEN_SIZE + TAG_SIZE;
416                let mut rewound = &mut dst.as_mut()[body_len_offset..];
417                rewound.put_u64(u64::try_from(total_len).expect(OVERSIZE_KEY));
418                rewound.put_u8(UPDATE);
419                rewound.put_u64(u64::try_from(key_len).expect(OVERSIZE_RECORD));
420            }
421            MapOperation::Remove { key } => {
422                dst.reserve(LEN_SIZE + TAG_SIZE);
423                let body_len_offset = dst.remaining();
424                dst.put_u64(0);
425                dst.put_u8(REMOVE);
426                let key_len = write_recon(dst, &key);
427                let total_len = key_len + TAG_SIZE;
428                let mut rewound = &mut dst.as_mut()[body_len_offset..];
429                rewound.put_u64(u64::try_from(total_len).expect(OVERSIZE_KEY));
430            }
431            MapOperation::Clear => {
432                dst.put_u64(TAG_SIZE as u64);
433                dst.put_u8(CLEAR);
434            }
435        }
436        Ok(())
437    }
438}
439
440#[derive(Debug, Default, Clone, Copy)]
441struct MessageEncoder<Inner>(Inner);
442
443#[derive(Debug, Default, Clone, Copy)]
444struct MessageDecoder<Inner>(Inner);
445
446impl<K, V, Inner> Encoder<MapMessage<K, V>> for MessageEncoder<Inner>
447where
448    Inner: Encoder<MapOperation<K, V>>,
449{
450    type Error = Inner::Error;
451
452    fn encode(&mut self, item: MapMessage<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
453        let MessageEncoder(inner) = self;
454        match item {
455            MapMessage::Update { key, value } => {
456                inner.encode(MapOperation::Update { key, value }, dst)
457            }
458            MapMessage::Remove { key } => inner.encode(MapOperation::Remove { key }, dst),
459            MapMessage::Clear => inner.encode(MapOperation::Clear, dst),
460            MapMessage::Take(n) => {
461                dst.reserve(TAG_SIZE + 2 * LEN_SIZE);
462                dst.put_u64((TAG_SIZE + LEN_SIZE) as u64);
463                dst.put_u8(TAKE);
464                dst.put_u64(n);
465                Ok(())
466            }
467            MapMessage::Drop(n) => {
468                dst.reserve(TAG_SIZE + 2 * LEN_SIZE);
469                dst.put_u64((TAG_SIZE + LEN_SIZE) as u64);
470                dst.put_u8(DROP);
471                dst.put_u64(n);
472                Ok(())
473            }
474        }
475    }
476}
477
478impl<K, V, Inner> Decoder for MessageDecoder<Inner>
479where
480    Inner: Decoder<Item = MapOperation<K, V>, Error = FrameIoError>,
481{
482    type Item = MapMessage<K, V>;
483
484    type Error = FrameIoError;
485
486    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
487        let MessageDecoder(inner) = self;
488        if src.remaining() < TAG_SIZE + LEN_SIZE {
489            src.reserve(TAG_SIZE + LEN_SIZE);
490            return Ok(None);
491        }
492        let mut header = src.as_ref();
493        let total_len = header.get_u64() as usize;
494        match header.get_u8() {
495            tag @ (TAKE | DROP) => {
496                if total_len != TAG_SIZE + LEN_SIZE {
497                    return Err(FrameIoError::BadFrame(InvalidFrame::InvalidHeader {
498                        problem: Text::new(BAD_RECORD_SIZE),
499                    }));
500                }
501                let required = TAG_SIZE + 2 * LEN_SIZE;
502                if src.remaining() < required {
503                    src.reserve(required - src.remaining());
504                    return Ok(None);
505                }
506                src.advance(TAG_SIZE + LEN_SIZE);
507                let n = src.get_u64();
508                Ok(Some(if tag == TAKE {
509                    MapMessage::Take(n)
510                } else {
511                    MapMessage::Drop(n)
512                }))
513            }
514            _ => {
515                let result = inner.decode(src)?;
516                Ok(result.map(Into::into))
517            }
518        }
519    }
520}
521
522#[derive(Debug, Default, Clone, Copy)]
523pub struct RawMapMessageEncoder {
524    inner: MessageEncoder<RawMapOperationEncoder>,
525}
526
527impl<K: AsRef<[u8]>, V: AsRef<[u8]>> Encoder<MapMessage<K, V>> for RawMapMessageEncoder {
528    type Error = std::io::Error;
529
530    fn encode(&mut self, item: MapMessage<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
531        self.inner.encode(item, dst)
532    }
533}
534
535#[derive(Debug, Default, Clone, Copy)]
536pub struct RawMapMessageDecoder {
537    inner: MessageDecoder<RawMapOperationDecoder>,
538}
539
540impl Decoder for RawMapMessageDecoder {
541    type Item = MapMessage<BytesMut, BytesMut>;
542
543    type Error = FrameIoError;
544
545    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
546        self.inner.decode(src)
547    }
548}
549
550#[derive(Debug, Default, Clone, Copy)]
551pub struct MapMessageEncoder {
552    inner: MessageEncoder<MapOperationEncoder>,
553}
554
555impl<K: StructuralWritable, V: StructuralWritable> Encoder<MapMessage<K, V>> for MapMessageEncoder {
556    type Error = std::io::Error;
557
558    fn encode(&mut self, item: MapMessage<K, V>, dst: &mut BytesMut) -> Result<(), Self::Error> {
559        self.inner.encode(item, dst)
560    }
561}
562
563pub struct MapMessageDecoder<K: RecognizerReadable, V: RecognizerReadable> {
564    inner: MessageDecoder<MapOperationDecoder<K, V>>,
565}
566
567impl<K: RecognizerReadable, V: RecognizerReadable> Default for MapMessageDecoder<K, V> {
568    fn default() -> Self {
569        Self {
570            inner: Default::default(),
571        }
572    }
573}
574
575impl<K: RecognizerReadable, V: RecognizerReadable> Decoder for MapMessageDecoder<K, V> {
576    type Item = MapMessage<K, V>;
577
578    type Error = FrameIoError;
579
580    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
581        self.inner.decode(src)
582    }
583}