wasefire_protocol_usb/
device.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use alloc::boxed::Box;
16use alloc::collections::VecDeque;
17use alloc::vec::Vec;
18use core::marker::PhantomData;
19
20use usb_device::class_prelude::{
21    ControlIn, ControlOut, InterfaceNumber, StringIndex, UsbBus, UsbBusAllocator, UsbClass,
22};
23use usb_device::descriptor::{BosWriter, DescriptorWriter};
24use usb_device::endpoint::{EndpointAddress, EndpointIn, EndpointOut};
25use usb_device::{LangID, UsbError};
26use wasefire_board_api::Error;
27use wasefire_board_api::platform::protocol::{Api, Event};
28use wasefire_error::Code;
29use wasefire_logger as log;
30
31use crate::common::{Decoder, Encoder};
32
33pub struct Impl<'a, B: UsbBus, T: HasRpc<'a, B>> {
34    _never: !,
35    _phantom: PhantomData<(&'a (), B, T)>,
36}
37
38pub trait HasRpc<'a, B: UsbBus> {
39    fn with_rpc<R>(f: impl FnOnce(&mut Rpc<'a, B>) -> R) -> R;
40    fn vendor(request: &[u8]) -> Result<Box<[u8]>, Error>;
41}
42
43impl<'a, B: UsbBus, T: HasRpc<'a, B>> Api for Impl<'a, B, T> {
44    fn read() -> Result<Option<Box<[u8]>>, Error> {
45        T::with_rpc(|x| x.read())
46    }
47
48    fn write(response: &[u8]) -> Result<(), Error> {
49        T::with_rpc(|x| x.write(response))
50    }
51
52    fn enable() -> Result<(), Error> {
53        T::with_rpc(|x| x.enable())
54    }
55
56    fn vendor(request: &[u8]) -> Result<Box<[u8]>, Error> {
57        T::vendor(request)
58    }
59}
60
61pub struct Rpc<'a, B: UsbBus> {
62    interface: InterfaceNumber,
63    read_ep: EndpointOut<'a, B>,
64    write_ep: EndpointIn<'a, B>,
65    state: State,
66}
67
68impl<'a, B: UsbBus> Rpc<'a, B> {
69    pub fn new(usb_bus: &'a UsbBusAllocator<B>) -> Self {
70        let interface = usb_bus.interface();
71        let read_ep = usb_bus.bulk(MAX_PACKET_SIZE);
72        let write_ep = usb_bus.bulk(MAX_PACKET_SIZE);
73        Rpc { interface, read_ep, write_ep, state: State::Disabled }
74    }
75
76    pub fn read(&mut self) -> Result<Option<Box<[u8]>>, Error> {
77        let result = self.state.read()?;
78        match &result {
79            #[cfg(not(feature = "defmt"))]
80            Some(result) => log::debug!("Reading {:02x?}", result),
81            #[cfg(feature = "defmt")]
82            Some(result) => log::debug!("Reading {=[u8]:02x}", result),
83            None => log::debug!("Reading (no message)"),
84        }
85        Ok(result)
86    }
87
88    pub fn write(&mut self, response: &[u8]) -> Result<(), Error> {
89        #[cfg(not(feature = "defmt"))]
90        log::debug!("Writing {:02x?}", response);
91        #[cfg(feature = "defmt")]
92        log::debug!("Writing {=[u8]:02x}", response);
93        self.state.write(response, &self.write_ep)
94    }
95
96    pub fn enable(&mut self) -> Result<(), Error> {
97        match self.state {
98            State::Disabled => {
99                self.state = WaitRequest;
100                Ok(())
101            }
102            _ => Err(Error::user(Code::InvalidState)),
103        }
104    }
105
106    pub fn tick(&mut self, push: impl FnOnce(Event)) {
107        if self.state.notify() {
108            push(Event);
109        }
110    }
111}
112
113const MAX_PACKET_SIZE: u16 = 64;
114
115enum State {
116    Disabled,
117    WaitRequest,
118    ReceiveRequest { decoder: Decoder },
119    RequestReady { notified: bool, request: Vec<u8> },
120    WaitResponse,
121    SendResponse { packets: VecDeque<[u8; 64]> },
122}
123use State::*;
124
125impl State {
126    fn read(&mut self) -> Result<Option<Box<[u8]>>, Error> {
127        match self {
128            RequestReady { request, .. } => {
129                let request = core::mem::take(request);
130                log::debug!("Received a message of {} bytes.", request.len());
131                *self = WaitResponse;
132                Ok(Some(request.into_boxed_slice()))
133            }
134            WaitRequest | ReceiveRequest { .. } | SendResponse { .. } => Ok(None),
135            WaitResponse | Disabled => Err(Error::user(Code::InvalidState)),
136        }
137    }
138
139    fn write<B: UsbBus>(&mut self, response: &[u8], ep: &EndpointIn<B>) -> Result<(), Error> {
140        if !matches!(self, WaitResponse) {
141            return Err(Error::user(Code::InvalidState));
142        }
143        let packets: VecDeque<_> = Encoder::new(response).collect();
144        log::debug!("Sending a message of {} bytes in {} packets.", response.len(), packets.len());
145        *self = SendResponse { packets };
146        self.send(ep);
147        Ok(())
148    }
149
150    fn receive<B: UsbBus>(&mut self, ep: &EndpointOut<B>) {
151        let decoder = match self {
152            ReceiveRequest { decoder } => decoder,
153            Disabled => {
154                log::error!("Not receiving data while disabled.");
155                return;
156            }
157            _ => {
158                *self = ReceiveRequest { decoder: Decoder::default() };
159                match self {
160                    ReceiveRequest { decoder } => decoder,
161                    _ => unreachable!(),
162                }
163            }
164        };
165        let mut packet = [0; MAX_PACKET_SIZE as usize];
166        let len = ep.read(&mut packet).unwrap();
167        if len != MAX_PACKET_SIZE as usize {
168            log::warn!("Received a packet of {} bytes instead of 64.", len);
169            *self = WaitRequest;
170            return;
171        }
172        match core::mem::take(decoder).push(&packet) {
173            None => {
174                log::warn!("Received invalid packet 0x{:02x}", packet[0]);
175                *self = WaitRequest;
176            }
177            Some(Ok(request)) => {
178                log::trace!("Received a message of {} bytes.", request.len());
179                *self = RequestReady { notified: false, request };
180            }
181            Some(Err(x)) => {
182                log::trace!("Received a packet.");
183                *decoder = x;
184            }
185        }
186    }
187
188    fn send<B: UsbBus>(&mut self, ep: &EndpointIn<B>) {
189        let packets = match self {
190            Disabled => {
191                log::error!("Not sending data while disabled.");
192                return;
193            }
194            SendResponse { packets } => packets,
195            _ => return,
196        };
197        let packet = match packets.pop_front() {
198            Some(x) => x,
199            None => {
200                log::warn!("Invalid state: SendResponse with no packets.");
201                *self = WaitRequest;
202                return;
203            }
204        };
205        let len = match ep.write(&packet) {
206            Err(UsbError::WouldBlock) => {
207                log::warn!("Failed to send packet, retrying later.");
208                packets.push_front(packet);
209                return;
210            }
211            x => x.unwrap(),
212        };
213        if len != MAX_PACKET_SIZE as usize {
214            log::warn!("Sent a packet of {} bytes instead of 64.", len);
215            *self = WaitRequest;
216            return;
217        }
218        let remaining = packets.len();
219        if packets.is_empty() {
220            *self = WaitRequest;
221        }
222        log::trace!("Sent the next packet ({} remaining).", remaining);
223    }
224
225    fn notify(&mut self) -> bool {
226        match self {
227            RequestReady { notified, .. } => !core::mem::replace(notified, true),
228            _ => false,
229        }
230    }
231}
232
233impl<B: UsbBus> UsbClass<B> for Rpc<'_, B> {
234    fn get_configuration_descriptors(
235        &self, writer: &mut DescriptorWriter,
236    ) -> usb_device::Result<()> {
237        writer.iad(self.interface, 1, 0xff, 0x58, 0x01, None)?;
238        writer.interface(self.interface, 0xff, 0x58, 0x01)?;
239        writer.endpoint(&self.write_ep)?;
240        writer.endpoint(&self.read_ep)?;
241        Ok(())
242    }
243
244    fn get_bos_descriptors(&self, _: &mut BosWriter) -> usb_device::Result<()> {
245        // We don't have any capabilities.
246        Ok(())
247    }
248
249    fn get_string(&self, _: StringIndex, _id: LangID) -> Option<&str> {
250        // We don't have strings.
251        None
252    }
253
254    fn reset(&mut self) {
255        self.state = match self.state {
256            State::Disabled => State::Disabled,
257            _ => State::WaitRequest,
258        };
259    }
260
261    fn poll(&mut self) {
262        // We probably don't need to do anything here.
263    }
264
265    fn control_out(&mut self, _: ControlOut<B>) {
266        // We probably don't need to do anything here.
267    }
268
269    fn control_in(&mut self, _: ControlIn<B>) {
270        // We probably don't need to do anything here.
271    }
272
273    fn endpoint_setup(&mut self, _: EndpointAddress) {
274        // We probably don't need to do anything here.
275    }
276
277    fn endpoint_out(&mut self, addr: EndpointAddress) {
278        if self.read_ep.address() != addr {
279            return;
280        }
281        self.state.receive(&self.read_ep);
282    }
283
284    fn endpoint_in_complete(&mut self, addr: EndpointAddress) {
285        if self.write_ep.address() != addr {
286            return;
287        }
288        self.state.send(&self.write_ep);
289    }
290}