1use alloc::boxed::Box;
16use alloc::collections::VecDeque;
17use alloc::vec::Vec;
18use core::marker::PhantomData;
19
20use usb_device::class_prelude::{
21 ControlIn, ControlOut, InterfaceNumber, StringIndex, UsbBus, UsbBusAllocator, UsbClass,
22};
23use usb_device::descriptor::{BosWriter, DescriptorWriter};
24use usb_device::endpoint::{EndpointAddress, EndpointIn, EndpointOut};
25use usb_device::{LangID, UsbError};
26use wasefire_board_api::Error;
27use wasefire_board_api::platform::protocol::{Api, Event};
28use wasefire_error::Code;
29use wasefire_logger as log;
30
31use crate::common::{Decoder, Encoder};
32
33pub struct Impl<'a, B: UsbBus, T: HasRpc<'a, B>> {
34 _never: !,
35 _phantom: PhantomData<(&'a (), B, T)>,
36}
37
38pub trait HasRpc<'a, B: UsbBus> {
39 fn with_rpc<R>(f: impl FnOnce(&mut Rpc<'a, B>) -> R) -> R;
40 fn vendor(request: &[u8]) -> Result<Box<[u8]>, Error>;
41}
42
43impl<'a, B: UsbBus, T: HasRpc<'a, B>> Api for Impl<'a, B, T> {
44 fn read() -> Result<Option<Box<[u8]>>, Error> {
45 T::with_rpc(|x| x.read())
46 }
47
48 fn write(response: &[u8]) -> Result<(), Error> {
49 T::with_rpc(|x| x.write(response))
50 }
51
52 fn enable() -> Result<(), Error> {
53 T::with_rpc(|x| x.enable())
54 }
55
56 fn vendor(request: &[u8]) -> Result<Box<[u8]>, Error> {
57 T::vendor(request)
58 }
59}
60
61pub struct Rpc<'a, B: UsbBus> {
62 interface: InterfaceNumber,
63 read_ep: EndpointOut<'a, B>,
64 write_ep: EndpointIn<'a, B>,
65 state: State,
66}
67
68impl<'a, B: UsbBus> Rpc<'a, B> {
69 pub fn new(usb_bus: &'a UsbBusAllocator<B>) -> Self {
70 let interface = usb_bus.interface();
71 let read_ep = usb_bus.bulk(MAX_PACKET_SIZE);
72 let write_ep = usb_bus.bulk(MAX_PACKET_SIZE);
73 Rpc { interface, read_ep, write_ep, state: State::Disabled }
74 }
75
76 pub fn read(&mut self) -> Result<Option<Box<[u8]>>, Error> {
77 let result = self.state.read()?;
78 match &result {
79 #[cfg(not(feature = "defmt"))]
80 Some(result) => log::debug!("Reading {:02x?}", result),
81 #[cfg(feature = "defmt")]
82 Some(result) => log::debug!("Reading {=[u8]:02x}", result),
83 None => log::debug!("Reading (no message)"),
84 }
85 Ok(result)
86 }
87
88 pub fn write(&mut self, response: &[u8]) -> Result<(), Error> {
89 #[cfg(not(feature = "defmt"))]
90 log::debug!("Writing {:02x?}", response);
91 #[cfg(feature = "defmt")]
92 log::debug!("Writing {=[u8]:02x}", response);
93 self.state.write(response, &self.write_ep)
94 }
95
96 pub fn enable(&mut self) -> Result<(), Error> {
97 match self.state {
98 State::Disabled => {
99 self.state = WaitRequest;
100 Ok(())
101 }
102 _ => Err(Error::user(Code::InvalidState)),
103 }
104 }
105
106 pub fn tick(&mut self, push: impl FnOnce(Event)) {
107 if self.state.notify() {
108 push(Event);
109 }
110 }
111}
112
113const MAX_PACKET_SIZE: u16 = 64;
114
115enum State {
116 Disabled,
117 WaitRequest,
118 ReceiveRequest { decoder: Decoder },
119 RequestReady { notified: bool, request: Vec<u8> },
120 WaitResponse,
121 SendResponse { packets: VecDeque<[u8; 64]> },
122}
123use State::*;
124
125impl State {
126 fn read(&mut self) -> Result<Option<Box<[u8]>>, Error> {
127 match self {
128 RequestReady { request, .. } => {
129 let request = core::mem::take(request);
130 log::debug!("Received a message of {} bytes.", request.len());
131 *self = WaitResponse;
132 Ok(Some(request.into_boxed_slice()))
133 }
134 WaitRequest | ReceiveRequest { .. } | SendResponse { .. } => Ok(None),
135 WaitResponse | Disabled => Err(Error::user(Code::InvalidState)),
136 }
137 }
138
139 fn write<B: UsbBus>(&mut self, response: &[u8], ep: &EndpointIn<B>) -> Result<(), Error> {
140 if !matches!(self, WaitResponse) {
141 return Err(Error::user(Code::InvalidState));
142 }
143 let packets: VecDeque<_> = Encoder::new(response).collect();
144 log::debug!("Sending a message of {} bytes in {} packets.", response.len(), packets.len());
145 *self = SendResponse { packets };
146 self.send(ep);
147 Ok(())
148 }
149
150 fn receive<B: UsbBus>(&mut self, ep: &EndpointOut<B>) {
151 let decoder = match self {
152 ReceiveRequest { decoder } => decoder,
153 Disabled => {
154 log::error!("Not receiving data while disabled.");
155 return;
156 }
157 _ => {
158 *self = ReceiveRequest { decoder: Decoder::default() };
159 match self {
160 ReceiveRequest { decoder } => decoder,
161 _ => unreachable!(),
162 }
163 }
164 };
165 let mut packet = [0; MAX_PACKET_SIZE as usize];
166 let len = ep.read(&mut packet).unwrap();
167 if len != MAX_PACKET_SIZE as usize {
168 log::warn!("Received a packet of {} bytes instead of 64.", len);
169 *self = WaitRequest;
170 return;
171 }
172 match core::mem::take(decoder).push(&packet) {
173 None => {
174 log::warn!("Received invalid packet 0x{:02x}", packet[0]);
175 *self = WaitRequest;
176 }
177 Some(Ok(request)) => {
178 log::trace!("Received a message of {} bytes.", request.len());
179 *self = RequestReady { notified: false, request };
180 }
181 Some(Err(x)) => {
182 log::trace!("Received a packet.");
183 *decoder = x;
184 }
185 }
186 }
187
188 fn send<B: UsbBus>(&mut self, ep: &EndpointIn<B>) {
189 let packets = match self {
190 Disabled => {
191 log::error!("Not sending data while disabled.");
192 return;
193 }
194 SendResponse { packets } => packets,
195 _ => return,
196 };
197 let packet = match packets.pop_front() {
198 Some(x) => x,
199 None => {
200 log::warn!("Invalid state: SendResponse with no packets.");
201 *self = WaitRequest;
202 return;
203 }
204 };
205 let len = match ep.write(&packet) {
206 Err(UsbError::WouldBlock) => {
207 log::warn!("Failed to send packet, retrying later.");
208 packets.push_front(packet);
209 return;
210 }
211 x => x.unwrap(),
212 };
213 if len != MAX_PACKET_SIZE as usize {
214 log::warn!("Sent a packet of {} bytes instead of 64.", len);
215 *self = WaitRequest;
216 return;
217 }
218 let remaining = packets.len();
219 if packets.is_empty() {
220 *self = WaitRequest;
221 }
222 log::trace!("Sent the next packet ({} remaining).", remaining);
223 }
224
225 fn notify(&mut self) -> bool {
226 match self {
227 RequestReady { notified, .. } => !core::mem::replace(notified, true),
228 _ => false,
229 }
230 }
231}
232
233impl<B: UsbBus> UsbClass<B> for Rpc<'_, B> {
234 fn get_configuration_descriptors(
235 &self, writer: &mut DescriptorWriter,
236 ) -> usb_device::Result<()> {
237 writer.iad(self.interface, 1, 0xff, 0x58, 0x01, None)?;
238 writer.interface(self.interface, 0xff, 0x58, 0x01)?;
239 writer.endpoint(&self.write_ep)?;
240 writer.endpoint(&self.read_ep)?;
241 Ok(())
242 }
243
244 fn get_bos_descriptors(&self, _: &mut BosWriter) -> usb_device::Result<()> {
245 Ok(())
247 }
248
249 fn get_string(&self, _: StringIndex, _id: LangID) -> Option<&str> {
250 None
252 }
253
254 fn reset(&mut self) {
255 self.state = match self.state {
256 State::Disabled => State::Disabled,
257 _ => State::WaitRequest,
258 };
259 }
260
261 fn poll(&mut self) {
262 }
264
265 fn control_out(&mut self, _: ControlOut<B>) {
266 }
268
269 fn control_in(&mut self, _: ControlIn<B>) {
270 }
272
273 fn endpoint_setup(&mut self, _: EndpointAddress) {
274 }
276
277 fn endpoint_out(&mut self, addr: EndpointAddress) {
278 if self.read_ep.address() != addr {
279 return;
280 }
281 self.state.receive(&self.read_ep);
282 }
283
284 fn endpoint_in_complete(&mut self, addr: EndpointAddress) {
285 if self.write_ep.address() != addr {
286 return;
287 }
288 self.state.send(&self.write_ep);
289 }
290}