Skip to main content

s2_common/record/
command.rs

1use std::{fmt, str::Utf8Error};
2
3use bytes::{BufMut, Bytes};
4use compact_str::CompactString;
5use enum_ordinalize::Ordinalize;
6
7use super::{
8    Encodable, FencingTokenTooLongError, MeteredSize, RecordDecodeError, fencing::FencingToken,
9};
10use crate::{deep_size::DeepSize, record::SeqNum};
11
12pub const COMMAND_ID_FENCE: &[u8] = b"fence";
13pub const COMMAND_ID_TRIM: &[u8] = b"trim";
14
15#[derive(Debug, PartialEq, Eq, Clone, Copy, Ordinalize)]
16#[repr(u8)]
17pub enum CommandOp {
18    Fence,
19    Trim,
20}
21
22impl CommandOp {
23    pub fn to_id(self) -> &'static [u8] {
24        match self {
25            Self::Fence => COMMAND_ID_FENCE,
26            Self::Trim => COMMAND_ID_TRIM,
27        }
28    }
29
30    pub fn from_id(name: &[u8]) -> Option<Self> {
31        match name {
32            COMMAND_ID_FENCE => Some(Self::Fence),
33            COMMAND_ID_TRIM => Some(Self::Trim),
34            _ => None,
35        }
36    }
37}
38
39impl fmt::Display for CommandOp {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        let name = std::str::from_utf8(self.to_id()).map_err(|_| fmt::Error)?;
42        f.write_str(name)
43    }
44}
45
46#[derive(Debug, PartialEq, Eq, Clone)]
47pub enum CommandRecord {
48    Fence(FencingToken),
49    Trim(SeqNum),
50}
51
52impl DeepSize for CommandRecord {
53    fn deep_size(&self) -> usize {
54        match self {
55            Self::Fence(token) => token.deep_size(),
56            Self::Trim(seq_num) => seq_num.deep_size(),
57        }
58    }
59}
60
61impl MeteredSize for CommandRecord {
62    fn metered_size(&self) -> usize {
63        8 + 2
64            + self.op().to_id().len()
65            + match self {
66                Self::Fence(token) => token.len(),
67                Self::Trim(trim_point) => size_of_val(trim_point),
68            }
69    }
70}
71
72impl CommandRecord {
73    pub fn op(&self) -> CommandOp {
74        match self {
75            CommandRecord::Fence(_) => CommandOp::Fence,
76            CommandRecord::Trim(_) => CommandOp::Trim,
77        }
78    }
79
80    pub fn payload(&self) -> Bytes {
81        match self {
82            Self::Fence(token) => Bytes::copy_from_slice(token.as_bytes()),
83            Self::Trim(trim_point) => Bytes::copy_from_slice(&trim_point.to_be_bytes()),
84        }
85    }
86
87    pub fn try_from_parts(op: CommandOp, payload: &[u8]) -> Result<Self, CommandPayloadError> {
88        match op {
89            CommandOp::Fence => {
90                let token = CompactString::from_utf8(payload)
91                    .map_err(CommandPayloadError::InvalidUtf8)?
92                    .try_into()?;
93                Ok(Self::Fence(token))
94            }
95            CommandOp::Trim => {
96                let trim_point = SeqNum::from_be_bytes(
97                    payload
98                        .try_into()
99                        .map_err(|_| CommandPayloadError::TrimPointSize(payload.len()))?,
100                );
101                Ok(Self::Trim(trim_point))
102            }
103        }
104    }
105}
106
107impl TryFrom<&[u8]> for CommandRecord {
108    type Error = RecordDecodeError;
109
110    fn try_from(record: &[u8]) -> Result<Self, Self::Error> {
111        if record.is_empty() {
112            return Err(RecordDecodeError::Truncated("CommandOrdinal"));
113        }
114        let op = CommandOp::from_ordinal(record[0])
115            .ok_or(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))?;
116        Self::try_from_parts(op, &record[1..]).map_err(Into::into)
117    }
118}
119
120impl Encodable for CommandRecord {
121    fn encoded_size(&self) -> usize {
122        1 + match self {
123            CommandRecord::Fence(token) => token.len(),
124            CommandRecord::Trim(trim_point) => size_of_val(trim_point),
125        }
126    }
127
128    fn encode_into(&self, buf: &mut impl BufMut) {
129        buf.put_u8(self.op().ordinal());
130        match self {
131            CommandRecord::Fence(token) => {
132                buf.put_slice(token.as_bytes());
133            }
134            CommandRecord::Trim(trim_point) => {
135                buf.put_u64(*trim_point);
136            }
137        }
138    }
139}
140
141#[derive(Debug, PartialEq, thiserror::Error)]
142pub enum CommandPayloadError {
143    #[error("invalid UTF-8")]
144    InvalidUtf8(Utf8Error),
145    #[error(transparent)]
146    FencingTokenTooLong(#[from] FencingTokenTooLongError),
147    #[error("earliest sequence number to trim to was {0} bytes, must be 8")]
148    TrimPointSize(usize),
149}
150
151impl From<CommandPayloadError> for RecordDecodeError {
152    fn from(e: CommandPayloadError) -> Self {
153        match e {
154            CommandPayloadError::InvalidUtf8(_) => {
155                RecordDecodeError::InvalidValue("CommandPayload", "fencing token not valid utf8")
156            }
157            CommandPayloadError::FencingTokenTooLong(_) => {
158                RecordDecodeError::InvalidValue("CommandPayload", "fencing token too long")
159            }
160            CommandPayloadError::TrimPointSize(_) => {
161                RecordDecodeError::InvalidValue("CommandPayload", "trim point size")
162            }
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use compact_str::ToCompactString;
170    use enum_ordinalize::Ordinalize;
171    use proptest::prelude::*;
172    use rstest::rstest;
173
174    use super::*;
175
176    fn roundtrip(cmd: CommandRecord, expected_len: usize) {
177        assert_eq!(cmd.encoded_size(), expected_len);
178        let encoded = cmd.to_bytes();
179        assert_eq!(encoded.len(), expected_len);
180        assert_eq!(CommandRecord::try_from(encoded.as_ref()), Ok(cmd));
181    }
182
183    #[test]
184    fn command_op_names() {
185        for cmd in CommandOp::VARIANTS {
186            let name = cmd.to_id();
187            assert_eq!(CommandOp::from_id(name), Some(*cmd));
188        }
189        assert_eq!(CommandOp::from_id(b""), None);
190        assert_eq!(CommandOp::from_id(b"invalid"), None);
191    }
192
193    #[test]
194    fn fencing_token_invalid_utf8() {
195        assert!(matches!(
196            CommandRecord::try_from_parts(CommandOp::Fence, &[0xff]),
197            Err(CommandPayloadError::InvalidUtf8(_))
198        ));
199    }
200
201    #[test]
202    fn fencing_token_too_long() {
203        assert_eq!(
204            CommandRecord::try_from_parts(
205                CommandOp::Fence,
206                b"0123456789012345678901234567890123456789"
207            ),
208            Err(CommandPayloadError::FencingTokenTooLong(
209                FencingTokenTooLongError(40)
210            ))
211        );
212    }
213
214    #[rstest]
215    #[case::empty("")]
216    #[case::arbit("arbitrary")]
217    #[case::full("0123456789012345")]
218    fn fence_roundtrip(#[case] token: &str) {
219        let cmd = CommandRecord::Fence(FencingToken::try_from(token.to_compact_string()).unwrap());
220        assert_eq!(
221            CommandRecord::try_from_parts(CommandOp::Fence, token.as_bytes()),
222            Ok(cmd.clone())
223        );
224        roundtrip(cmd, 1 + token.len());
225    }
226
227    #[rstest]
228    #[case::empty(b"")]
229    #[case::too_small(b"0123")]
230    #[case::too_big(b"0123456789")]
231    fn trim_point_size(#[case] payload: &[u8]) {
232        assert_eq!(
233            CommandRecord::try_from_parts(CommandOp::Trim, payload),
234            Err(CommandPayloadError::TrimPointSize(payload.len()))
235        );
236    }
237
238    #[test]
239    fn metered_size_is_computed_without_materializing_payload() {
240        let fence =
241            CommandRecord::Fence(FencingToken::try_from("fence-me".to_compact_string()).unwrap());
242        assert_eq!(
243            fence.metered_size(),
244            8 + 2 + CommandOp::Fence.to_id().len() + "fence-me".len()
245        );
246
247        let trim = CommandRecord::Trim(42);
248        assert_eq!(
249            trim.metered_size(),
250            8 + 2 + CommandOp::Trim.to_id().len() + size_of_val(&42u64)
251        );
252    }
253
254    proptest! {
255        #[test]
256        fn trim_roundtrip(trim_point in any::<SeqNum>()) {
257            let cmd = CommandRecord::Trim(trim_point);
258            assert_eq!(CommandRecord::try_from_parts(CommandOp::Trim, trim_point.to_be_bytes().as_slice()), Ok(cmd.clone()));
259            roundtrip(cmd, 9);
260        }
261    }
262
263    #[test]
264    fn decode_invalid_command() {
265        let try_convert = |raw: &[u8]| CommandRecord::try_from(raw);
266        assert_eq!(
267            try_convert(&[]),
268            Err(RecordDecodeError::Truncated("CommandOrdinal"))
269        );
270        assert_eq!(
271            try_convert(&[0xff]),
272            Err(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))
273        );
274        assert_eq!(
275            try_convert(&[CommandOp::Fence.ordinal(), 0xff, 0xff]),
276            Err(RecordDecodeError::InvalidValue(
277                "CommandPayload",
278                "fencing token not valid utf8"
279            ))
280        );
281        assert_eq!(
282            try_convert(&[
283                CommandOp::Fence.ordinal(),
284                b'0',
285                b'1',
286                b'2',
287                b'3',
288                b'4',
289                b'5',
290                b'6',
291                b'7',
292                b'8',
293                b'9',
294                b'0',
295                b'1',
296                b'2',
297                b'3',
298                b'4',
299                b'5',
300                b'6',
301                b'7',
302                b'8',
303                b'9',
304                b'0',
305                b'1',
306                b'2',
307                b'3',
308                b'4',
309                b'5',
310                b'6',
311                b'7',
312                b'8',
313                b'9',
314                b'0',
315                b'1',
316                b'2',
317                b'3',
318                b'4',
319                b'5',
320                b'6',
321                b'7',
322                b'8',
323                b'9',
324            ]),
325            Err(CommandPayloadError::FencingTokenTooLong(FencingTokenTooLongError(40)).into())
326        );
327        assert_eq!(
328            try_convert(&[CommandOp::Trim.ordinal(), 0xff]),
329            Err(CommandPayloadError::TrimPointSize(1).into())
330        );
331    }
332}