1use std::{fmt, str::Utf8Error};
2
3use bytes::{BufMut, Bytes};
4use compact_str::CompactString;
5use enum_ordinalize::Ordinalize;
6
7use super::{Encodable, FencingTokenTooLongError, InternalRecordError, fencing::FencingToken};
8use crate::{deep_size::DeepSize, record::SeqNum};
9
10pub const COMMAND_ID_FENCE: &[u8] = b"fence";
11pub const COMMAND_ID_TRIM: &[u8] = b"trim";
12
13#[derive(Debug, PartialEq, Eq, Clone, Copy, Ordinalize)]
14#[repr(u8)]
15pub enum CommandOp {
16 Fence,
17 Trim,
18}
19
20impl CommandOp {
21 pub fn to_id(self) -> &'static [u8] {
22 match self {
23 Self::Fence => COMMAND_ID_FENCE,
24 Self::Trim => COMMAND_ID_TRIM,
25 }
26 }
27
28 pub fn from_id(name: &[u8]) -> Option<Self> {
29 match name {
30 COMMAND_ID_FENCE => Some(Self::Fence),
31 COMMAND_ID_TRIM => Some(Self::Trim),
32 _ => None,
33 }
34 }
35}
36
37impl fmt::Display for CommandOp {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 let name = std::str::from_utf8(self.to_id()).map_err(|_| fmt::Error)?;
40 f.write_str(name)
41 }
42}
43
44#[derive(Debug, PartialEq, Eq, Clone)]
45pub enum CommandRecord {
46 Fence(FencingToken),
47 Trim(SeqNum),
48}
49
50impl DeepSize for CommandRecord {
51 fn deep_size(&self) -> usize {
52 match self {
53 Self::Fence(token) => token.deep_size(),
54 Self::Trim(seq_num) => seq_num.deep_size(),
55 }
56 }
57}
58
59impl CommandRecord {
60 pub fn op(&self) -> CommandOp {
61 match self {
62 CommandRecord::Fence(_) => CommandOp::Fence,
63 CommandRecord::Trim(_) => CommandOp::Trim,
64 }
65 }
66
67 pub fn payload(&self) -> Bytes {
68 match self {
69 Self::Fence(token) => Bytes::copy_from_slice(token.as_bytes()),
70 Self::Trim(trim_point) => Bytes::copy_from_slice(&trim_point.to_be_bytes()),
71 }
72 }
73
74 pub fn try_from_parts(op: CommandOp, payload: &[u8]) -> Result<Self, CommandPayloadError> {
75 match op {
76 CommandOp::Fence => {
77 let token = CompactString::from_utf8(payload)
78 .map_err(CommandPayloadError::InvalidUtf8)?
79 .try_into()?;
80 Ok(Self::Fence(token))
81 }
82 CommandOp::Trim => {
83 let trim_point = SeqNum::from_be_bytes(
84 payload
85 .try_into()
86 .map_err(|_| CommandPayloadError::TrimPointSize(payload.len()))?,
87 );
88 Ok(Self::Trim(trim_point))
89 }
90 }
91 }
92}
93
94impl TryFrom<&[u8]> for CommandRecord {
95 type Error = InternalRecordError;
96
97 fn try_from(record: &[u8]) -> Result<Self, Self::Error> {
98 if record.is_empty() {
99 return Err(InternalRecordError::Truncated("CommandOrdinal"));
100 }
101 let op = CommandOp::from_ordinal(record[0]).ok_or(InternalRecordError::InvalidValue(
102 "CommandOrdinal",
103 "unknown",
104 ))?;
105 Self::try_from_parts(op, &record[1..]).map_err(Into::into)
106 }
107}
108
109impl Encodable for CommandRecord {
110 fn encoded_size(&self) -> usize {
111 1 + match self {
112 CommandRecord::Fence(token) => token.len(),
113 CommandRecord::Trim(trim_point) => size_of_val(trim_point),
114 }
115 }
116
117 fn encode_into(&self, buf: &mut impl BufMut) {
118 buf.put_u8(self.op().ordinal());
119 match self {
120 CommandRecord::Fence(token) => {
121 buf.put_slice(token.as_bytes());
122 }
123 CommandRecord::Trim(trim_point) => {
124 buf.put_u64(*trim_point);
125 }
126 }
127 }
128}
129
130#[derive(Debug, PartialEq, thiserror::Error)]
131pub enum CommandPayloadError {
132 #[error("invalid UTF-8")]
133 InvalidUtf8(Utf8Error),
134 #[error(transparent)]
135 FencingTokenTooLong(#[from] FencingTokenTooLongError),
136 #[error("earliest sequence number to trim to was {0} bytes, must be 8")]
137 TrimPointSize(usize),
138}
139
140impl From<CommandPayloadError> for InternalRecordError {
141 fn from(e: CommandPayloadError) -> Self {
142 match e {
143 CommandPayloadError::InvalidUtf8(_) => {
144 InternalRecordError::InvalidValue("CommandPayload", "fencing token not valid utf8")
145 }
146 CommandPayloadError::FencingTokenTooLong(_) => {
147 InternalRecordError::InvalidValue("CommandPayload", "fencing token too long")
148 }
149 CommandPayloadError::TrimPointSize(_) => {
150 InternalRecordError::InvalidValue("CommandPayload", "trim point size")
151 }
152 }
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use compact_str::ToCompactString;
159 use enum_ordinalize::Ordinalize;
160 use proptest::prelude::*;
161 use rstest::rstest;
162
163 use super::*;
164
165 fn roundtrip(cmd: CommandRecord, expected_len: usize) {
166 assert_eq!(cmd.encoded_size(), expected_len);
167 let encoded = cmd.to_bytes();
168 assert_eq!(encoded.len(), expected_len);
169 assert_eq!(CommandRecord::try_from(encoded.as_ref()), Ok(cmd));
170 }
171
172 #[test]
173 fn command_op_names() {
174 for cmd in CommandOp::VARIANTS {
175 let name = cmd.to_id();
176 assert_eq!(CommandOp::from_id(name), Some(*cmd));
177 }
178 assert_eq!(CommandOp::from_id(b""), None);
179 assert_eq!(CommandOp::from_id(b"invalid"), None);
180 }
181
182 #[test]
183 fn fencing_token_invalid_utf8() {
184 assert!(matches!(
185 CommandRecord::try_from_parts(CommandOp::Fence, &[0xff]),
186 Err(CommandPayloadError::InvalidUtf8(_))
187 ));
188 }
189
190 #[test]
191 fn fencing_token_too_long() {
192 assert_eq!(
193 CommandRecord::try_from_parts(
194 CommandOp::Fence,
195 b"0123456789012345678901234567890123456789"
196 ),
197 Err(CommandPayloadError::FencingTokenTooLong(
198 FencingTokenTooLongError(40)
199 ))
200 );
201 }
202
203 #[rstest]
204 #[case::empty("")]
205 #[case::arbit("arbitrary")]
206 #[case::full("0123456789012345")]
207 fn fence_roundtrip(#[case] token: &str) {
208 let cmd = CommandRecord::Fence(FencingToken::try_from(token.to_compact_string()).unwrap());
209 assert_eq!(
210 CommandRecord::try_from_parts(CommandOp::Fence, token.as_bytes()),
211 Ok(cmd.clone())
212 );
213 roundtrip(cmd, 1 + token.len());
214 }
215
216 #[rstest]
217 #[case::empty(b"")]
218 #[case::too_small(b"0123")]
219 #[case::too_big(b"0123456789")]
220 fn trim_point_size(#[case] payload: &[u8]) {
221 assert_eq!(
222 CommandRecord::try_from_parts(CommandOp::Trim, payload),
223 Err(CommandPayloadError::TrimPointSize(payload.len()))
224 );
225 }
226
227 proptest! {
228 #[test]
229 fn trim_roundtrip(trim_point in any::<SeqNum>()) {
230 let cmd = CommandRecord::Trim(trim_point);
231 assert_eq!(CommandRecord::try_from_parts(CommandOp::Trim, trim_point.to_be_bytes().as_slice()), Ok(cmd.clone()));
232 roundtrip(cmd, 9);
233 }
234 }
235
236 #[test]
237 fn decode_invalid_command() {
238 let try_convert = |raw: &[u8]| CommandRecord::try_from(raw);
239 assert_eq!(
240 try_convert(&[]),
241 Err(InternalRecordError::Truncated("CommandOrdinal"))
242 );
243 assert_eq!(
244 try_convert(&[0xff]),
245 Err(InternalRecordError::InvalidValue(
246 "CommandOrdinal",
247 "unknown"
248 ))
249 );
250 assert_eq!(
251 try_convert(&[CommandOp::Fence.ordinal(), 0xff, 0xff]),
252 Err(InternalRecordError::InvalidValue(
253 "CommandPayload",
254 "fencing token not valid utf8"
255 ))
256 );
257 assert_eq!(
258 try_convert(&[
259 CommandOp::Fence.ordinal(),
260 b'0',
261 b'1',
262 b'2',
263 b'3',
264 b'4',
265 b'5',
266 b'6',
267 b'7',
268 b'8',
269 b'9',
270 b'0',
271 b'1',
272 b'2',
273 b'3',
274 b'4',
275 b'5',
276 b'6',
277 b'7',
278 b'8',
279 b'9',
280 b'0',
281 b'1',
282 b'2',
283 b'3',
284 b'4',
285 b'5',
286 b'6',
287 b'7',
288 b'8',
289 b'9',
290 b'0',
291 b'1',
292 b'2',
293 b'3',
294 b'4',
295 b'5',
296 b'6',
297 b'7',
298 b'8',
299 b'9',
300 ]),
301 Err(CommandPayloadError::FencingTokenTooLong(FencingTokenTooLongError(40)).into())
302 );
303 assert_eq!(
304 try_convert(&[CommandOp::Trim.ordinal(), 0xff]),
305 Err(CommandPayloadError::TrimPointSize(1).into())
306 );
307 }
308}