1use std::{fmt, str::Utf8Error};
2
3use bytes::{BufMut, Bytes};
4use compact_str::CompactString;
5use enum_ordinalize::Ordinalize;
6
7use super::{
8 Encodable, FencingTokenTooLongError, MeteredSize, RecordDecodeError, fencing::FencingToken,
9};
10use crate::{deep_size::DeepSize, record::SeqNum};
11
12pub const COMMAND_ID_FENCE: &[u8] = b"fence";
13pub const COMMAND_ID_TRIM: &[u8] = b"trim";
14
15#[derive(Debug, PartialEq, Eq, Clone, Copy, Ordinalize)]
16#[repr(u8)]
17pub enum CommandOp {
18 Fence,
19 Trim,
20}
21
22impl CommandOp {
23 pub fn to_id(self) -> &'static [u8] {
24 match self {
25 Self::Fence => COMMAND_ID_FENCE,
26 Self::Trim => COMMAND_ID_TRIM,
27 }
28 }
29
30 pub fn from_id(name: &[u8]) -> Option<Self> {
31 match name {
32 COMMAND_ID_FENCE => Some(Self::Fence),
33 COMMAND_ID_TRIM => Some(Self::Trim),
34 _ => None,
35 }
36 }
37}
38
39impl fmt::Display for CommandOp {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 let name = std::str::from_utf8(self.to_id()).map_err(|_| fmt::Error)?;
42 f.write_str(name)
43 }
44}
45
46#[derive(Debug, PartialEq, Eq, Clone)]
47pub enum CommandRecord {
48 Fence(FencingToken),
49 Trim(SeqNum),
50}
51
52impl DeepSize for CommandRecord {
53 fn deep_size(&self) -> usize {
54 match self {
55 Self::Fence(token) => token.deep_size(),
56 Self::Trim(seq_num) => seq_num.deep_size(),
57 }
58 }
59}
60
61impl MeteredSize for CommandRecord {
62 fn metered_size(&self) -> usize {
63 8 + 2
64 + self.op().to_id().len()
65 + match self {
66 Self::Fence(token) => token.len(),
67 Self::Trim(trim_point) => size_of_val(trim_point),
68 }
69 }
70}
71
72impl CommandRecord {
73 pub fn op(&self) -> CommandOp {
74 match self {
75 CommandRecord::Fence(_) => CommandOp::Fence,
76 CommandRecord::Trim(_) => CommandOp::Trim,
77 }
78 }
79
80 pub fn payload(&self) -> Bytes {
81 match self {
82 Self::Fence(token) => Bytes::copy_from_slice(token.as_bytes()),
83 Self::Trim(trim_point) => Bytes::copy_from_slice(&trim_point.to_be_bytes()),
84 }
85 }
86
87 pub fn try_from_parts(op: CommandOp, payload: &[u8]) -> Result<Self, CommandPayloadError> {
88 match op {
89 CommandOp::Fence => {
90 let token = CompactString::from_utf8(payload)
91 .map_err(CommandPayloadError::InvalidUtf8)?
92 .try_into()?;
93 Ok(Self::Fence(token))
94 }
95 CommandOp::Trim => {
96 let trim_point = SeqNum::from_be_bytes(
97 payload
98 .try_into()
99 .map_err(|_| CommandPayloadError::TrimPointSize(payload.len()))?,
100 );
101 Ok(Self::Trim(trim_point))
102 }
103 }
104 }
105}
106
107impl TryFrom<&[u8]> for CommandRecord {
108 type Error = RecordDecodeError;
109
110 fn try_from(record: &[u8]) -> Result<Self, Self::Error> {
111 if record.is_empty() {
112 return Err(RecordDecodeError::Truncated("CommandOrdinal"));
113 }
114 let op = CommandOp::from_ordinal(record[0])
115 .ok_or(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))?;
116 Self::try_from_parts(op, &record[1..]).map_err(Into::into)
117 }
118}
119
120impl Encodable for CommandRecord {
121 fn encoded_size(&self) -> usize {
122 1 + match self {
123 CommandRecord::Fence(token) => token.len(),
124 CommandRecord::Trim(trim_point) => size_of_val(trim_point),
125 }
126 }
127
128 fn encode_into(&self, buf: &mut impl BufMut) {
129 buf.put_u8(self.op().ordinal());
130 match self {
131 CommandRecord::Fence(token) => {
132 buf.put_slice(token.as_bytes());
133 }
134 CommandRecord::Trim(trim_point) => {
135 buf.put_u64(*trim_point);
136 }
137 }
138 }
139}
140
141#[derive(Debug, PartialEq, thiserror::Error)]
142pub enum CommandPayloadError {
143 #[error("invalid UTF-8")]
144 InvalidUtf8(Utf8Error),
145 #[error(transparent)]
146 FencingTokenTooLong(#[from] FencingTokenTooLongError),
147 #[error("earliest sequence number to trim to was {0} bytes, must be 8")]
148 TrimPointSize(usize),
149}
150
151impl From<CommandPayloadError> for RecordDecodeError {
152 fn from(e: CommandPayloadError) -> Self {
153 match e {
154 CommandPayloadError::InvalidUtf8(_) => {
155 RecordDecodeError::InvalidValue("CommandPayload", "fencing token not valid utf8")
156 }
157 CommandPayloadError::FencingTokenTooLong(_) => {
158 RecordDecodeError::InvalidValue("CommandPayload", "fencing token too long")
159 }
160 CommandPayloadError::TrimPointSize(_) => {
161 RecordDecodeError::InvalidValue("CommandPayload", "trim point size")
162 }
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use compact_str::ToCompactString;
170 use enum_ordinalize::Ordinalize;
171 use proptest::prelude::*;
172 use rstest::rstest;
173
174 use super::*;
175
176 fn roundtrip(cmd: CommandRecord, expected_len: usize) {
177 assert_eq!(cmd.encoded_size(), expected_len);
178 let encoded = cmd.to_bytes();
179 assert_eq!(encoded.len(), expected_len);
180 assert_eq!(CommandRecord::try_from(encoded.as_ref()), Ok(cmd));
181 }
182
183 #[test]
184 fn command_op_names() {
185 for cmd in CommandOp::VARIANTS {
186 let name = cmd.to_id();
187 assert_eq!(CommandOp::from_id(name), Some(*cmd));
188 }
189 assert_eq!(CommandOp::from_id(b""), None);
190 assert_eq!(CommandOp::from_id(b"invalid"), None);
191 }
192
193 #[test]
194 fn fencing_token_invalid_utf8() {
195 assert!(matches!(
196 CommandRecord::try_from_parts(CommandOp::Fence, &[0xff]),
197 Err(CommandPayloadError::InvalidUtf8(_))
198 ));
199 }
200
201 #[test]
202 fn fencing_token_too_long() {
203 assert_eq!(
204 CommandRecord::try_from_parts(
205 CommandOp::Fence,
206 b"0123456789012345678901234567890123456789"
207 ),
208 Err(CommandPayloadError::FencingTokenTooLong(
209 FencingTokenTooLongError(40)
210 ))
211 );
212 }
213
214 #[rstest]
215 #[case::empty("")]
216 #[case::arbit("arbitrary")]
217 #[case::full("0123456789012345")]
218 fn fence_roundtrip(#[case] token: &str) {
219 let cmd = CommandRecord::Fence(FencingToken::try_from(token.to_compact_string()).unwrap());
220 assert_eq!(
221 CommandRecord::try_from_parts(CommandOp::Fence, token.as_bytes()),
222 Ok(cmd.clone())
223 );
224 roundtrip(cmd, 1 + token.len());
225 }
226
227 #[rstest]
228 #[case::empty(b"")]
229 #[case::too_small(b"0123")]
230 #[case::too_big(b"0123456789")]
231 fn trim_point_size(#[case] payload: &[u8]) {
232 assert_eq!(
233 CommandRecord::try_from_parts(CommandOp::Trim, payload),
234 Err(CommandPayloadError::TrimPointSize(payload.len()))
235 );
236 }
237
238 #[test]
239 fn metered_size_is_computed_without_materializing_payload() {
240 let fence =
241 CommandRecord::Fence(FencingToken::try_from("fence-me".to_compact_string()).unwrap());
242 assert_eq!(
243 fence.metered_size(),
244 8 + 2 + CommandOp::Fence.to_id().len() + "fence-me".len()
245 );
246
247 let trim = CommandRecord::Trim(42);
248 assert_eq!(
249 trim.metered_size(),
250 8 + 2 + CommandOp::Trim.to_id().len() + size_of_val(&42u64)
251 );
252 }
253
254 proptest! {
255 #[test]
256 fn trim_roundtrip(trim_point in any::<SeqNum>()) {
257 let cmd = CommandRecord::Trim(trim_point);
258 assert_eq!(CommandRecord::try_from_parts(CommandOp::Trim, trim_point.to_be_bytes().as_slice()), Ok(cmd.clone()));
259 roundtrip(cmd, 9);
260 }
261 }
262
263 #[test]
264 fn decode_invalid_command() {
265 let try_convert = |raw: &[u8]| CommandRecord::try_from(raw);
266 assert_eq!(
267 try_convert(&[]),
268 Err(RecordDecodeError::Truncated("CommandOrdinal"))
269 );
270 assert_eq!(
271 try_convert(&[0xff]),
272 Err(RecordDecodeError::InvalidValue("CommandOrdinal", "unknown"))
273 );
274 assert_eq!(
275 try_convert(&[CommandOp::Fence.ordinal(), 0xff, 0xff]),
276 Err(RecordDecodeError::InvalidValue(
277 "CommandPayload",
278 "fencing token not valid utf8"
279 ))
280 );
281 assert_eq!(
282 try_convert(&[
283 CommandOp::Fence.ordinal(),
284 b'0',
285 b'1',
286 b'2',
287 b'3',
288 b'4',
289 b'5',
290 b'6',
291 b'7',
292 b'8',
293 b'9',
294 b'0',
295 b'1',
296 b'2',
297 b'3',
298 b'4',
299 b'5',
300 b'6',
301 b'7',
302 b'8',
303 b'9',
304 b'0',
305 b'1',
306 b'2',
307 b'3',
308 b'4',
309 b'5',
310 b'6',
311 b'7',
312 b'8',
313 b'9',
314 b'0',
315 b'1',
316 b'2',
317 b'3',
318 b'4',
319 b'5',
320 b'6',
321 b'7',
322 b'8',
323 b'9',
324 ]),
325 Err(CommandPayloadError::FencingTokenTooLong(FencingTokenTooLongError(40)).into())
326 );
327 assert_eq!(
328 try_convert(&[CommandOp::Trim.ordinal(), 0xff]),
329 Err(CommandPayloadError::TrimPointSize(1).into())
330 );
331 }
332}