Skip to main content

s2_common/record/
command.rs

1use std::{fmt, str::Utf8Error};
2
3use bytes::{BufMut, Bytes};
4use compact_str::CompactString;
5use strum::FromRepr;
6
7use super::{
8    Encodable, FencingTokenTooLongError, MeteredSize, RecordDecodeError, fencing::FencingToken,
9};
10use crate::{deep_size::DeepSize, record::SeqNum};
11
12pub const COMMAND_ID_FENCE: &[u8] = b"fence";
13pub const COMMAND_ID_TRIM: &[u8] = b"trim";
14
15#[derive(Debug, PartialEq, Eq, Clone, Copy, FromRepr)]
16#[repr(u8)]
17pub enum CommandOp {
18    Fence,
19    Trim,
20}
21
22impl CommandOp {
23    pub fn to_id(self) -> &'static [u8] {
24        match self {
25            Self::Fence => COMMAND_ID_FENCE,
26            Self::Trim => COMMAND_ID_TRIM,
27        }
28    }
29
30    pub fn from_id(name: &[u8]) -> Option<Self> {
31        match name {
32            COMMAND_ID_FENCE => Some(Self::Fence),
33            COMMAND_ID_TRIM => Some(Self::Trim),
34            _ => None,
35        }
36    }
37}
38
39impl fmt::Display for CommandOp {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        let name = std::str::from_utf8(self.to_id()).map_err(|_| fmt::Error)?;
42        f.write_str(name)
43    }
44}
45
46#[derive(Debug, PartialEq, Eq, Clone)]
47pub enum CommandRecord {
48    Fence(FencingToken),
49    Trim(SeqNum),
50}
51
52impl DeepSize for CommandRecord {
53    fn deep_size(&self) -> usize {
54        match self {
55            Self::Fence(token) => token.deep_size(),
56            Self::Trim(seq_num) => seq_num.deep_size(),
57        }
58    }
59}
60
61impl MeteredSize for CommandRecord {
62    fn metered_size(&self) -> usize {
63        8 + 2
64            + self.op().to_id().len()
65            + match self {
66                Self::Fence(token) => token.len(),
67                Self::Trim(trim_point) => size_of_val(trim_point),
68            }
69    }
70}
71
72impl CommandRecord {
73    pub fn op(&self) -> CommandOp {
74        match self {
75            CommandRecord::Fence(_) => CommandOp::Fence,
76            CommandRecord::Trim(_) => CommandOp::Trim,
77        }
78    }
79
80    pub fn payload(&self) -> Bytes {
81        match self {
82            Self::Fence(token) => Bytes::copy_from_slice(token.as_bytes()),
83            Self::Trim(trim_point) => Bytes::copy_from_slice(&trim_point.to_be_bytes()),
84        }
85    }
86
87    pub fn try_from_parts(op: CommandOp, payload: &[u8]) -> Result<Self, CommandPayloadError> {
88        match op {
89            CommandOp::Fence => {
90                let token = CompactString::from_utf8(payload)
91                    .map_err(CommandPayloadError::InvalidUtf8)?
92                    .try_into()?;
93                Ok(Self::Fence(token))
94            }
95            CommandOp::Trim => {
96                let trim_point = SeqNum::from_be_bytes(
97                    payload
98                        .try_into()
99                        .map_err(|_| CommandPayloadError::TrimPointSize(payload.len()))?,
100                );
101                Ok(Self::Trim(trim_point))
102            }
103        }
104    }
105}
106
107impl TryFrom<&[u8]> for CommandRecord {
108    type Error = RecordDecodeError;
109
110    fn try_from(record: &[u8]) -> Result<Self, Self::Error> {
111        if record.is_empty() {
112            return Err(RecordDecodeError::Truncated("CommandOrdinal"));
113        }
114        let op = CommandOp::from_repr(record[0])
115            .ok_or(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))?;
116        Self::try_from_parts(op, &record[1..]).map_err(Into::into)
117    }
118}
119
120impl Encodable for CommandRecord {
121    fn encoded_size(&self) -> usize {
122        1 + match self {
123            CommandRecord::Fence(token) => token.len(),
124            CommandRecord::Trim(trim_point) => size_of_val(trim_point),
125        }
126    }
127
128    fn encode_into(&self, buf: &mut impl BufMut) {
129        buf.put_u8(self.op() as u8);
130        match self {
131            CommandRecord::Fence(token) => {
132                buf.put_slice(token.as_bytes());
133            }
134            CommandRecord::Trim(trim_point) => {
135                buf.put_u64(*trim_point);
136            }
137        }
138    }
139}
140
141#[derive(Debug, PartialEq, thiserror::Error)]
142pub enum CommandPayloadError {
143    #[error("invalid UTF-8")]
144    InvalidUtf8(Utf8Error),
145    #[error(transparent)]
146    FencingTokenTooLong(#[from] FencingTokenTooLongError),
147    #[error("earliest sequence number to trim to was {0} bytes, must be 8")]
148    TrimPointSize(usize),
149}
150
151impl From<CommandPayloadError> for RecordDecodeError {
152    fn from(e: CommandPayloadError) -> Self {
153        match e {
154            CommandPayloadError::InvalidUtf8(_) => {
155                RecordDecodeError::InvalidValue("CommandPayload", "fencing token not valid utf8")
156            }
157            CommandPayloadError::FencingTokenTooLong(_) => {
158                RecordDecodeError::InvalidValue("CommandPayload", "fencing token too long")
159            }
160            CommandPayloadError::TrimPointSize(_) => {
161                RecordDecodeError::InvalidValue("CommandPayload", "trim point size")
162            }
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use compact_str::ToCompactString;
170    use proptest::prelude::*;
171    use rstest::rstest;
172
173    use super::*;
174
175    fn roundtrip(cmd: CommandRecord, expected_len: usize) {
176        assert_eq!(cmd.encoded_size(), expected_len);
177        let encoded = cmd.to_bytes();
178        assert_eq!(encoded.len(), expected_len);
179        assert_eq!(CommandRecord::try_from(encoded.as_ref()), Ok(cmd));
180    }
181
182    #[test]
183    fn command_op_names() {
184        for cmd in [CommandOp::Fence, CommandOp::Trim] {
185            let name = cmd.to_id();
186            assert_eq!(CommandOp::from_id(name), Some(cmd));
187        }
188        assert_eq!(CommandOp::from_id(b""), None);
189        assert_eq!(CommandOp::from_id(b"invalid"), None);
190    }
191
192    #[test]
193    fn fencing_token_invalid_utf8() {
194        assert!(matches!(
195            CommandRecord::try_from_parts(CommandOp::Fence, &[0xff]),
196            Err(CommandPayloadError::InvalidUtf8(_))
197        ));
198    }
199
200    #[test]
201    fn fencing_token_too_long() {
202        assert_eq!(
203            CommandRecord::try_from_parts(
204                CommandOp::Fence,
205                b"0123456789012345678901234567890123456789"
206            ),
207            Err(CommandPayloadError::FencingTokenTooLong(
208                FencingTokenTooLongError(40)
209            ))
210        );
211    }
212
213    #[rstest]
214    #[case::empty("")]
215    #[case::arbit("arbitrary")]
216    #[case::full("0123456789012345")]
217    fn fence_roundtrip(#[case] token: &str) {
218        let cmd = CommandRecord::Fence(FencingToken::try_from(token.to_compact_string()).unwrap());
219        assert_eq!(
220            CommandRecord::try_from_parts(CommandOp::Fence, token.as_bytes()),
221            Ok(cmd.clone())
222        );
223        roundtrip(cmd, 1 + token.len());
224    }
225
226    #[rstest]
227    #[case::empty(b"")]
228    #[case::too_small(b"0123")]
229    #[case::too_big(b"0123456789")]
230    fn trim_point_size(#[case] payload: &[u8]) {
231        assert_eq!(
232            CommandRecord::try_from_parts(CommandOp::Trim, payload),
233            Err(CommandPayloadError::TrimPointSize(payload.len()))
234        );
235    }
236
237    #[test]
238    fn metered_size_is_computed_without_materializing_payload() {
239        let fence =
240            CommandRecord::Fence(FencingToken::try_from("fence-me".to_compact_string()).unwrap());
241        assert_eq!(
242            fence.metered_size(),
243            8 + 2 + CommandOp::Fence.to_id().len() + "fence-me".len()
244        );
245
246        let trim = CommandRecord::Trim(42);
247        assert_eq!(
248            trim.metered_size(),
249            8 + 2 + CommandOp::Trim.to_id().len() + size_of_val(&42u64)
250        );
251    }
252
253    proptest! {
254        #[test]
255        fn trim_roundtrip(trim_point in any::<SeqNum>()) {
256            let cmd = CommandRecord::Trim(trim_point);
257            assert_eq!(CommandRecord::try_from_parts(CommandOp::Trim, trim_point.to_be_bytes().as_slice()), Ok(cmd.clone()));
258            roundtrip(cmd, 9);
259        }
260    }
261
262    #[test]
263    fn decode_invalid_command() {
264        let try_convert = |raw: &[u8]| CommandRecord::try_from(raw);
265        assert_eq!(
266            try_convert(&[]),
267            Err(RecordDecodeError::Truncated("CommandOrdinal"))
268        );
269        assert_eq!(
270            try_convert(&[0xff]),
271            Err(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))
272        );
273        assert_eq!(
274            try_convert(&[CommandOp::Fence as u8, 0xff, 0xff]),
275            Err(RecordDecodeError::InvalidValue(
276                "CommandPayload",
277                "fencing token not valid utf8"
278            ))
279        );
280        assert_eq!(
281            try_convert(&[
282                CommandOp::Fence as u8,
283                b'0',
284                b'1',
285                b'2',
286                b'3',
287                b'4',
288                b'5',
289                b'6',
290                b'7',
291                b'8',
292                b'9',
293                b'0',
294                b'1',
295                b'2',
296                b'3',
297                b'4',
298                b'5',
299                b'6',
300                b'7',
301                b'8',
302                b'9',
303                b'0',
304                b'1',
305                b'2',
306                b'3',
307                b'4',
308                b'5',
309                b'6',
310                b'7',
311                b'8',
312                b'9',
313                b'0',
314                b'1',
315                b'2',
316                b'3',
317                b'4',
318                b'5',
319                b'6',
320                b'7',
321                b'8',
322                b'9',
323            ]),
324            Err(CommandPayloadError::FencingTokenTooLong(FencingTokenTooLongError(40)).into())
325        );
326        assert_eq!(
327            try_convert(&[CommandOp::Trim as u8, 0xff]),
328            Err(CommandPayloadError::TrimPointSize(1).into())
329        );
330    }
331}