1use std::{fmt, str::Utf8Error};
2
3use bytes::{BufMut, Bytes};
4use compact_str::CompactString;
5use strum::FromRepr;
6
7use super::{
8 Encodable, FencingTokenTooLongError, MeteredSize, RecordDecodeError, fencing::FencingToken,
9};
10use crate::{deep_size::DeepSize, record::SeqNum};
11
12pub const COMMAND_ID_FENCE: &[u8] = b"fence";
13pub const COMMAND_ID_TRIM: &[u8] = b"trim";
14
15#[derive(Debug, PartialEq, Eq, Clone, Copy, FromRepr)]
16#[repr(u8)]
17pub enum CommandOp {
18 Fence,
19 Trim,
20}
21
22impl CommandOp {
23 pub fn to_id(self) -> &'static [u8] {
24 match self {
25 Self::Fence => COMMAND_ID_FENCE,
26 Self::Trim => COMMAND_ID_TRIM,
27 }
28 }
29
30 pub fn from_id(name: &[u8]) -> Option<Self> {
31 match name {
32 COMMAND_ID_FENCE => Some(Self::Fence),
33 COMMAND_ID_TRIM => Some(Self::Trim),
34 _ => None,
35 }
36 }
37}
38
39impl fmt::Display for CommandOp {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 let name = std::str::from_utf8(self.to_id()).map_err(|_| fmt::Error)?;
42 f.write_str(name)
43 }
44}
45
46#[derive(Debug, PartialEq, Eq, Clone)]
47pub enum CommandRecord {
48 Fence(FencingToken),
49 Trim(SeqNum),
50}
51
52impl DeepSize for CommandRecord {
53 fn deep_size(&self) -> usize {
54 match self {
55 Self::Fence(token) => token.deep_size(),
56 Self::Trim(seq_num) => seq_num.deep_size(),
57 }
58 }
59}
60
61impl MeteredSize for CommandRecord {
62 fn metered_size(&self) -> usize {
63 8 + 2
64 + self.op().to_id().len()
65 + match self {
66 Self::Fence(token) => token.len(),
67 Self::Trim(trim_point) => size_of_val(trim_point),
68 }
69 }
70}
71
72impl CommandRecord {
73 pub fn op(&self) -> CommandOp {
74 match self {
75 CommandRecord::Fence(_) => CommandOp::Fence,
76 CommandRecord::Trim(_) => CommandOp::Trim,
77 }
78 }
79
80 pub fn payload(&self) -> Bytes {
81 match self {
82 Self::Fence(token) => Bytes::copy_from_slice(token.as_bytes()),
83 Self::Trim(trim_point) => Bytes::copy_from_slice(&trim_point.to_be_bytes()),
84 }
85 }
86
87 pub fn try_from_parts(op: CommandOp, payload: &[u8]) -> Result<Self, CommandPayloadError> {
88 match op {
89 CommandOp::Fence => {
90 let token = CompactString::from_utf8(payload)
91 .map_err(CommandPayloadError::InvalidUtf8)?
92 .try_into()?;
93 Ok(Self::Fence(token))
94 }
95 CommandOp::Trim => {
96 let trim_point = SeqNum::from_be_bytes(
97 payload
98 .try_into()
99 .map_err(|_| CommandPayloadError::TrimPointSize(payload.len()))?,
100 );
101 Ok(Self::Trim(trim_point))
102 }
103 }
104 }
105}
106
107impl TryFrom<&[u8]> for CommandRecord {
108 type Error = RecordDecodeError;
109
110 fn try_from(record: &[u8]) -> Result<Self, Self::Error> {
111 if record.is_empty() {
112 return Err(RecordDecodeError::Truncated("CommandOrdinal"));
113 }
114 let op = CommandOp::from_repr(record[0])
115 .ok_or(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))?;
116 Self::try_from_parts(op, &record[1..]).map_err(Into::into)
117 }
118}
119
120impl Encodable for CommandRecord {
121 fn encoded_size(&self) -> usize {
122 1 + match self {
123 CommandRecord::Fence(token) => token.len(),
124 CommandRecord::Trim(trim_point) => size_of_val(trim_point),
125 }
126 }
127
128 fn encode_into(&self, buf: &mut impl BufMut) {
129 buf.put_u8(self.op() as u8);
130 match self {
131 CommandRecord::Fence(token) => {
132 buf.put_slice(token.as_bytes());
133 }
134 CommandRecord::Trim(trim_point) => {
135 buf.put_u64(*trim_point);
136 }
137 }
138 }
139}
140
141#[derive(Debug, PartialEq, thiserror::Error)]
142pub enum CommandPayloadError {
143 #[error("invalid UTF-8")]
144 InvalidUtf8(Utf8Error),
145 #[error(transparent)]
146 FencingTokenTooLong(#[from] FencingTokenTooLongError),
147 #[error("earliest sequence number to trim to was {0} bytes, must be 8")]
148 TrimPointSize(usize),
149}
150
151impl From<CommandPayloadError> for RecordDecodeError {
152 fn from(e: CommandPayloadError) -> Self {
153 match e {
154 CommandPayloadError::InvalidUtf8(_) => {
155 RecordDecodeError::InvalidValue("CommandPayload", "fencing token not valid utf8")
156 }
157 CommandPayloadError::FencingTokenTooLong(_) => {
158 RecordDecodeError::InvalidValue("CommandPayload", "fencing token too long")
159 }
160 CommandPayloadError::TrimPointSize(_) => {
161 RecordDecodeError::InvalidValue("CommandPayload", "trim point size")
162 }
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use compact_str::ToCompactString;
170 use proptest::prelude::*;
171 use rstest::rstest;
172
173 use super::*;
174
175 fn roundtrip(cmd: CommandRecord, expected_len: usize) {
176 assert_eq!(cmd.encoded_size(), expected_len);
177 let encoded = cmd.to_bytes();
178 assert_eq!(encoded.len(), expected_len);
179 assert_eq!(CommandRecord::try_from(encoded.as_ref()), Ok(cmd));
180 }
181
182 #[test]
183 fn command_op_names() {
184 for cmd in [CommandOp::Fence, CommandOp::Trim] {
185 let name = cmd.to_id();
186 assert_eq!(CommandOp::from_id(name), Some(cmd));
187 }
188 assert_eq!(CommandOp::from_id(b""), None);
189 assert_eq!(CommandOp::from_id(b"invalid"), None);
190 }
191
192 #[test]
193 fn fencing_token_invalid_utf8() {
194 assert!(matches!(
195 CommandRecord::try_from_parts(CommandOp::Fence, &[0xff]),
196 Err(CommandPayloadError::InvalidUtf8(_))
197 ));
198 }
199
200 #[test]
201 fn fencing_token_too_long() {
202 assert_eq!(
203 CommandRecord::try_from_parts(
204 CommandOp::Fence,
205 b"0123456789012345678901234567890123456789"
206 ),
207 Err(CommandPayloadError::FencingTokenTooLong(
208 FencingTokenTooLongError(40)
209 ))
210 );
211 }
212
213 #[rstest]
214 #[case::empty("")]
215 #[case::arbit("arbitrary")]
216 #[case::full("0123456789012345")]
217 fn fence_roundtrip(#[case] token: &str) {
218 let cmd = CommandRecord::Fence(FencingToken::try_from(token.to_compact_string()).unwrap());
219 assert_eq!(
220 CommandRecord::try_from_parts(CommandOp::Fence, token.as_bytes()),
221 Ok(cmd.clone())
222 );
223 roundtrip(cmd, 1 + token.len());
224 }
225
226 #[rstest]
227 #[case::empty(b"")]
228 #[case::too_small(b"0123")]
229 #[case::too_big(b"0123456789")]
230 fn trim_point_size(#[case] payload: &[u8]) {
231 assert_eq!(
232 CommandRecord::try_from_parts(CommandOp::Trim, payload),
233 Err(CommandPayloadError::TrimPointSize(payload.len()))
234 );
235 }
236
237 #[test]
238 fn metered_size_is_computed_without_materializing_payload() {
239 let fence =
240 CommandRecord::Fence(FencingToken::try_from("fence-me".to_compact_string()).unwrap());
241 assert_eq!(
242 fence.metered_size(),
243 8 + 2 + CommandOp::Fence.to_id().len() + "fence-me".len()
244 );
245
246 let trim = CommandRecord::Trim(42);
247 assert_eq!(
248 trim.metered_size(),
249 8 + 2 + CommandOp::Trim.to_id().len() + size_of_val(&42u64)
250 );
251 }
252
253 proptest! {
254 #[test]
255 fn trim_roundtrip(trim_point in any::<SeqNum>()) {
256 let cmd = CommandRecord::Trim(trim_point);
257 assert_eq!(CommandRecord::try_from_parts(CommandOp::Trim, trim_point.to_be_bytes().as_slice()), Ok(cmd.clone()));
258 roundtrip(cmd, 9);
259 }
260 }
261
262 #[test]
263 fn decode_invalid_command() {
264 let try_convert = |raw: &[u8]| CommandRecord::try_from(raw);
265 assert_eq!(
266 try_convert(&[]),
267 Err(RecordDecodeError::Truncated("CommandOrdinal"))
268 );
269 assert_eq!(
270 try_convert(&[0xff]),
271 Err(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))
272 );
273 assert_eq!(
274 try_convert(&[CommandOp::Fence as u8, 0xff, 0xff]),
275 Err(RecordDecodeError::InvalidValue(
276 "CommandPayload",
277 "fencing token not valid utf8"
278 ))
279 );
280 assert_eq!(
281 try_convert(&[
282 CommandOp::Fence as u8,
283 b'0',
284 b'1',
285 b'2',
286 b'3',
287 b'4',
288 b'5',
289 b'6',
290 b'7',
291 b'8',
292 b'9',
293 b'0',
294 b'1',
295 b'2',
296 b'3',
297 b'4',
298 b'5',
299 b'6',
300 b'7',
301 b'8',
302 b'9',
303 b'0',
304 b'1',
305 b'2',
306 b'3',
307 b'4',
308 b'5',
309 b'6',
310 b'7',
311 b'8',
312 b'9',
313 b'0',
314 b'1',
315 b'2',
316 b'3',
317 b'4',
318 b'5',
319 b'6',
320 b'7',
321 b'8',
322 b'9',
323 ]),
324 Err(CommandPayloadError::FencingTokenTooLong(FencingTokenTooLongError(40)).into())
325 );
326 assert_eq!(
327 try_convert(&[CommandOp::Trim as u8, 0xff]),
328 Err(CommandPayloadError::TrimPointSize(1).into())
329 );
330 }
331}