Skip to main content

smoo_gadget_core/
lib.rs

1use anyhow::{Context, Result, anyhow, ensure};
2use dma_heap::HeapKind;
3use smoo_proto::{
4    CONFIG_EXPORTS_REQ_TYPE, CONFIG_EXPORTS_REQUEST, IDENT_LEN, IDENT_REQUEST, Ident, RESPONSE_LEN,
5    Request, Response, SMOO_STATUS_FLAG_EXPORT_ACTIVE, SMOO_STATUS_LEN, SMOO_STATUS_REQ_TYPE,
6    SMOO_STATUS_REQUEST, SmooStatusV0,
7};
8#[cfg(feature = "metrics")]
9use std::time::Instant;
10use std::{
11    cmp,
12    fs::File as StdFile,
13    io,
14    os::fd::{AsRawFd, OwnedFd, RawFd},
15    sync::Arc,
16};
17use tokio::{
18    fs::File,
19    io::{AsyncReadExt, AsyncWriteExt},
20    sync::Mutex,
21    task,
22};
23use tracing::trace;
24
25mod dma;
26mod link;
27#[cfg(feature = "metrics")]
28pub mod metrics;
29mod pump;
30mod runtime;
31mod state_store;
32
33use crate::dma::{BufferPool, dmabuf_transfer_blocking};
34pub use link::{LinkCommand, LinkController, LinkOfflineReason, LinkState};
35#[cfg(feature = "metrics")]
36pub use metrics::*;
37pub use pump::{IoPumpHandle, IoWork};
38pub use runtime::{
39    DeviceHandle, ExportController, ExportReconcileContext, ExportState, GadgetRuntime,
40    RuntimeTunables,
41};
42pub use smoo_gadget_ublk::{
43    SmooUblk, SmooUblkDevice, UblkBuffer, UblkIoRequest, UblkOp, UblkQueueRuntime,
44};
45pub use smoo_proto::{ConfigExport, ConfigExportsV0};
46pub use state_store::{ExportFlags, ExportSpec, PersistedExportRecord, StateStore};
47
48const USB_DIR_IN: u8 = 0x80;
49const USB_TYPE_VENDOR: u8 = 0x40;
50const USB_RECIP_INTERFACE: u8 = 0x01;
51const SMOO_REQ_TYPE: u8 = USB_DIR_IN | USB_TYPE_VENDOR | USB_RECIP_INTERFACE;
52
53const SETUP_STAGE_LEN: usize = 8;
54
55/// File descriptor bundle for a FunctionFS interface (data-plane endpoints only).
56pub struct FunctionfsEndpoints {
57    pub interrupt_in: OwnedFd,
58    pub interrupt_out: OwnedFd,
59    pub bulk_in: OwnedFd,
60    pub bulk_out: OwnedFd,
61}
62
63impl FunctionfsEndpoints {
64    pub fn new(
65        interrupt_in: OwnedFd,
66        interrupt_out: OwnedFd,
67        bulk_in: OwnedFd,
68        bulk_out: OwnedFd,
69    ) -> Self {
70        Self {
71            interrupt_in,
72            interrupt_out,
73            bulk_in,
74            bulk_out,
75        }
76    }
77}
78
79/// Decoded USB control request observed on ep0.
80#[derive(Clone, Copy, Debug)]
81pub struct SetupPacket {
82    request_type: u8,
83    request: u8,
84    value: u16,
85    index: u16,
86    length: u16,
87}
88
89impl SetupPacket {
90    /// Construct a SetupPacket from raw USB control fields.
91    pub fn from_fields(request_type: u8, request: u8, value: u16, index: u16, length: u16) -> Self {
92        let bytes = [
93            request_type,
94            request,
95            value.to_le_bytes()[0],
96            value.to_le_bytes()[1],
97            index.to_le_bytes()[0],
98            index.to_le_bytes()[1],
99            length.to_le_bytes()[0],
100            length.to_le_bytes()[1],
101        ];
102        Self::from_bytes(bytes)
103    }
104
105    fn from_bytes(bytes: [u8; SETUP_STAGE_LEN]) -> Self {
106        Self {
107            request_type: bytes[0],
108            request: bytes[1],
109            value: u16::from_le_bytes([bytes[2], bytes[3]]),
110            index: u16::from_le_bytes([bytes[4], bytes[5]]),
111            length: u16::from_le_bytes([bytes[6], bytes[7]]),
112        }
113    }
114
115    /// bmRequestType
116    pub fn request_type(&self) -> u8 {
117        self.request_type
118    }
119
120    /// bRequest
121    pub fn request(&self) -> u8 {
122        self.request
123    }
124
125    /// wValue
126    pub fn value(&self) -> u16 {
127        self.value
128    }
129
130    /// wIndex
131    pub fn index(&self) -> u16 {
132        self.index
133    }
134
135    /// wLength
136    pub fn length(&self) -> u16 {
137        self.length
138    }
139
140    fn direction(&self) -> ControlDirection {
141        if self.request_type & USB_DIR_IN != 0 {
142            ControlDirection::In
143        } else {
144            ControlDirection::Out
145        }
146    }
147}
148
149/// Gadget configuration parameters that stay constant while the device is active.
150#[derive(Clone, Copy)]
151pub struct GadgetConfig {
152    pub ident: Ident,
153    pub queue_count: u16,
154    pub queue_depth: u16,
155    pub max_io_bytes: usize,
156    pub dma_heap: Option<DmaHeap>,
157}
158
159impl GadgetConfig {
160    pub fn new(
161        ident: Ident,
162        queue_count: u16,
163        queue_depth: u16,
164        max_io_bytes: usize,
165        dma_heap: Option<DmaHeap>,
166    ) -> Self {
167        Self {
168            ident,
169            queue_count,
170            queue_depth,
171            max_io_bytes,
172            dma_heap,
173        }
174    }
175}
176
177#[derive(Clone, Copy, Debug)]
178pub enum DmaHeap {
179    System,
180    Cma,
181    Reserved,
182}
183
184impl DmaHeap {
185    fn to_heap_kind(self) -> HeapKind {
186        match self {
187            DmaHeap::System => HeapKind::System,
188            DmaHeap::Cma => HeapKind::Cma,
189            DmaHeap::Reserved => {
190                HeapKind::Custom(std::path::PathBuf::from("/dev/dma_heap/reserved"))
191            }
192        }
193    }
194}
195
196/// High-level FunctionFS gadget driver.
197pub struct SmooGadget {
198    data_plane: GadgetDataPlane,
199    ident: Ident,
200}
201
202#[async_trait::async_trait]
203pub trait ControlIo {
204    async fn write_in(&mut self, data: &[u8]) -> Result<()>;
205    async fn read_out(&mut self, buf: &mut [u8]) -> Result<()>;
206    async fn stall(&mut self) -> Result<()>;
207}
208
209impl SmooGadget {
210    pub fn new(endpoints: FunctionfsEndpoints, config: GadgetConfig) -> Result<Self> {
211        let FunctionfsEndpoints {
212            interrupt_in,
213            interrupt_out,
214            bulk_in,
215            bulk_out,
216        } = endpoints;
217        Ok(Self {
218            data_plane: GadgetDataPlane::new(
219                interrupt_in,
220                interrupt_out,
221                bulk_in,
222                bulk_out,
223                config.queue_count,
224                config.queue_depth,
225                config.max_io_bytes,
226                config.dma_heap,
227            )?,
228            ident: config.ident,
229        })
230    }
231
232    /// Send a Request message to the host over the interrupt IN endpoint.
233    pub async fn send_request(&self, request: Request) -> Result<()> {
234        self.data_plane.send_request(request).await
235    }
236
237    /// Receive a Response message from the host over the interrupt OUT endpoint.
238    pub async fn read_response(&self) -> Result<Response> {
239        self.data_plane.read_response().await
240    }
241
242    /// Read a bulk payload from the host (bulk OUT → gadget).
243    pub async fn read_bulk(&self, buf: &mut [u8]) -> Result<()> {
244        self.data_plane.read_bulk(buf).await
245    }
246
247    /// Write a bulk payload to the host (bulk IN → host).
248    pub async fn write_bulk(&self, buf: &[u8]) -> Result<()> {
249        self.data_plane.write_bulk(buf).await
250    }
251
252    /// Read a bulk payload directly into a buffer, using DMA-BUF when available.
253    pub async fn read_bulk_buffer(&self, buf: &mut [u8]) -> Result<()> {
254        self.data_plane.read_bulk_buffer(buf).await
255    }
256
257    /// Write a bulk payload from a buffer, using DMA-BUF when available.
258    pub async fn write_bulk_buffer(&self, buf: &mut [u8]) -> Result<()> {
259        self.data_plane.write_bulk_buffer(buf).await
260    }
261
262    /// Access the data-plane controller directly.
263    pub fn data_plane(&self) -> &GadgetDataPlane {
264        &self.data_plane
265    }
266
267    pub fn response_reader(&self) -> Arc<Mutex<File>> {
268        self.data_plane.response_reader()
269    }
270
271    /// Current IDENT response advertised by the gadget.
272    pub fn ident(&self) -> Ident {
273        self.ident
274    }
275
276    /// Create a control-plane helper for parsing vendor SETUP packets.
277    pub fn control_handler(&self) -> GadgetControl {
278        GadgetControl::new(self.ident)
279    }
280}
281
282/// Snapshot of dynamic gadget status advertised via SMOO_STATUS.
283#[derive(Clone, Copy, Debug, PartialEq, Eq)]
284pub struct GadgetStatusReport {
285    pub session_id: u64,
286    pub export_count: u32,
287}
288
289impl GadgetStatusReport {
290    pub fn new(session_id: u64, export_count: u32) -> Self {
291        Self {
292            session_id,
293            export_count,
294        }
295    }
296
297    pub fn export_active(&self) -> bool {
298        self.export_count > 0
299    }
300}
301
302/// Control-plane helper that parses vendor SETUP packets and emits high-level commands.
303#[derive(Clone, Copy, Debug)]
304pub struct GadgetControl {
305    ident: Ident,
306}
307
308impl GadgetControl {
309    fn new(ident: Ident) -> Self {
310        Self { ident }
311    }
312
313    /// Handle a vendor-specific SETUP packet.
314    ///
315    /// Returns [`SetupCommand`] when additional action is required (e.g. CONFIG_EXPORTS).
316    /// All control responses/ACKs are written through `io` internally.
317    pub async fn handle_setup_packet(
318        &self,
319        io: &mut (impl ControlIo + Send),
320        setup: SetupPacket,
321        status: &GadgetStatusReport,
322    ) -> Result<Option<SetupCommand>> {
323        if setup.request() == IDENT_REQUEST && setup.request_type() == SMOO_REQ_TYPE {
324            ensure!(
325                setup.direction() == ControlDirection::In,
326                "GET_IDENT must be an IN transfer"
327            );
328            ensure!(
329                setup.length() as usize >= IDENT_LEN,
330                "GET_IDENT length too small"
331            );
332            trace!("ep0: GET_IDENT");
333            let ident = self.ident.encode();
334            let len = cmp::min(setup.length() as usize, ident.len());
335            io.write_in(&ident[..len])
336                .await
337                .context("reply to GET_IDENT")?;
338            trace!(
339                len,
340                major = self.ident.major,
341                minor = self.ident.minor,
342                "ep0: GET_IDENT reply sent"
343            );
344            return Ok(None);
345        }
346
347        if setup.request() == SMOO_STATUS_REQUEST && setup.request_type() == SMOO_STATUS_REQ_TYPE {
348            ensure!(
349                setup.direction() == ControlDirection::In,
350                "SMOO_STATUS must be an IN transfer"
351            );
352            ensure!(
353                setup.length() as usize >= SMOO_STATUS_LEN,
354                "SMOO_STATUS buffer too small"
355            );
356            trace!(
357                current_exports = status.export_count,
358                session_id = status.session_id,
359                "ep0: SMOO_STATUS"
360            );
361            let mut flags = 0;
362            if status.export_active() {
363                flags |= SMOO_STATUS_FLAG_EXPORT_ACTIVE;
364            }
365            let payload = SmooStatusV0::new(flags, status.export_count, status.session_id);
366            let encoded = payload.encode();
367            let len = cmp::min(encoded.len(), setup.length() as usize);
368            io.write_in(&encoded[..len])
369                .await
370                .context("write SMOO_STATUS response")?;
371            trace!(len, "ep0: SMOO_STATUS reply sent");
372            return Ok(None);
373        }
374
375        if setup.request() == CONFIG_EXPORTS_REQUEST
376            && setup.request_type() == CONFIG_EXPORTS_REQ_TYPE
377        {
378            let len = setup.length() as usize;
379            ensure!(
380                len >= ConfigExportsV0::HEADER_LEN,
381                "CONFIG_EXPORTS payload too short"
382            );
383            trace!(len, "ep0: CONFIG_EXPORTS setup");
384            let mut buf = vec![0u8; len];
385            io.read_out(&mut buf).await.context("read CONFIG_EXPORTS")?;
386            trace!(len, "ep0: CONFIG_EXPORTS payload received");
387            let payload = ConfigExportsV0::try_from_slice(&buf)
388                .map_err(|err| anyhow!("parse CONFIG_EXPORTS payload: {err}"))?;
389            return Ok(Some(SetupCommand::Config(payload)));
390        }
391
392        io.stall()
393            .await
394            .context("stall unsupported control request")?;
395        Err(anyhow!(
396            "unsupported setup request {:#x} type {:#x}",
397            setup.request(),
398            setup.request_type()
399        ))
400    }
401}
402
403/// Commands emitted by [`GadgetControl`] for the runtime to apply.
404#[derive(Clone, Debug)]
405pub enum SetupCommand {
406    Config(ConfigExportsV0),
407}
408
409/// Data-plane controller that owns the FunctionFS interrupt and bulk endpoints.
410///
411/// Today it still drives a single export, but the separation from control-plane handling
412/// allows future work to multiplex multiple exports or schedule heavy work without
413/// blocking EP0.
414pub struct GadgetDataPlane {
415    interrupt_in: Arc<Mutex<File>>,
416    interrupt_out: Arc<Mutex<File>>,
417    bulk_in: Arc<Mutex<File>>,
418    bulk_out: Arc<Mutex<File>>,
419    bulk_in_fd: RawFd,
420    bulk_out_fd: RawFd,
421    read_path_lock: Arc<Mutex<()>>,
422    write_path_lock: Arc<Mutex<()>>,
423    buffers: Option<Mutex<BufferPool>>,
424}
425
426impl GadgetDataPlane {
427    #[allow(clippy::too_many_arguments)]
428    pub(crate) fn new(
429        interrupt_in: OwnedFd,
430        interrupt_out: OwnedFd,
431        bulk_in: OwnedFd,
432        bulk_out: OwnedFd,
433        queue_count: u16,
434        queue_depth: u16,
435        max_io_bytes: usize,
436        dma_heap: Option<DmaHeap>,
437    ) -> Result<Self> {
438        let bulk_in_raw = bulk_in.as_raw_fd();
439        let bulk_out_raw = bulk_out.as_raw_fd();
440        let buffers = if let Some(heap) = dma_heap {
441            let prealloc = queue_count as usize * queue_depth as usize;
442            let cap = prealloc;
443            Some(Mutex::new(
444                BufferPool::new(
445                    bulk_in_raw,
446                    bulk_out_raw,
447                    Some(heap.to_heap_kind()),
448                    max_io_bytes,
449                    prealloc,
450                    cap,
451                )
452                .context("init DMA buffer pool")?,
453            ))
454        } else {
455            None
456        };
457        Ok(Self {
458            interrupt_in: Arc::new(Mutex::new(to_tokio_file(interrupt_in)?)),
459            interrupt_out: Arc::new(Mutex::new(to_tokio_file(interrupt_out)?)),
460            bulk_in: Arc::new(Mutex::new(to_tokio_file(bulk_in)?)),
461            bulk_out: Arc::new(Mutex::new(to_tokio_file(bulk_out)?)),
462            bulk_in_fd: bulk_in_raw,
463            bulk_out_fd: bulk_out_raw,
464            read_path_lock: Arc::new(Mutex::new(())),
465            write_path_lock: Arc::new(Mutex::new(())),
466            buffers,
467        })
468    }
469
470    pub async fn send_request(&self, request: Request) -> Result<()> {
471        let encoded = request.encode();
472        #[cfg(feature = "metrics")]
473        let start = Instant::now();
474        trace!(bytes = encoded.len(), "interrupt IN: sending Request");
475        let mut lock = self.interrupt_in.lock().await;
476        lock.write_all(&encoded)
477            .await
478            .context("write request to interrupt IN")?;
479        lock.flush().await.context("flush interrupt IN")?;
480        #[cfg(feature = "metrics")]
481        crate::metrics::observe_interrupt_in(encoded.len(), start.elapsed());
482        trace!("interrupt IN: Request flushed");
483        Ok(())
484    }
485
486    pub async fn read_response(&self) -> Result<Response> {
487        let mut buf = [0u8; RESPONSE_LEN];
488        #[cfg(feature = "metrics")]
489        let start = Instant::now();
490        trace!(bytes = buf.len(), "interrupt OUT: reading Response");
491        let mut lock = self.interrupt_out.lock().await;
492        lock.read_exact(&mut buf)
493            .await
494            .context("read response from interrupt OUT")?;
495        #[cfg(feature = "metrics")]
496        crate::metrics::observe_interrupt_out(buf.len(), start.elapsed());
497        trace!("interrupt OUT: Response received");
498        Response::try_from(buf.as_slice()).map_err(|err| anyhow!("decode response: {err}"))
499    }
500
501    pub async fn read_bulk(&self, buf: &mut [u8]) -> Result<()> {
502        if buf.is_empty() {
503            return Ok(());
504        }
505        #[cfg(feature = "metrics")]
506        let start = Instant::now();
507        trace!(bytes = buf.len(), "bulk OUT: reading payload");
508        let mut lock = self.bulk_out.lock().await;
509        lock.read_exact(buf)
510            .await
511            .context("read payload from bulk OUT")?;
512        #[cfg(feature = "metrics")]
513        crate::metrics::observe_bulk_out(buf.len(), start.elapsed());
514        trace!("bulk OUT: payload received");
515        Ok(())
516    }
517
518    pub async fn write_bulk(&self, buf: &[u8]) -> Result<()> {
519        if buf.is_empty() {
520            return Ok(());
521        }
522        #[cfg(feature = "metrics")]
523        let start = Instant::now();
524        trace!(bytes = buf.len(), "bulk IN: writing payload");
525        let mut lock = self.bulk_in.lock().await;
526        lock.write_all(buf)
527            .await
528            .context("write payload to bulk IN")?;
529        lock.flush().await.context("flush bulk IN")?;
530        #[cfg(feature = "metrics")]
531        crate::metrics::observe_bulk_in(buf.len(), start.elapsed());
532        Ok(())
533    }
534
535    pub async fn read_bulk_buffer(&self, buf: &mut [u8]) -> Result<()> {
536        if buf.is_empty() {
537            return Ok(());
538        }
539        let len = buf.len();
540        match &self.buffers {
541            Some(pool) => {
542                let mut pool = pool.lock().await;
543                trace!(bytes = len, "bulk OUT: reading payload via buffer pool");
544                let mut handle = pool.checkout();
545                debug_assert!(handle.len() >= len);
546                #[cfg(feature = "metrics")]
547                let start = Instant::now();
548                let result = if let Some(buf_fd) = handle.dma_fd() {
549                    self.queue_dmabuf_transfer(self.bulk_out_fd, len, buf_fd)
550                        .await
551                } else {
552                    self.read_bulk(&mut handle.as_mut_slice()[..len]).await
553                };
554                if result.is_ok() {
555                    #[cfg(feature = "metrics")]
556                    crate::metrics::observe_bulk_out(len, start.elapsed());
557                    handle
558                        .finish_device_write()
559                        .context("invalidate buffer after device write")?;
560                    buf.copy_from_slice(&handle.as_slice()[..len]);
561                }
562                pool.checkin(handle);
563                result.context("bulk OUT buffered transfer")
564            }
565            None => {
566                trace!(bytes = len, "bulk OUT: reading payload via read()");
567                self.read_bulk(buf).await.context("read bulk payload")
568            }
569        }
570    }
571
572    pub async fn write_bulk_buffer(&self, buf: &mut [u8]) -> Result<()> {
573        if buf.is_empty() {
574            return Ok(());
575        }
576        let len = buf.len();
577        match &self.buffers {
578            Some(pool) => {
579                let mut pool = pool.lock().await;
580                trace!(bytes = len, "bulk IN: writing payload via buffer pool");
581                let mut handle = pool.checkout();
582                debug_assert!(handle.len() >= len);
583                handle.as_mut_slice()[..len].copy_from_slice(&buf[..len]);
584                handle
585                    .prepare_device_read()
586                    .context("prepare buffer before device read")?;
587                #[cfg(feature = "metrics")]
588                let start = Instant::now();
589                let result = if let Some(buf_fd) = handle.dma_fd() {
590                    self.queue_dmabuf_transfer(self.bulk_in_fd, len, buf_fd)
591                        .await
592                        .context("FUNCTIONFS dmabuf transfer (IN)")
593                } else {
594                    self.write_bulk(&handle.as_slice()[..len]).await
595                };
596                #[cfg(feature = "metrics")]
597                if result.is_ok() {
598                    crate::metrics::observe_bulk_in(len, start.elapsed());
599                }
600                pool.checkin(handle);
601                result.context("bulk IN buffered transfer")
602            }
603            None => {
604                trace!(bytes = len, "bulk IN: writing payload via write()");
605                self.write_bulk(buf).await.context("write bulk payload")
606            }
607        }
608    }
609
610    async fn queue_dmabuf_transfer(
611        &self,
612        endpoint_fd: RawFd,
613        len: usize,
614        buf_fd: RawFd,
615    ) -> Result<()> {
616        task::spawn_blocking(move || dmabuf_transfer_blocking(endpoint_fd, buf_fd, len))
617            .await
618            .map_err(|err| anyhow!("dma-buf transfer task failed: {err}"))?
619    }
620
621    pub fn response_reader(&self) -> Arc<Mutex<File>> {
622        self.interrupt_out.clone()
623    }
624
625    pub async fn with_read_path<T>(&self, fut: impl Future<Output = Result<T>>) -> Result<T> {
626        let _guard = self.read_path_lock.lock().await;
627        fut.await
628    }
629
630    pub async fn with_write_path<T>(&self, fut: impl Future<Output = Result<T>>) -> Result<T> {
631        let _guard = self.write_path_lock.lock().await;
632        fut.await
633    }
634}
635
636#[derive(Clone, Copy, Debug, PartialEq, Eq)]
637enum ControlDirection {
638    In,
639    Out,
640}
641
642fn to_tokio_file(fd: OwnedFd) -> io::Result<File> {
643    let std_file = StdFile::from(fd);
644    Ok(File::from_std(std_file))
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650
651    #[test]
652    fn config_exports_none() {
653        let payload = [0u8; ConfigExportsV0::HEADER_LEN];
654        let parsed = ConfigExportsV0::try_from_slice(&payload).expect("parse");
655        assert!(parsed.entries().is_empty());
656    }
657
658    #[test]
659    fn config_exports_single() {
660        let mut payload = [0u8; ConfigExportsV0::HEADER_LEN + ConfigExportsV0::ENTRY_LEN];
661        payload[2..4].copy_from_slice(&1u16.to_le_bytes());
662        payload[8..12].copy_from_slice(&1u32.to_le_bytes()); // export_id
663        payload[12..16].copy_from_slice(&4096u32.to_le_bytes());
664        payload[16..24].copy_from_slice(&(4096u64 * 8).to_le_bytes());
665        let parsed = ConfigExportsV0::try_from_slice(&payload).expect("parse");
666        let export = parsed.entries().first().expect("export");
667        assert_eq!(export.block_size, 4096);
668        assert_eq!(export.size_bytes, 4096 * 8);
669    }
670
671    #[test]
672    fn config_exports_invalid_flags() {
673        let mut payload = [0u8; ConfigExportsV0::HEADER_LEN];
674        payload[4] = 1;
675        assert!(ConfigExportsV0::try_from_slice(&payload).is_err());
676    }
677}