s2_common/record/
command.rs

1use std::{fmt, str::Utf8Error};
2
3use bytes::{BufMut, Bytes};
4use compact_str::CompactString;
5use enum_ordinalize::Ordinalize;
6
7use super::{Encodable, FencingTokenTooLongError, InternalRecordError, fencing::FencingToken};
8use crate::{deep_size::DeepSize, record::SeqNum};
9
10pub const COMMAND_ID_FENCE: &[u8] = b"fence";
11pub const COMMAND_ID_TRIM: &[u8] = b"trim";
12
13#[derive(Debug, PartialEq, Eq, Clone, Copy, Ordinalize)]
14#[repr(u8)]
15pub enum CommandOp {
16    Fence,
17    Trim,
18}
19
20impl CommandOp {
21    pub fn to_id(self) -> &'static [u8] {
22        match self {
23            Self::Fence => COMMAND_ID_FENCE,
24            Self::Trim => COMMAND_ID_TRIM,
25        }
26    }
27
28    pub fn from_id(name: &[u8]) -> Option<Self> {
29        match name {
30            COMMAND_ID_FENCE => Some(Self::Fence),
31            COMMAND_ID_TRIM => Some(Self::Trim),
32            _ => None,
33        }
34    }
35}
36
37impl fmt::Display for CommandOp {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        let name = std::str::from_utf8(self.to_id()).map_err(|_| fmt::Error)?;
40        f.write_str(name)
41    }
42}
43
44#[derive(Debug, PartialEq, Eq, Clone)]
45pub enum CommandRecord {
46    Fence(FencingToken),
47    Trim(SeqNum),
48}
49
50impl DeepSize for CommandRecord {
51    fn deep_size(&self) -> usize {
52        match self {
53            Self::Fence(token) => token.deep_size(),
54            Self::Trim(seq_num) => seq_num.deep_size(),
55        }
56    }
57}
58
59impl CommandRecord {
60    pub fn op(&self) -> CommandOp {
61        match self {
62            CommandRecord::Fence(_) => CommandOp::Fence,
63            CommandRecord::Trim(_) => CommandOp::Trim,
64        }
65    }
66
67    pub fn payload(&self) -> Bytes {
68        match self {
69            Self::Fence(token) => Bytes::copy_from_slice(token.as_bytes()),
70            Self::Trim(trim_point) => Bytes::copy_from_slice(&trim_point.to_be_bytes()),
71        }
72    }
73
74    pub fn try_from_parts(op: CommandOp, payload: &[u8]) -> Result<Self, CommandPayloadError> {
75        match op {
76            CommandOp::Fence => {
77                let token = CompactString::from_utf8(payload)
78                    .map_err(CommandPayloadError::InvalidUtf8)?
79                    .try_into()?;
80                Ok(Self::Fence(token))
81            }
82            CommandOp::Trim => {
83                let trim_point = SeqNum::from_be_bytes(
84                    payload
85                        .try_into()
86                        .map_err(|_| CommandPayloadError::TrimPointSize(payload.len()))?,
87                );
88                Ok(Self::Trim(trim_point))
89            }
90        }
91    }
92}
93
94impl TryFrom<&[u8]> for CommandRecord {
95    type Error = InternalRecordError;
96
97    fn try_from(record: &[u8]) -> Result<Self, Self::Error> {
98        if record.is_empty() {
99            return Err(InternalRecordError::Truncated("CommandOrdinal"));
100        }
101        let op = CommandOp::from_ordinal(record[0]).ok_or(InternalRecordError::InvalidValue(
102            "CommandOrdinal",
103            "unknown",
104        ))?;
105        Self::try_from_parts(op, &record[1..]).map_err(Into::into)
106    }
107}
108
109impl Encodable for CommandRecord {
110    fn encoded_size(&self) -> usize {
111        1 + match self {
112            CommandRecord::Fence(token) => token.len(),
113            CommandRecord::Trim(trim_point) => size_of_val(trim_point),
114        }
115    }
116
117    fn encode_into(&self, buf: &mut impl BufMut) {
118        buf.put_u8(self.op().ordinal());
119        match self {
120            CommandRecord::Fence(token) => {
121                buf.put_slice(token.as_bytes());
122            }
123            CommandRecord::Trim(trim_point) => {
124                buf.put_u64(*trim_point);
125            }
126        }
127    }
128}
129
130#[derive(Debug, PartialEq, thiserror::Error)]
131pub enum CommandPayloadError {
132    #[error("invalid UTF-8")]
133    InvalidUtf8(Utf8Error),
134    #[error(transparent)]
135    FencingTokenTooLong(#[from] FencingTokenTooLongError),
136    #[error("earliest sequence number to trim to was {0} bytes, must be 8")]
137    TrimPointSize(usize),
138}
139
140impl From<CommandPayloadError> for InternalRecordError {
141    fn from(e: CommandPayloadError) -> Self {
142        match e {
143            CommandPayloadError::InvalidUtf8(_) => {
144                InternalRecordError::InvalidValue("CommandPayload", "fencing token not valid utf8")
145            }
146            CommandPayloadError::FencingTokenTooLong(_) => {
147                InternalRecordError::InvalidValue("CommandPayload", "fencing token too long")
148            }
149            CommandPayloadError::TrimPointSize(_) => {
150                InternalRecordError::InvalidValue("CommandPayload", "trim point size")
151            }
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use compact_str::ToCompactString;
159    use enum_ordinalize::Ordinalize;
160    use proptest::prelude::*;
161    use rstest::rstest;
162
163    use super::*;
164
165    fn roundtrip(cmd: CommandRecord, expected_len: usize) {
166        assert_eq!(cmd.encoded_size(), expected_len);
167        let encoded = cmd.to_bytes();
168        assert_eq!(encoded.len(), expected_len);
169        assert_eq!(CommandRecord::try_from(encoded.as_ref()), Ok(cmd));
170    }
171
172    #[test]
173    fn command_op_names() {
174        for cmd in CommandOp::VARIANTS {
175            let name = cmd.to_id();
176            assert_eq!(CommandOp::from_id(name), Some(*cmd));
177        }
178        assert_eq!(CommandOp::from_id(b""), None);
179        assert_eq!(CommandOp::from_id(b"invalid"), None);
180    }
181
182    #[test]
183    fn fencing_token_invalid_utf8() {
184        assert!(matches!(
185            CommandRecord::try_from_parts(CommandOp::Fence, &[0xff]),
186            Err(CommandPayloadError::InvalidUtf8(_))
187        ));
188    }
189
190    #[test]
191    fn fencing_token_too_long() {
192        assert_eq!(
193            CommandRecord::try_from_parts(
194                CommandOp::Fence,
195                b"0123456789012345678901234567890123456789"
196            ),
197            Err(CommandPayloadError::FencingTokenTooLong(
198                FencingTokenTooLongError(40)
199            ))
200        );
201    }
202
203    #[rstest]
204    #[case::empty("")]
205    #[case::arbit("arbitrary")]
206    #[case::full("0123456789012345")]
207    fn fence_roundtrip(#[case] token: &str) {
208        let cmd = CommandRecord::Fence(FencingToken::try_from(token.to_compact_string()).unwrap());
209        assert_eq!(
210            CommandRecord::try_from_parts(CommandOp::Fence, token.as_bytes()),
211            Ok(cmd.clone())
212        );
213        roundtrip(cmd, 1 + token.len());
214    }
215
216    #[rstest]
217    #[case::empty(b"")]
218    #[case::too_small(b"0123")]
219    #[case::too_big(b"0123456789")]
220    fn trim_point_size(#[case] payload: &[u8]) {
221        assert_eq!(
222            CommandRecord::try_from_parts(CommandOp::Trim, payload),
223            Err(CommandPayloadError::TrimPointSize(payload.len()))
224        );
225    }
226
227    proptest! {
228        #[test]
229        fn trim_roundtrip(trim_point in any::<SeqNum>()) {
230            let cmd = CommandRecord::Trim(trim_point);
231            assert_eq!(CommandRecord::try_from_parts(CommandOp::Trim, trim_point.to_be_bytes().as_slice()), Ok(cmd.clone()));
232            roundtrip(cmd, 9);
233        }
234    }
235
236    #[test]
237    fn decode_invalid_command() {
238        let try_convert = |raw: &[u8]| CommandRecord::try_from(raw);
239        assert_eq!(
240            try_convert(&[]),
241            Err(InternalRecordError::Truncated("CommandOrdinal"))
242        );
243        assert_eq!(
244            try_convert(&[0xff]),
245            Err(InternalRecordError::InvalidValue(
246                "CommandOrdinal",
247                "unknown"
248            ))
249        );
250        assert_eq!(
251            try_convert(&[CommandOp::Fence.ordinal(), 0xff, 0xff]),
252            Err(InternalRecordError::InvalidValue(
253                "CommandPayload",
254                "fencing token not valid utf8"
255            ))
256        );
257        assert_eq!(
258            try_convert(&[
259                CommandOp::Fence.ordinal(),
260                b'0',
261                b'1',
262                b'2',
263                b'3',
264                b'4',
265                b'5',
266                b'6',
267                b'7',
268                b'8',
269                b'9',
270                b'0',
271                b'1',
272                b'2',
273                b'3',
274                b'4',
275                b'5',
276                b'6',
277                b'7',
278                b'8',
279                b'9',
280                b'0',
281                b'1',
282                b'2',
283                b'3',
284                b'4',
285                b'5',
286                b'6',
287                b'7',
288                b'8',
289                b'9',
290                b'0',
291                b'1',
292                b'2',
293                b'3',
294                b'4',
295                b'5',
296                b'6',
297                b'7',
298                b'8',
299                b'9',
300            ]),
301            Err(CommandPayloadError::FencingTokenTooLong(FencingTokenTooLongError(40)).into())
302        );
303        assert_eq!(
304            try_convert(&[CommandOp::Trim.ordinal(), 0xff]),
305            Err(CommandPayloadError::TrimPointSize(1).into())
306        );
307    }
308}