Skip to main content

smoo_proto/
lib.rs

1#![no_std]
2
3use core::{convert::TryFrom, fmt};
4
5/// ASCII magic prefix for Ident handshake messages.
6pub const IDENT_MAGIC: [u8; 4] = *b"SMOO";
7/// Number of bytes in an encoded [`Ident`] message.
8pub const IDENT_LEN: usize = 8;
9/// Vendor control request opcode used to fetch [`Ident`].
10pub const IDENT_REQUEST: u8 = 0x01;
11/// Number of bytes in an encoded [`Request`] control message.
12pub const REQUEST_LEN: usize = 28;
13/// Number of bytes in an encoded [`Response`] control message.
14pub const RESPONSE_LEN: usize = 28;
15/// bmRequestType for CONFIG_EXPORTS (host → gadget, vendor, interface, OUT).
16pub const CONFIG_EXPORTS_REQ_TYPE: u8 = 0x41;
17/// Vendor control bRequest used to apply CONFIG_EXPORTS.
18pub const CONFIG_EXPORTS_REQUEST: u8 = 0x02;
19/// Vendor control bRequest used to fetch [`SmooStatusV0`].
20pub const SMOO_STATUS_REQUEST: u8 = 0x03;
21/// bmRequestType for SMOO status/heartbeat (device → host, vendor, interface).
22pub const SMOO_STATUS_REQ_TYPE: u8 = 0xA1;
23/// Number of bytes returned by [`SmooStatusV0`].
24pub const SMOO_STATUS_LEN: usize = 16;
25/// Supported status payload version.
26pub const SMOO_STATUS_VERSION: u16 = 0;
27/// Status flag indicating an active export.
28pub const SMOO_STATUS_FLAG_EXPORT_ACTIVE: u16 = 1 << 0;
29
30/// Errors surfaced while decoding protocol messages.
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum ProtoError {
33    /// Buffer length did not match the protocol expectation.
34    InvalidLength { expected: usize, actual: usize },
35    /// Incoming opcode is unsupported.
36    InvalidOpcode(u8),
37    /// Ident magic prefix did not match `SMOO`.
38    InvalidMagic,
39    /// Payload or struct version mismatch.
40    InvalidVersion { expected: u16, actual: u16 },
41    /// Field value failed validation.
42    InvalidValue(&'static str),
43}
44
45impl fmt::Display for ProtoError {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            ProtoError::InvalidLength { expected, actual } => {
49                write!(f, "invalid message length {actual}, expected {expected}")
50            }
51            ProtoError::InvalidOpcode(op) => write!(f, "invalid opcode {op}"),
52            ProtoError::InvalidMagic => write!(f, "invalid ident magic"),
53            ProtoError::InvalidVersion { expected, actual } => write!(
54                f,
55                "unsupported payload version {actual}, expected {expected}"
56            ),
57            ProtoError::InvalidValue(field) => write!(f, "invalid field value: {field}"),
58        }
59    }
60}
61
62/// Result alias for protocol parsing operations.
63pub type Result<T> = core::result::Result<T, ProtoError>;
64
65/// Control-plane operations issued by ublk.
66#[derive(Clone, Copy, Debug, PartialEq, Eq)]
67#[repr(u8)]
68pub enum OpCode {
69    Read = 0,
70    Write = 1,
71    Flush = 2,
72    Discard = 3,
73}
74
75impl TryFrom<u8> for OpCode {
76    type Error = ProtoError;
77
78    fn try_from(value: u8) -> Result<Self> {
79        match value {
80            0 => Ok(Self::Read),
81            1 => Ok(Self::Write),
82            2 => Ok(Self::Flush),
83            3 => Ok(Self::Discard),
84            other => Err(ProtoError::InvalidOpcode(other)),
85        }
86    }
87}
88
89impl From<OpCode> for u8 {
90    fn from(op: OpCode) -> Self {
91        op as u8
92    }
93}
94
95/// Ident handshake sent from the gadget to the host.
96#[derive(Clone, Copy, Debug, PartialEq, Eq)]
97pub struct Ident {
98    pub major: u16,
99    pub minor: u16,
100}
101
102impl Ident {
103    pub const fn new(major: u16, minor: u16) -> Self {
104        Self { major, minor }
105    }
106
107    pub fn encode(self) -> [u8; IDENT_LEN] {
108        let mut buf = [0u8; IDENT_LEN];
109        buf[0..4].copy_from_slice(&IDENT_MAGIC);
110        buf[4..6].copy_from_slice(&self.major.to_le_bytes());
111        buf[6..8].copy_from_slice(&self.minor.to_le_bytes());
112        buf
113    }
114
115    pub fn decode(bytes: [u8; IDENT_LEN]) -> Result<Self> {
116        if bytes[0..4] != IDENT_MAGIC {
117            return Err(ProtoError::InvalidMagic);
118        }
119        let major = u16::from_le_bytes([bytes[4], bytes[5]]);
120        let minor = u16::from_le_bytes([bytes[6], bytes[7]]);
121        Ok(Self { major, minor })
122    }
123}
124
125/// Request message emitted by the gadget.
126#[derive(Clone, Copy, Debug, PartialEq, Eq)]
127pub struct Request {
128    pub export_id: u32,
129    pub request_id: u32,
130    pub op: OpCode,
131    pub lba: u64,
132    pub num_blocks: u32,
133    pub flags: u32,
134}
135
136impl Request {
137    pub const fn new(
138        export_id: u32,
139        request_id: u32,
140        op: OpCode,
141        lba: u64,
142        num_blocks: u32,
143        flags: u32,
144    ) -> Self {
145        Self {
146            export_id,
147            request_id,
148            op,
149            lba,
150            num_blocks,
151            flags,
152        }
153    }
154
155    pub fn encode(self) -> [u8; REQUEST_LEN] {
156        let mut buf = [0u8; REQUEST_LEN];
157        buf[0] = self.op.into();
158        buf[4..8].copy_from_slice(&self.request_id.to_le_bytes());
159        buf[8..12].copy_from_slice(&self.export_id.to_le_bytes());
160        buf[12..20].copy_from_slice(&self.lba.to_le_bytes());
161        buf[20..24].copy_from_slice(&self.num_blocks.to_le_bytes());
162        buf[24..28].copy_from_slice(&self.flags.to_le_bytes());
163        buf
164    }
165
166    pub fn decode(bytes: [u8; REQUEST_LEN]) -> Result<Self> {
167        let request_id = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
168        let export_id = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
169        let lba = u64::from_le_bytes([
170            bytes[12], bytes[13], bytes[14], bytes[15], bytes[16], bytes[17], bytes[18], bytes[19],
171        ]);
172        let num_blocks = u32::from_le_bytes([bytes[20], bytes[21], bytes[22], bytes[23]]);
173        let flags = u32::from_le_bytes([bytes[24], bytes[25], bytes[26], bytes[27]]);
174        let op = OpCode::try_from(bytes[0])?;
175        Ok(Self {
176            export_id,
177            request_id,
178            op,
179            lba,
180            num_blocks,
181            flags,
182        })
183    }
184}
185
186impl TryFrom<&[u8]> for Request {
187    type Error = ProtoError;
188
189    fn try_from(value: &[u8]) -> Result<Self> {
190        if value.len() != REQUEST_LEN {
191            return Err(ProtoError::InvalidLength {
192                expected: REQUEST_LEN,
193                actual: value.len(),
194            });
195        }
196        let mut buf = [0u8; REQUEST_LEN];
197        buf.copy_from_slice(value);
198        Self::decode(buf)
199    }
200}
201
202/// Response message sent back by the host.
203#[derive(Clone, Copy, Debug, PartialEq, Eq)]
204pub struct Response {
205    pub export_id: u32,
206    pub request_id: u32,
207    pub op: OpCode,
208    pub status: u8,
209    pub lba: u64,
210    pub num_blocks: u32,
211    pub flags: u32,
212}
213
214impl Response {
215    pub const fn new(
216        export_id: u32,
217        request_id: u32,
218        op: OpCode,
219        status: u8,
220        lba: u64,
221        num_blocks: u32,
222        flags: u32,
223    ) -> Self {
224        Self {
225            export_id,
226            request_id,
227            op,
228            status,
229            lba,
230            num_blocks,
231            flags,
232        }
233    }
234
235    pub fn encode(self) -> [u8; RESPONSE_LEN] {
236        let mut buf = [0u8; RESPONSE_LEN];
237        buf[0] = self.op.into();
238        buf[1] = self.status;
239        buf[4..8].copy_from_slice(&self.request_id.to_le_bytes());
240        buf[8..12].copy_from_slice(&self.export_id.to_le_bytes());
241        buf[12..20].copy_from_slice(&self.lba.to_le_bytes());
242        buf[20..24].copy_from_slice(&self.num_blocks.to_le_bytes());
243        buf[24..28].copy_from_slice(&self.flags.to_le_bytes());
244        buf
245    }
246
247    pub fn decode(bytes: [u8; RESPONSE_LEN]) -> Result<Self> {
248        let request_id = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
249        let export_id = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
250        let lba = u64::from_le_bytes([
251            bytes[12], bytes[13], bytes[14], bytes[15], bytes[16], bytes[17], bytes[18], bytes[19],
252        ]);
253        let num_blocks = u32::from_le_bytes([bytes[20], bytes[21], bytes[22], bytes[23]]);
254        let flags = u32::from_le_bytes([bytes[24], bytes[25], bytes[26], bytes[27]]);
255        let op = OpCode::try_from(bytes[0])?;
256        let status = bytes[1];
257        Ok(Self {
258            export_id,
259            request_id,
260            op,
261            lba,
262            num_blocks,
263            flags,
264            status,
265        })
266    }
267}
268
269impl TryFrom<&[u8]> for Response {
270    type Error = ProtoError;
271
272    fn try_from(value: &[u8]) -> Result<Self> {
273        if value.len() != RESPONSE_LEN {
274            return Err(ProtoError::InvalidLength {
275                expected: RESPONSE_LEN,
276                actual: value.len(),
277            });
278        }
279        let mut buf = [0u8; RESPONSE_LEN];
280        buf.copy_from_slice(value);
281        Self::decode(buf)
282    }
283}
284
285/// Heartbeat/status payload returned by the gadget.
286#[derive(Clone, Copy, Debug, PartialEq, Eq)]
287pub struct SmooStatusV0 {
288    pub version: u16,
289    pub flags: u16,
290    pub export_count: u32,
291    pub session_id: u64,
292}
293
294impl SmooStatusV0 {
295    /// Create a v0 status payload with the provided flags/count/session.
296    pub const fn new(flags: u16, export_count: u32, session_id: u64) -> Self {
297        Self {
298            version: SMOO_STATUS_VERSION,
299            flags,
300            export_count,
301            session_id,
302        }
303    }
304
305    /// Serialize the status payload to its on-wire representation.
306    pub fn encode(self) -> [u8; SMOO_STATUS_LEN] {
307        let mut buf = [0u8; SMOO_STATUS_LEN];
308        buf[0..2].copy_from_slice(&self.version.to_le_bytes());
309        buf[2..4].copy_from_slice(&self.flags.to_le_bytes());
310        buf[4..8].copy_from_slice(&self.export_count.to_le_bytes());
311        buf[8..16].copy_from_slice(&self.session_id.to_le_bytes());
312        buf
313    }
314
315    /// Decode a status payload from a fixed-size buffer.
316    pub fn decode(bytes: [u8; SMOO_STATUS_LEN]) -> Result<Self> {
317        let version = u16::from_le_bytes([bytes[0], bytes[1]]);
318        if version != SMOO_STATUS_VERSION {
319            return Err(ProtoError::InvalidVersion {
320                expected: SMOO_STATUS_VERSION,
321                actual: version,
322            });
323        }
324        let flags = u16::from_le_bytes([bytes[2], bytes[3]]);
325        let export_count = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
326        let session_id = u64::from_le_bytes([
327            bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
328        ]);
329        Ok(Self {
330            version,
331            flags,
332            export_count,
333            session_id,
334        })
335    }
336
337    /// Decode a status payload from a borrowed slice.
338    pub fn try_from_slice(slice: &[u8]) -> Result<Self> {
339        if slice.len() != SMOO_STATUS_LEN {
340            return Err(ProtoError::InvalidLength {
341                expected: SMOO_STATUS_LEN,
342                actual: slice.len(),
343            });
344        }
345        let mut buf = [0u8; SMOO_STATUS_LEN];
346        buf.copy_from_slice(slice);
347        Self::decode(buf)
348    }
349
350    /// Returns true when the export_active flag is set.
351    pub fn export_active(&self) -> bool {
352        (self.flags & SMOO_STATUS_FLAG_EXPORT_ACTIVE) != 0
353    }
354}
355
356/// Encoded CONFIG_EXPORTS payload for protocol version 0.
357#[derive(Clone, Debug, PartialEq, Eq)]
358pub struct ConfigExportsV0 {
359    entries: heapless::Vec<ConfigExport, 32>,
360}
361
362impl ConfigExportsV0 {
363    /// Supported CONFIG_EXPORTS payload version.
364    pub const VERSION: u16 = 0;
365    /// Number of bytes in the payload header.
366    pub const HEADER_LEN: usize = 8;
367    /// Number of bytes in each entry.
368    pub const ENTRY_LEN: usize = 24;
369    /// Maximum exports supported in one payload.
370    pub const MAX_EXPORTS: usize = 32;
371
372    pub fn new(entries: heapless::Vec<ConfigExport, 32>) -> Result<Self> {
373        Ok(Self { entries })
374    }
375
376    pub fn entries(&self) -> &[ConfigExport] {
377        &self.entries
378    }
379
380    pub fn from_slice(entries: &[ConfigExport]) -> Result<Self> {
381        let mut vec: heapless::Vec<ConfigExport, 32> = heapless::Vec::new();
382        for entry in entries {
383            vec.push(*entry)
384                .map_err(|_| ProtoError::InvalidValue("too many exports"))?;
385        }
386        Self::new(vec)
387    }
388
389    /// Serialize the payload to its wire representation.
390    pub fn encode(
391        &self,
392    ) -> heapless::Vec<u8, { Self::HEADER_LEN + Self::ENTRY_LEN * Self::MAX_EXPORTS }> {
393        let mut buf: heapless::Vec<u8, { Self::HEADER_LEN + Self::ENTRY_LEN * Self::MAX_EXPORTS }> =
394            heapless::Vec::new();
395        buf.resize(Self::HEADER_LEN + self.entries.len() * Self::ENTRY_LEN, 0)
396            .unwrap();
397        buf[0..2].copy_from_slice(&Self::VERSION.to_le_bytes());
398        buf[2..4].copy_from_slice(&(self.entries.len() as u16).to_le_bytes());
399        for (idx, entry) in self.entries.iter().enumerate() {
400            let offset = Self::HEADER_LEN + idx * Self::ENTRY_LEN;
401            buf[offset..offset + 4].copy_from_slice(&entry.export_id.to_le_bytes());
402            buf[offset + 4..offset + 8].copy_from_slice(&entry.block_size.to_le_bytes());
403            buf[offset + 8..offset + 16].copy_from_slice(&entry.size_bytes.to_le_bytes());
404        }
405        buf
406    }
407
408    /// Decode a CONFIG_EXPORTS payload from a borrowed slice.
409    pub fn try_from_slice(bytes: &[u8]) -> Result<Self> {
410        if bytes.len() < Self::HEADER_LEN {
411            return Err(ProtoError::InvalidLength {
412                expected: Self::HEADER_LEN,
413                actual: bytes.len(),
414            });
415        }
416        let version = u16::from_le_bytes([bytes[0], bytes[1]]);
417        if version != Self::VERSION {
418            return Err(ProtoError::InvalidVersion {
419                expected: Self::VERSION,
420                actual: version,
421            });
422        }
423        let count = u16::from_le_bytes([bytes[2], bytes[3]]) as usize;
424        if count > Self::MAX_EXPORTS {
425            return Err(ProtoError::InvalidValue(
426                "CONFIG_EXPORTS count exceeds maximum",
427            ));
428        }
429        let flags = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
430        if flags != 0 {
431            return Err(ProtoError::InvalidValue(
432                "CONFIG_EXPORTS header flags must be zero",
433            ));
434        }
435        let expected_len = Self::HEADER_LEN + count * Self::ENTRY_LEN;
436        if bytes.len() != expected_len {
437            return Err(ProtoError::InvalidLength {
438                expected: expected_len,
439                actual: bytes.len(),
440            });
441        }
442        let mut entries: heapless::Vec<ConfigExport, 32> = heapless::Vec::new();
443        for idx in 0..count {
444            let offset = Self::HEADER_LEN + idx * Self::ENTRY_LEN;
445            let export_id = u32::from_le_bytes([
446                bytes[offset],
447                bytes[offset + 1],
448                bytes[offset + 2],
449                bytes[offset + 3],
450            ]);
451            let block_size = u32::from_le_bytes([
452                bytes[offset + 4],
453                bytes[offset + 5],
454                bytes[offset + 6],
455                bytes[offset + 7],
456            ]);
457            let size_bytes = u64::from_le_bytes([
458                bytes[offset + 8],
459                bytes[offset + 9],
460                bytes[offset + 10],
461                bytes[offset + 11],
462                bytes[offset + 12],
463                bytes[offset + 13],
464                bytes[offset + 14],
465                bytes[offset + 15],
466            ]);
467            entries
468                .push(validate_export(export_id, block_size, size_bytes)?)
469                .map_err(|_| ProtoError::InvalidValue("too many exports"))?;
470        }
471        Ok(Self { entries })
472    }
473}
474
475/// Parameters describing a single export entry in v0 CONFIG_EXPORTS payloads.
476#[derive(Clone, Copy, Debug, PartialEq, Eq)]
477pub struct ConfigExport {
478    pub export_id: u32,
479    pub block_size: u32,
480    pub size_bytes: u64,
481}
482
483fn validate_export(export_id: u32, block_size: u32, size_bytes: u64) -> Result<ConfigExport> {
484    if export_id == 0 {
485        return Err(ProtoError::InvalidValue("export_id must be non-zero"));
486    }
487    if !block_size.is_power_of_two() {
488        return Err(ProtoError::InvalidValue("block size must be power-of-two"));
489    }
490    if !(512..=65536).contains(&block_size) {
491        return Err(ProtoError::InvalidValue(
492            "block size out of supported range",
493        ));
494    }
495    if size_bytes != 0 && !size_bytes.is_multiple_of(block_size as u64) {
496        return Err(ProtoError::InvalidValue(
497            "size_bytes must be multiple of block_size",
498        ));
499    }
500    Ok(ConfigExport {
501        export_id,
502        block_size,
503        size_bytes,
504    })
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[test]
512    fn ident_round_trip() {
513        let ident = Ident::new(1, 2);
514        let bytes = ident.encode();
515        assert_eq!(Ident::decode(bytes).unwrap(), ident);
516    }
517
518    #[test]
519    fn ident_magic_guard() {
520        let mut bytes = Ident::new(1, 2).encode();
521        bytes[0] = b'X';
522        assert!(matches!(
523            Ident::decode(bytes),
524            Err(ProtoError::InvalidMagic)
525        ));
526    }
527
528    #[test]
529    fn request_round_trip() {
530        let req = Request::new(2, 99, OpCode::Write, 42, 8, 0xAA55AA55);
531        let bytes = req.encode();
532        assert_eq!(Request::decode(bytes).unwrap(), req);
533        assert_eq!(Request::try_from(bytes.as_slice()).unwrap(), req);
534    }
535
536    #[test]
537    fn response_round_trip() {
538        let resp = Response::new(3, 77, OpCode::Read, 0, 9001, 16, 0);
539        let bytes = resp.encode();
540        assert_eq!(Response::decode(bytes).unwrap(), resp);
541        assert_eq!(Response::try_from(bytes.as_slice()).unwrap(), resp);
542    }
543
544    #[test]
545    fn status_round_trip() {
546        let status = SmooStatusV0::new(SMOO_STATUS_FLAG_EXPORT_ACTIVE, 1, 0x0102_0304_0506_0708);
547        let bytes = status.encode();
548        assert_eq!(SmooStatusV0::try_from_slice(&bytes).unwrap(), status);
549        assert!(SmooStatusV0::decode(bytes).unwrap().export_active());
550    }
551
552    #[test]
553    fn bad_opcode() {
554        let mut bytes = Request::new(1, 2, OpCode::Flush, 0, 0, 0).encode();
555        bytes[0] = 0xFF;
556        assert!(matches!(
557            Request::decode(bytes),
558            Err(ProtoError::InvalidOpcode(0xFF))
559        ));
560    }
561
562    #[test]
563    fn invalid_len() {
564        assert!(matches!(
565            Request::try_from(&[0u8; 27][..]),
566            Err(ProtoError::InvalidLength {
567                expected: 28,
568                actual: 27
569            })
570        ));
571    }
572
573    #[test]
574    fn config_exports_zero_round_trip() {
575        let payload = ConfigExportsV0::new(heapless::Vec::new()).unwrap();
576        let encoded = payload.encode();
577        let decoded = ConfigExportsV0::try_from_slice(&encoded).unwrap();
578        assert!(decoded.entries().is_empty());
579    }
580
581    #[test]
582    fn config_exports_single_round_trip() {
583        let mut entries = heapless::Vec::new();
584        entries
585            .push(ConfigExport {
586                export_id: 7,
587                block_size: 4096,
588                size_bytes: 4096 * 8,
589            })
590            .unwrap();
591        let payload = ConfigExportsV0::new(entries).unwrap();
592        let encoded = payload.encode();
593        let decoded = ConfigExportsV0::try_from_slice(&encoded).unwrap();
594        let export = decoded.entries().first().unwrap();
595        assert_eq!(export.export_id, 7);
596        assert_eq!(export.block_size, 4096);
597        assert_eq!(export.size_bytes, 4096 * 8);
598    }
599
600    #[test]
601    fn config_exports_invalid_flags() {
602        let mut encoded = ConfigExportsV0::new(heapless::Vec::new()).unwrap().encode();
603        encoded[4] = 1;
604        assert!(matches!(
605            ConfigExportsV0::try_from_slice(&encoded),
606            Err(ProtoError::InvalidValue(_))
607        ));
608    }
609
610    #[test]
611    fn config_exports_invalid_block_size() {
612        let mut entries = heapless::Vec::new();
613        entries
614            .push(ConfigExport {
615                export_id: 1,
616                block_size: 1024,
617                size_bytes: 0,
618            })
619            .unwrap();
620        let mut encoded = ConfigExportsV0::new(entries).unwrap().encode();
621        encoded[4..8].copy_from_slice(&1u32.to_le_bytes()); // count = 1
622        encoded[8..12].copy_from_slice(&500u32.to_le_bytes());
623        assert!(matches!(
624            ConfigExportsV0::try_from_slice(&encoded),
625            Err(ProtoError::InvalidValue(_))
626        ));
627    }
628}