1use super::common::Feature;
4use crate::config::{read_config, write_config, ReadOnly, WriteOnly};
5use crate::hal::Hal;
6use crate::queue::VirtQueue;
7use crate::transport::{InterruptStatus, Transport};
8use crate::Error;
9use alloc::{boxed::Box, string::String};
10use core::cmp::min;
11use core::mem::{offset_of, size_of};
12use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes, KnownLayout};
13
14pub struct VirtIOInput<H: Hal, T: Transport> {
20    transport: T,
21    event_queue: VirtQueue<H, QUEUE_SIZE>,
22    status_queue: VirtQueue<H, QUEUE_SIZE>,
23    event_buf: Box<[InputEvent; 32]>,
24}
25
26impl<H: Hal, T: Transport> VirtIOInput<H, T> {
27    pub fn new(mut transport: T) -> Result<Self, Error> {
29        let mut event_buf = Box::new([InputEvent::default(); QUEUE_SIZE]);
30
31        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
32
33        let mut event_queue = VirtQueue::new(
34            &mut transport,
35            QUEUE_EVENT,
36            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
37            negotiated_features.contains(Feature::RING_EVENT_IDX),
38        )?;
39        let status_queue = VirtQueue::new(
40            &mut transport,
41            QUEUE_STATUS,
42            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
43            negotiated_features.contains(Feature::RING_EVENT_IDX),
44        )?;
45        for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
46            let token = unsafe { event_queue.add(&[], &mut [event.as_mut_bytes()])? };
48            assert_eq!(token, i as u16);
49        }
50        if event_queue.should_notify() {
51            transport.notify(QUEUE_EVENT);
52        }
53
54        transport.finish_init();
55
56        Ok(VirtIOInput {
57            transport,
58            event_queue,
59            status_queue,
60            event_buf,
61        })
62    }
63
64    pub fn ack_interrupt(&mut self) -> InterruptStatus {
66        self.transport.ack_interrupt()
67    }
68
69    pub fn pop_pending_event(&mut self) -> Option<InputEvent> {
71        if let Some(token) = self.event_queue.peek_used() {
72            let event = &mut self.event_buf[token as usize];
73            unsafe {
75                self.event_queue
76                    .pop_used(token, &[], &mut [event.as_mut_bytes()])
77                    .ok()?;
78            }
79            let event_saved = *event;
80            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &mut [event.as_mut_bytes()]) }
83            {
84                assert_eq!(new_token, token);
88                if self.event_queue.should_notify() {
89                    self.transport.notify(QUEUE_EVENT);
90                }
91                return Some(event_saved);
92            }
93        }
94        None
95    }
96
97    pub fn query_config_select(
100        &mut self,
101        select: InputConfigSelect,
102        subsel: u8,
103        out: &mut [u8],
104    ) -> Result<u8, Error> {
105        write_config!(self.transport, Config, select, select as u8)?;
106        write_config!(self.transport, Config, subsel, subsel)?;
107        let size: u8 = read_config!(self.transport, Config, size)?;
108        let size_to_copy = min(usize::from(size), out.len());
110        for (i, out_item) in out.iter_mut().take(size_to_copy).enumerate() {
111            *out_item = self
112                .transport
113                .read_config_space(offset_of!(Config, data) + i * size_of::<u8>())?;
114        }
115
116        Ok(size)
117    }
118
119    fn query_config_select_alloc(
122        &mut self,
123        select: InputConfigSelect,
124        subsel: u8,
125    ) -> Result<Box<[u8]>, Error> {
126        write_config!(self.transport, Config, select, select as u8)?;
127        write_config!(self.transport, Config, subsel, subsel)?;
128        let size = usize::from(read_config!(self.transport, Config, size)?);
129        if size > CONFIG_DATA_MAX_LENGTH {
130            return Err(Error::IoError);
131        }
132        let mut buf = <[u8]>::new_box_zeroed_with_elems(size).unwrap();
133        for i in 0..size {
134            buf[i] = self
135                .transport
136                .read_config_space(offset_of!(Config, data) + i * size_of::<u8>())?;
137        }
138        Ok(buf)
139    }
140
141    fn query_config_string(
146        &mut self,
147        select: InputConfigSelect,
148        subsel: u8,
149    ) -> Result<String, Error> {
150        Ok(String::from_utf8(
151            self.query_config_select_alloc(select, subsel)?.into(),
152        )?)
153    }
154
155    pub fn name(&mut self) -> Result<String, Error> {
157        self.query_config_string(InputConfigSelect::IdName, 0)
158    }
159
160    pub fn serial_number(&mut self) -> Result<String, Error> {
162        self.query_config_string(InputConfigSelect::IdSerial, 0)
163    }
164
165    pub fn ids(&mut self) -> Result<DevIDs, Error> {
167        let mut ids = DevIDs::default();
168        let size = self.query_config_select(InputConfigSelect::IdDevids, 0, ids.as_mut_bytes())?;
169        if usize::from(size) == size_of::<DevIDs>() {
170            Ok(ids)
171        } else {
172            Err(Error::IoError)
173        }
174    }
175
176    pub fn prop_bits(&mut self) -> Result<Box<[u8]>, Error> {
178        self.query_config_select_alloc(InputConfigSelect::PropBits, 0)
179    }
180
181    pub fn ev_bits(&mut self, event_type: u8) -> Result<Box<[u8]>, Error> {
185        self.query_config_select_alloc(InputConfigSelect::EvBits, event_type)
186    }
187
188    pub fn abs_info(&mut self, axis: u8) -> Result<AbsInfo, Error> {
190        let mut info = AbsInfo::default();
191        let size =
192            self.query_config_select(InputConfigSelect::AbsInfo, axis, info.as_mut_bytes())?;
193        if usize::from(size) == size_of::<AbsInfo>() {
194            Ok(info)
195        } else {
196            Err(Error::IoError)
197        }
198    }
199}
200
201unsafe impl<H: Hal, T: Transport + Send> Send for VirtIOInput<H, T> where
203    VirtQueue<H, QUEUE_SIZE>: Send
204{
205}
206
207unsafe impl<H: Hal, T: Transport + Sync> Sync for VirtIOInput<H, T> where
209    VirtQueue<H, QUEUE_SIZE>: Sync
210{
211}
212
213impl<H: Hal, T: Transport> Drop for VirtIOInput<H, T> {
214    fn drop(&mut self) {
215        self.transport.queue_unset(QUEUE_EVENT);
218        self.transport.queue_unset(QUEUE_STATUS);
219    }
220}
221
222const CONFIG_DATA_MAX_LENGTH: usize = 128;
223
224#[repr(u8)]
226#[derive(Debug, Clone, Copy)]
227pub enum InputConfigSelect {
228    IdName = 0x01,
230    IdSerial = 0x02,
232    IdDevids = 0x03,
234    PropBits = 0x10,
238    EvBits = 0x11,
244    AbsInfo = 0x12,
247}
248
249#[derive(FromBytes, Immutable, IntoBytes)]
250#[repr(C)]
251struct Config {
252    select: WriteOnly<u8>,
253    subsel: WriteOnly<u8>,
254    size: ReadOnly<u8>,
255    _reserved: [ReadOnly<u8>; 5],
256    data: [ReadOnly<u8>; CONFIG_DATA_MAX_LENGTH],
257}
258
259#[repr(C)]
261#[derive(Clone, Debug, Default, Eq, FromBytes, Immutable, IntoBytes, KnownLayout, PartialEq)]
262pub struct AbsInfo {
263    pub min: u32,
265    pub max: u32,
267    pub fuzz: u32,
269    pub flat: u32,
271    pub res: u32,
273}
274
275#[repr(C)]
277#[derive(Clone, Debug, Default, Eq, FromBytes, Immutable, IntoBytes, KnownLayout, PartialEq)]
278pub struct DevIDs {
279    pub bustype: u16,
281    pub vendor: u16,
283    pub product: u16,
285    pub version: u16,
287}
288
289#[repr(C)]
292#[derive(Clone, Copy, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
293pub struct InputEvent {
294    pub event_type: u16,
296    pub code: u16,
298    pub value: u32,
300}
301
302const QUEUE_EVENT: u16 = 0;
303const QUEUE_STATUS: u16 = 1;
304const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX
305    .union(Feature::RING_INDIRECT_DESC)
306    .union(Feature::VERSION_1);
307
308const QUEUE_SIZE: usize = 32;
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::{
315        hal::fake::FakeHal,
316        transport::{
317            fake::{FakeTransport, QueueStatus, State},
318            DeviceType,
319        },
320    };
321    use alloc::{sync::Arc, vec};
322    use core::convert::TryInto;
323    use std::sync::Mutex;
324
325    #[test]
326    fn config() {
327        const DEFAULT_DATA: ReadOnly<u8> = ReadOnly::new(0);
328        let config_space = Config {
329            select: WriteOnly::default(),
330            subsel: WriteOnly::default(),
331            size: ReadOnly::new(0),
332            _reserved: Default::default(),
333            data: [DEFAULT_DATA; 128],
334        };
335        let state = Arc::new(Mutex::new(State::new(
336            vec![QueueStatus::default(), QueueStatus::default()],
337            config_space,
338        )));
339        let transport = FakeTransport {
340            device_type: DeviceType::Block,
341            max_queue_size: QUEUE_SIZE.try_into().unwrap(),
342            device_features: 0,
343            state: state.clone(),
344        };
345        let mut input = VirtIOInput::<FakeHal, FakeTransport<Config>>::new(transport).unwrap();
346
347        set_data(
348            &mut state.lock().unwrap().config_space,
349            "Test input device".as_bytes(),
350        );
351        assert_eq!(input.name().unwrap(), "Test input device");
352        assert_eq!(
353            state.lock().unwrap().config_space.select.0,
354            InputConfigSelect::IdName as u8
355        );
356        assert_eq!(state.lock().unwrap().config_space.subsel.0, 0);
357
358        set_data(
359            &mut state.lock().unwrap().config_space,
360            "Serial number".as_bytes(),
361        );
362        assert_eq!(input.serial_number().unwrap(), "Serial number");
363        assert_eq!(
364            state.lock().unwrap().config_space.select.0,
365            InputConfigSelect::IdSerial as u8
366        );
367        assert_eq!(state.lock().unwrap().config_space.subsel.0, 0);
368
369        let ids = DevIDs {
370            bustype: 0x4242,
371            product: 0x0067,
372            vendor: 0x1234,
373            version: 0x4321,
374        };
375        set_data(&mut state.lock().unwrap().config_space, ids.as_bytes());
376        assert_eq!(input.ids().unwrap(), ids);
377        assert_eq!(
378            state.lock().unwrap().config_space.select.0,
379            InputConfigSelect::IdDevids as u8
380        );
381        assert_eq!(state.lock().unwrap().config_space.subsel.0, 0);
382
383        set_data(&mut state.lock().unwrap().config_space, &[0x12, 0x34, 0x56]);
384        assert_eq!(input.prop_bits().unwrap().as_ref(), &[0x12, 0x34, 0x56]);
385        assert_eq!(
386            state.lock().unwrap().config_space.select.0,
387            InputConfigSelect::PropBits as u8
388        );
389        assert_eq!(state.lock().unwrap().config_space.subsel.0, 0);
390
391        set_data(&mut state.lock().unwrap().config_space, &[0x42, 0x66]);
392        assert_eq!(input.ev_bits(3).unwrap().as_ref(), &[0x42, 0x66]);
393        assert_eq!(
394            state.lock().unwrap().config_space.select.0,
395            InputConfigSelect::EvBits as u8
396        );
397        assert_eq!(state.lock().unwrap().config_space.subsel.0, 3);
398
399        let abs_info = AbsInfo {
400            min: 12,
401            max: 1234,
402            fuzz: 4,
403            flat: 10,
404            res: 2,
405        };
406        set_data(&mut state.lock().unwrap().config_space, abs_info.as_bytes());
407        assert_eq!(input.abs_info(5).unwrap(), abs_info);
408        assert_eq!(
409            state.lock().unwrap().config_space.select.0,
410            InputConfigSelect::AbsInfo as u8
411        );
412        assert_eq!(state.lock().unwrap().config_space.subsel.0, 5);
413    }
414
415    fn set_data(config_space: &mut Config, value: &[u8]) {
416        config_space.size.0 = value.len().try_into().unwrap();
417        for (i, &byte) in value.into_iter().enumerate() {
418            config_space.data[i].0 = byte;
419        }
420    }
421}