stm32_usbd2/
bus.rs

1//! USB peripheral driver.
2
3use core::mem::{self, MaybeUninit};
4use cortex_m::interrupt::{self, Mutex};
5use usb_device::bus::{PollResult, UsbBusAllocator};
6use usb_device::endpoint::{EndpointAddress, EndpointType};
7use usb_device::{Result, UsbDirection, UsbError};
8
9use crate::endpoint::{calculate_count_rx, Endpoint, EndpointStatus, NUM_ENDPOINTS};
10use crate::endpoint_memory::EndpointMemoryAllocator;
11use crate::registers::UsbRegisters;
12use crate::UsbPeripheral;
13
14/// USB peripheral driver for STM32 microcontrollers.
15pub struct UsbBus<USB> {
16    peripheral: USB,
17    regs: Mutex<UsbRegisters<USB>>,
18    endpoints: [Endpoint<USB>; NUM_ENDPOINTS],
19    ep_allocator: EndpointMemoryAllocator<USB>,
20    max_endpoint: usize,
21}
22
23impl<USB: UsbPeripheral> UsbBus<USB> {
24    /// Constructs a new USB peripheral driver.
25    pub fn new(peripheral: USB) -> UsbBusAllocator<Self> {
26        USB::enable();
27
28        let bus = UsbBus {
29            peripheral,
30            regs: Mutex::new(UsbRegisters::new()),
31            ep_allocator: EndpointMemoryAllocator::new(),
32            max_endpoint: 0,
33            endpoints: {
34                let mut endpoints: [MaybeUninit<Endpoint<USB>>; NUM_ENDPOINTS] =
35                    unsafe { MaybeUninit::uninit().assume_init() };
36
37                for i in 0..NUM_ENDPOINTS {
38                    endpoints[i] = MaybeUninit::new(Endpoint::new(i as u8));
39                }
40
41                unsafe { mem::transmute::<_, [Endpoint<USB>; NUM_ENDPOINTS]>(endpoints) }
42            },
43        };
44
45        UsbBusAllocator::new(bus)
46    }
47
48    pub fn free(self) -> USB {
49        self.peripheral
50    }
51
52    /// Simulates a disconnect from the USB bus, causing the host to reset and re-enumerate the
53    /// device.
54    ///
55    /// Mostly used for development. By calling this at the start of your program ensures that the
56    /// host re-enumerates your device after a new program has been flashed.
57    ///
58    /// `disconnect` parameter is used to provide a custom disconnect function.
59    /// This function will be called with USB peripheral powered down
60    /// and interrupts disabled.
61    /// It should perform disconnect in a platform-specific way.
62    pub fn force_reenumeration<F: FnOnce()>(&self, disconnect: F) {
63        interrupt::free(|cs| {
64            let regs = self.regs.borrow(cs);
65
66            let pdwn = regs.cntr.read().pdwn().bit_is_set();
67            regs.cntr.modify(|_, w| w.pdwn().set_bit());
68
69            disconnect();
70
71            regs.cntr.modify(|_, w| w.pdwn().bit(pdwn));
72        });
73    }
74}
75
76impl<USB: UsbPeripheral> usb_device::bus::UsbBus for UsbBus<USB> {
77    fn alloc_ep(
78        &mut self,
79        ep_dir: UsbDirection,
80        ep_addr: Option<EndpointAddress>,
81        ep_type: EndpointType,
82        max_packet_size: u16,
83        _interval: u8,
84    ) -> Result<EndpointAddress> {
85        for index in ep_addr.map(|a| a.index()..a.index() + 1).unwrap_or(1..NUM_ENDPOINTS) {
86            let ep = &mut self.endpoints[index];
87
88            match ep.ep_type() {
89                None => {
90                    ep.set_ep_type(ep_type);
91                }
92                Some(t) if t != ep_type => {
93                    continue;
94                }
95                _ => {}
96            };
97
98            match ep_dir {
99                UsbDirection::Out if !ep.is_out_buf_set() => {
100                    let (out_size, size_bits) = calculate_count_rx(max_packet_size as usize)?;
101
102                    let buffer = self.ep_allocator.allocate_buffer(out_size)?;
103
104                    ep.set_out_buf(buffer, size_bits);
105
106                    return Ok(EndpointAddress::from_parts(index, ep_dir));
107                }
108                UsbDirection::In if !ep.is_in_buf_set() => {
109                    let size = (max_packet_size as usize + 1) & !0x01;
110
111                    let buffer = self.ep_allocator.allocate_buffer(size)?;
112
113                    ep.set_in_buf(buffer);
114
115                    return Ok(EndpointAddress::from_parts(index, ep_dir));
116                }
117                _ => {}
118            }
119        }
120
121        Err(match ep_addr {
122            Some(_) => UsbError::InvalidEndpoint,
123            None => UsbError::EndpointOverflow,
124        })
125    }
126
127    fn enable(&mut self) {
128        let mut max = 0;
129        for (index, ep) in self.endpoints.iter().enumerate() {
130            if ep.is_out_buf_set() || ep.is_in_buf_set() {
131                max = index;
132            }
133        }
134
135        self.max_endpoint = max;
136
137        interrupt::free(|cs| {
138            let regs = self.regs.borrow(cs);
139
140            regs.cntr.modify(|_, w| w.pdwn().clear_bit());
141
142            USB::startup_delay();
143
144            regs.btable.modify(|_, w| w.btable().bits(0));
145            regs.cntr.modify(|_, w| {
146                w.fres().clear_bit();
147                w.resetm().set_bit();
148                w.suspm().set_bit();
149                w.wkupm().set_bit();
150                w.ctrm().set_bit()
151            });
152            regs.istr.modify(|_, w| unsafe { w.bits(0) });
153
154            if USB::DP_PULL_UP_FEATURE {
155                regs.bcdr.modify(|_, w| w.dppu().set_bit());
156            }
157        });
158    }
159
160    fn reset(&self) {
161        interrupt::free(|cs| {
162            let regs = self.regs.borrow(cs);
163
164            regs.istr.modify(|_, w| unsafe { w.bits(0) });
165            regs.daddr.modify(|_, w| w.ef().set_bit().add().bits(0));
166
167            for ep in self.endpoints.iter() {
168                ep.configure(cs);
169            }
170        });
171    }
172
173    fn set_device_address(&self, addr: u8) {
174        interrupt::free(|cs| {
175            self.regs.borrow(cs).daddr.modify(|_, w| w.add().bits(addr as u8));
176        });
177    }
178
179    fn poll(&self) -> PollResult {
180        interrupt::free(|cs| {
181            let regs = self.regs.borrow(cs);
182
183            let istr = regs.istr.read();
184
185            if istr.wkup().bit_is_set() {
186                // Interrupt flag bits are write-0-to-clear, other bits should be written as 1 to avoid
187                // race conditions
188                regs.istr.write(|w| unsafe { w.bits(0xffff) }.wkup().clear_bit());
189
190                // Required by datasheet
191                regs.cntr.modify(|_, w| w.fsusp().clear_bit());
192
193                PollResult::Resume
194            } else if istr.reset().bit_is_set() {
195                regs.istr.write(|w| unsafe { w.bits(0xffff) }.reset().clear_bit());
196
197                PollResult::Reset
198            } else if istr.susp().bit_is_set() {
199                regs.istr.write(|w| unsafe { w.bits(0xffff) }.susp().clear_bit());
200
201                PollResult::Suspend
202            } else if istr.ctr().bit_is_set() {
203                let mut ep_out = 0;
204                let mut ep_in_complete = 0;
205                let mut ep_setup = 0;
206                let mut bit = 1;
207
208                for ep in &self.endpoints[0..=self.max_endpoint] {
209                    let v = ep.read_reg();
210
211                    if v.ctr_rx().bit_is_set() {
212                        ep_out |= bit;
213
214                        if v.setup().bit_is_set() {
215                            ep_setup |= bit;
216                        }
217                    }
218
219                    if v.ctr_tx().bit_is_set() {
220                        ep_in_complete |= bit;
221
222                        interrupt::free(|cs| {
223                            ep.clear_ctr_tx(cs);
224                        });
225                    }
226
227                    bit <<= 1;
228                }
229
230                PollResult::Data {
231                    ep_out,
232                    ep_in_complete,
233                    ep_setup,
234                }
235            } else {
236                PollResult::None
237            }
238        })
239    }
240
241    fn write(&self, ep_addr: EndpointAddress, buf: &[u8]) -> Result<usize> {
242        if !ep_addr.is_in() {
243            return Err(UsbError::InvalidEndpoint);
244        }
245
246        self.endpoints[ep_addr.index()].write(buf)
247    }
248
249    fn read(&self, ep_addr: EndpointAddress, buf: &mut [u8]) -> Result<usize> {
250        if !ep_addr.is_out() {
251            return Err(UsbError::InvalidEndpoint);
252        }
253
254        self.endpoints[ep_addr.index()].read(buf)
255    }
256
257    fn set_stalled(&self, ep_addr: EndpointAddress, stalled: bool) {
258        interrupt::free(|cs| {
259            if self.is_stalled(ep_addr) == stalled {
260                return;
261            }
262
263            let ep = &self.endpoints[ep_addr.index()];
264
265            match (stalled, ep_addr.direction()) {
266                (true, UsbDirection::In) => ep.set_stat_tx(cs, EndpointStatus::Stall),
267                (true, UsbDirection::Out) => ep.set_stat_rx(cs, EndpointStatus::Stall),
268                (false, UsbDirection::In) => ep.set_stat_tx(cs, EndpointStatus::Nak),
269                (false, UsbDirection::Out) => ep.set_stat_rx(cs, EndpointStatus::Valid),
270            };
271        });
272    }
273
274    fn is_stalled(&self, ep_addr: EndpointAddress) -> bool {
275        let ep = &self.endpoints[ep_addr.index()];
276        let reg_v = ep.read_reg();
277
278        let status = match ep_addr.direction() {
279            UsbDirection::In => reg_v.stat_tx().bits(),
280            UsbDirection::Out => reg_v.stat_rx().bits(),
281        };
282
283        status == (EndpointStatus::Stall as u8)
284    }
285
286    fn suspend(&self) {
287        interrupt::free(|cs| {
288            self.regs
289                .borrow(cs)
290                .cntr
291                .modify(|_, w| w.fsusp().set_bit().lpmode().set_bit());
292        });
293    }
294
295    fn resume(&self) {
296        interrupt::free(|cs| {
297            self.regs
298                .borrow(cs)
299                .cntr
300                .modify(|_, w| w.fsusp().clear_bit().lpmode().clear_bit());
301        });
302    }
303}