trouble_host/
channel_manager.rs

1use core::cell::RefCell;
2use core::future::poll_fn;
3use core::task::{Context, Poll};
4
5use bt_hci::controller::{blocking, Controller};
6use bt_hci::param::ConnHandle;
7use bt_hci::FromHciBytes;
8use embassy_sync::blocking_mutex::raw::NoopRawMutex;
9use embassy_sync::channel::Channel;
10use embassy_sync::waitqueue::WakerRegistration;
11
12use crate::connection_manager::ConnectionManager;
13use crate::cursor::WriteCursor;
14use crate::host::BleHost;
15#[cfg(not(feature = "l2cap-sdu-reassembly-optimization"))]
16use crate::l2cap::sar::PacketReassembly;
17use crate::l2cap::L2capChannel;
18use crate::pdu::{Pdu, Sdu};
19use crate::prelude::{ConnectionEvent, L2capChannelConfig};
20use crate::types::l2cap::{
21    CommandRejectRes, ConnParamUpdateReq, ConnParamUpdateRes, DisconnectionReq, DisconnectionRes, L2capSignalCode,
22    L2capSignalHeader, LeCreditConnReq, LeCreditConnRes, LeCreditConnResultCode, LeCreditFlowInd,
23};
24use crate::{config, BleHostError, Error, PacketPool};
25
26const BASE_ID: u16 = 0x40;
27
28struct State<'d, P> {
29    next_req_id: u8,
30    channels: &'d mut [ChannelStorage<P>],
31    accept_waker: WakerRegistration,
32    create_waker: WakerRegistration,
33    disconnect_waker: WakerRegistration,
34}
35
36/// Channel manager for L2CAP channels used directly by clients.
37pub struct ChannelManager<'d, P: PacketPool> {
38    state: RefCell<State<'d, P::Packet>>,
39}
40
41pub(crate) struct PacketChannel<P, const QLEN: usize> {
42    chan: Channel<NoopRawMutex, Option<Pdu<P>>, QLEN>,
43}
44
45#[derive(Clone, Copy, Debug, PartialEq)]
46#[cfg_attr(feature = "defmt", derive(defmt::Format))]
47pub struct ChannelIndex(u8);
48
49impl<P, const QLEN: usize> PacketChannel<P, QLEN> {
50    pub(crate) const fn new() -> Self {
51        Self { chan: Channel::new() }
52    }
53
54    pub fn close(&self) -> Result<(), ()> {
55        self.chan.try_send(None).map_err(|_| ())
56    }
57
58    pub async fn send(&self, pdu: Pdu<P>) {
59        self.chan.send(Some(pdu)).await;
60    }
61
62    pub fn try_send(&self, pdu: Pdu<P>) -> Result<(), Error> {
63        self.chan.try_send(Some(pdu)).map_err(|_| Error::OutOfMemory)
64    }
65
66    pub fn poll_receive(&self, cx: &mut Context<'_>) -> Poll<Option<Pdu<P>>> {
67        self.chan.poll_receive(cx)
68    }
69
70    pub fn clear(&self) {
71        self.chan.clear()
72    }
73}
74
75impl<P> State<'_, P> {
76    fn print(&self, verbose: bool) {
77        for (idx, storage) in self.channels.iter().enumerate() {
78            if verbose || storage.state != ChannelState::Disconnected {
79                debug!("[l2cap][idx = {}] {:?}", idx, storage);
80            }
81        }
82    }
83    fn next_request_id(&mut self) -> u8 {
84        // 0 is an invalid identifier
85        if self.next_req_id == 0 {
86            self.next_req_id += 1;
87        }
88        let next = self.next_req_id;
89        self.next_req_id = self.next_req_id.wrapping_add(1);
90        next
91    }
92
93    fn inc_ref(&mut self, index: ChannelIndex) {
94        let state = &mut self.channels[index.0 as usize];
95        state.refcount = unwrap!(state.refcount.checked_add(1), "Too many references to the same channel");
96    }
97}
98
99impl<'d, P: PacketPool> ChannelManager<'d, P> {
100    pub fn new(channels: &'d mut [ChannelStorage<P::Packet>]) -> Self {
101        Self {
102            state: RefCell::new(State {
103                next_req_id: 0,
104                channels,
105                accept_waker: WakerRegistration::new(),
106                create_waker: WakerRegistration::new(),
107                disconnect_waker: WakerRegistration::new(),
108            }),
109        }
110    }
111
112    fn next_request_id(&self) -> u8 {
113        self.state.borrow_mut().next_request_id()
114    }
115
116    pub(crate) fn psm(&self, index: ChannelIndex) -> u16 {
117        self.with_mut(|state| {
118            let chan = &mut state.channels[index.0 as usize];
119            chan.psm
120        })
121    }
122
123    pub(crate) fn disconnect(&self, index: ChannelIndex) {
124        self.with_mut(|state| {
125            let chan = &mut state.channels[index.0 as usize];
126            if chan.state == ChannelState::Connected {
127                chan.state = ChannelState::Disconnecting;
128                let _ = chan.inbound.close();
129                #[cfg(feature = "channel-metrics")]
130                chan.metrics.reset();
131                state.disconnect_waker.wake();
132            }
133        })
134    }
135
136    pub(crate) fn disconnected(&self, conn: ConnHandle) -> Result<(), Error> {
137        let mut state = self.state.borrow_mut();
138        for storage in state.channels.iter_mut() {
139            if Some(conn) == storage.conn {
140                let _ = storage.inbound.close();
141                #[cfg(not(feature = "l2cap-sdu-reassembly-optimization"))]
142                storage.reassembly.clear();
143                #[cfg(feature = "channel-metrics")]
144                storage.metrics.reset();
145                storage.close();
146            }
147        }
148        state.accept_waker.wake();
149        state.create_waker.wake();
150        Ok(())
151    }
152
153    fn alloc<F: FnOnce(&mut ChannelStorage<P::Packet>)>(&self, conn: ConnHandle, f: F) -> Result<ChannelIndex, Error> {
154        let mut state = self.state.borrow_mut();
155        for (idx, storage) in state.channels.iter_mut().enumerate() {
156            if ChannelState::Disconnected == storage.state && storage.refcount == 0 {
157                // Ensure inbound is empty.
158                storage.inbound.clear();
159                #[cfg(not(feature = "l2cap-sdu-reassembly-optimization"))]
160                storage.reassembly.clear();
161                let cid: u16 = BASE_ID + idx as u16;
162                storage.conn = Some(conn);
163                storage.cid = cid;
164                f(storage);
165                return Ok(ChannelIndex(idx as u8));
166            }
167        }
168        Err(Error::NoChannelAvailable)
169    }
170
171    pub(crate) async fn accept<T: Controller>(
172        &'d self,
173        conn: ConnHandle,
174        psm: &[u16],
175        config: &L2capChannelConfig,
176        ble: &BleHost<'d, T, P>,
177    ) -> Result<L2capChannel<'d, P>, BleHostError<T::Error>> {
178        let L2capChannelConfig {
179            mtu,
180            mps,
181            flow_policy,
182            initial_credits,
183        } = config;
184
185        let mtu = mtu.unwrap_or(P::MTU as u16 - 6);
186        let mps = mps.unwrap_or(P::MTU as u16 - 4);
187        if mps > P::MTU as u16 - 4 {
188            return Err(Error::InsufficientSpace.into());
189        }
190
191        // Wait until we find a channel for our connection in the connecting state matching our PSM.
192        let (channel, req_id, mps, mtu, cid, credits) = poll_fn(|cx| {
193            let mut state = self.state.borrow_mut();
194            state.accept_waker.register(cx.waker());
195            for (idx, chan) in state.channels.iter_mut().enumerate() {
196                match chan.state {
197                    ChannelState::PeerConnecting(req_id) if chan.conn == Some(conn) && psm.contains(&chan.psm) => {
198                        chan.mtu = chan.mtu.min(mtu);
199                        chan.mps = chan.mps.min(mps);
200                        chan.flow_control = CreditFlowControl::new(
201                            *flow_policy,
202                            initial_credits.unwrap_or(config::L2CAP_RX_QUEUE_SIZE.min(P::capacity()) as u16),
203                        );
204                        chan.state = ChannelState::Connected;
205                        let mps = chan.mps;
206                        let mtu = chan.mtu;
207                        let cid = chan.cid;
208                        let available = chan.flow_control.available();
209                        if chan.refcount != 0 {
210                            state.print(true);
211                            panic!("unexpected refcount");
212                        }
213                        assert_eq!(chan.refcount, 0);
214                        let index = ChannelIndex(idx as u8);
215
216                        state.inc_ref(index);
217                        return Poll::Ready((L2capChannel::new(index, self), req_id, mps, mtu, cid, available));
218                    }
219                    _ => {}
220                }
221            }
222            Poll::Pending
223        })
224        .await;
225
226        let mut tx = [0; 18];
227        // Respond that we accept the channel.
228        ble.l2cap_signal(
229            conn,
230            req_id,
231            &LeCreditConnRes {
232                mps,
233                dcid: cid,
234                mtu,
235                credits,
236                result: LeCreditConnResultCode::Success,
237            },
238            &mut tx[..],
239        )
240        .await?;
241        Ok(channel)
242    }
243
244    pub(crate) async fn create<T: Controller>(
245        &'d self,
246        conn: ConnHandle,
247        psm: u16,
248        config: &L2capChannelConfig,
249        ble: &BleHost<'_, T, P>,
250    ) -> Result<L2capChannel<'d, P>, BleHostError<T::Error>> {
251        let L2capChannelConfig {
252            mtu,
253            mps,
254            flow_policy,
255            initial_credits,
256        } = config;
257
258        let req_id = self.next_request_id();
259        let mut credits = 0;
260        let mut cid: u16 = 0;
261
262        let mtu = mtu.unwrap_or(P::MTU as u16 - 6);
263        let mps = mps.unwrap_or(P::MTU as u16 - 4);
264        if mps > P::MTU as u16 - 4 {
265            return Err(Error::InsufficientSpace.into());
266        }
267
268        // Allocate space for our new channel.
269        let idx = self.alloc(conn, |storage| {
270            cid = storage.cid;
271            credits = initial_credits.unwrap_or(config::L2CAP_RX_QUEUE_SIZE.min(P::capacity()) as u16);
272            storage.psm = psm;
273            storage.mtu = mtu;
274            storage.mps = mps;
275            storage.flow_control = CreditFlowControl::new(*flow_policy, credits);
276            storage.state = ChannelState::Connecting(req_id);
277        })?;
278
279        let mut tx = [0; 18];
280        // Send the initial connect request.
281        let command = LeCreditConnReq {
282            psm,
283            mps,
284            scid: cid,
285            mtu,
286            credits,
287        };
288        ble.l2cap_signal(conn, req_id, &command, &mut tx[..]).await?;
289
290        // Wait until a response is accepted.
291        poll_fn(|cx| self.poll_created(conn, idx, ble, Some(cx))).await
292    }
293
294    fn poll_created<T: Controller>(
295        &'d self,
296        conn: ConnHandle,
297        idx: ChannelIndex,
298        ble: &BleHost<'_, T, P>,
299        cx: Option<&mut Context<'_>>,
300    ) -> Poll<Result<L2capChannel<'d, P>, BleHostError<T::Error>>> {
301        let mut state = self.state.borrow_mut();
302        if let Some(cx) = cx {
303            state.create_waker.register(cx.waker());
304        }
305        let storage = &mut state.channels[idx.0 as usize];
306        // Check if we've been disconnected while waiting
307        if !ble.connections.is_handle_connected(conn) {
308            return Poll::Ready(Err(Error::Disconnected.into()));
309        }
310
311        //// Make sure something hasn't gone wrong
312        assert_eq!(Some(conn), storage.conn);
313
314        match storage.state {
315            ChannelState::Disconnecting | ChannelState::PeerDisconnecting => {
316                return Poll::Ready(Err(Error::Disconnected.into()));
317            }
318            ChannelState::Connected => {
319                if storage.refcount != 0 {
320                    state.print(true);
321                    panic!("unexpected refcount");
322                }
323                assert_eq!(storage.refcount, 0);
324                state.inc_ref(idx);
325                return Poll::Ready(Ok(L2capChannel::new(idx, self)));
326            }
327            _ => {}
328        }
329        Poll::Pending
330    }
331
332    pub(crate) fn received(&self, channel: u16, credits: u16) -> Result<(), Error> {
333        if channel < BASE_ID {
334            return Err(Error::InvalidChannelId);
335        }
336
337        let chan = (channel - BASE_ID) as usize;
338        self.with_mut(|state| {
339            if chan >= state.channels.len() {
340                return Err(Error::InvalidChannelId);
341            }
342
343            let storage = &mut state.channels[chan];
344            match storage.state {
345                ChannelState::Connected if channel == storage.cid => {
346                    if storage.flow_control.available() == 0 {
347                        #[cfg(feature = "channel-metrics")]
348                        storage.metrics.blocked_receive();
349                        // NOTE: This will trigger closing of the link, which might be a bit
350                        // too strict. But it should be controllable via the credits given,
351                        // which the remote should respect.
352                        debug!("[l2cap][cid = {}] no credits available", channel);
353                        return Err(Error::OutOfMemory);
354                    }
355                    storage.flow_control.confirm_received(1);
356                    #[cfg(feature = "channel-metrics")]
357                    storage.metrics.received(1);
358                    return Ok(());
359                }
360                _ => {}
361            }
362            Err(Error::NotFound)
363        })
364    }
365
366    pub(crate) fn dispatch(&self, channel: u16, pdu: Pdu<P::Packet>) -> Result<(), Error> {
367        if channel < BASE_ID {
368            return Err(Error::InvalidChannelId);
369        }
370
371        let chan = (channel - BASE_ID) as usize;
372        self.with_mut(|state| {
373            if chan >= state.channels.len() {
374                return Err(Error::InvalidChannelId);
375            }
376
377            let mut sdu = None;
378            let storage = &mut state.channels[chan];
379            match storage.state {
380                ChannelState::Connected if channel == storage.cid => {
381                    // Reassembly and accounting is already done
382                    #[cfg(feature = "l2cap-sdu-reassembly-optimization")]
383                    sdu.replace(pdu);
384
385                    // Reassembly is done in the channel
386                    #[cfg(not(feature = "l2cap-sdu-reassembly-optimization"))]
387                    {
388                        if storage.flow_control.available() == 0 {
389                            #[cfg(feature = "channel-metrics")]
390                            storage.metrics.blocked_receive();
391                            // NOTE: This will trigger closing of the link, which might be a bit
392                            // too strict. But it should be controllable via the credits given,
393                            // which the remote should respect.
394                            debug!("[l2cap][cid = {}] no credits available", channel);
395                            return Err(Error::OutOfMemory);
396                        }
397                        storage.flow_control.confirm_received(1);
398
399                        #[cfg(feature = "channel-metrics")]
400                        storage.metrics.received(1);
401                        if !storage.reassembly.in_progress() {
402                            let (first, _) = pdu.as_ref().split_at(2);
403                            let sdu_len: u16 = u16::from_le_bytes([first[0], first[1]]);
404                            let len = pdu.len() - 2;
405
406                            let mut packet = pdu.into_inner();
407                            packet.as_mut().rotate_left(2);
408
409                            // A complete fragment
410                            if sdu_len as usize == len {
411                                sdu.replace(Pdu::new(packet, sdu_len as usize));
412                            } else {
413                                // Need another fragment
414                                storage.reassembly.init_with_written(channel, sdu_len, packet, len)?;
415                            }
416                        } else if let Some((state, pdu)) = storage.reassembly.update(pdu.as_ref())? {
417                            sdu.replace(pdu);
418                        }
419                    }
420                }
421                _ => {}
422            }
423
424            if let Some(sdu) = sdu {
425                storage.inbound.try_send(sdu)?;
426            }
427
428            Ok(())
429        })
430    }
431
432    /// Handle incoming L2CAP signal
433    pub(crate) fn signal(
434        &self,
435        conn: ConnHandle,
436        data: &[u8],
437        manager: &ConnectionManager<'_, P>,
438    ) -> Result<(), Error> {
439        let (header, data) = L2capSignalHeader::from_hci_bytes(data)?;
440        //trace!(
441        //    "[l2cap][conn = {:?}] received signal (req {}) code {:?}",
442        //    conn,
443        //    header.identifier,
444        //    header.code
445        //);
446        match header.code {
447            L2capSignalCode::LeCreditConnReq => {
448                let req = LeCreditConnReq::from_hci_bytes_complete(data)?;
449                self.handle_connect_request(conn, header.identifier, &req)?;
450            }
451            L2capSignalCode::LeCreditConnRes => {
452                let res = LeCreditConnRes::from_hci_bytes_complete(data)?;
453                self.handle_connect_response(conn, header.identifier, &res)?;
454            }
455            L2capSignalCode::LeCreditFlowInd => {
456                let req = LeCreditFlowInd::from_hci_bytes_complete(data)?;
457                //trace!("[l2cap] credit flow: {:?}", req);
458                self.handle_credit_flow(conn, &req)?;
459            }
460            L2capSignalCode::CommandRejectRes => {
461                let (reject, _) = CommandRejectRes::from_hci_bytes(data)?;
462            }
463            L2capSignalCode::DisconnectionReq => {
464                let req = DisconnectionReq::from_hci_bytes_complete(data)?;
465                debug!("[l2cap][conn = {:?}, cid = {}] disconnect request", conn, req.dcid);
466                self.handle_disconnect_request(req.dcid)?;
467            }
468            L2capSignalCode::DisconnectionRes => {
469                let res = DisconnectionRes::from_hci_bytes_complete(data)?;
470                debug!("[l2cap][conn = {:?}, cid = {}] disconnect response", conn, res.scid);
471                self.handle_disconnect_response(res.scid)?;
472            }
473            L2capSignalCode::ConnParamUpdateReq => {
474                let req = ConnParamUpdateReq::from_hci_bytes_complete(data)?;
475                debug!("[l2cap][conn = {:?}] connection param update request: {:?}", conn, req);
476                let interval_min: bt_hci::param::Duration<1_250> = bt_hci::param::Duration::from_u16(req.interval_min);
477                let interva_max: bt_hci::param::Duration<1_250> = bt_hci::param::Duration::from_u16(req.interval_max);
478                let timeout: bt_hci::param::Duration<10_000> = bt_hci::param::Duration::from_u16(req.timeout);
479                use embassy_time::Duration;
480                let _ = manager.post_handle_event(
481                    conn,
482                    ConnectionEvent::RequestConnectionParams {
483                        min_connection_interval: Duration::from_micros(interval_min.as_micros()),
484                        max_connection_interval: Duration::from_micros(interval_min.as_micros()),
485                        max_latency: req.latency,
486                        supervision_timeout: Duration::from_micros(timeout.as_micros()),
487                    },
488                );
489            }
490            L2capSignalCode::ConnParamUpdateRes => {
491                let res = ConnParamUpdateRes::from_hci_bytes_complete(data)?;
492                debug!(
493                    "[l2cap][conn = {:?}] connection param update response: {}",
494                    conn, res.result,
495                );
496            }
497            r => {
498                warn!("[l2cap][conn = {:?}] unsupported signal: {:?}", conn, r);
499                return Err(Error::NotSupported);
500            }
501        }
502        Ok(())
503    }
504
505    fn handle_connect_request(&self, conn: ConnHandle, identifier: u8, req: &LeCreditConnReq) -> Result<(), Error> {
506        self.alloc(conn, |storage| {
507            storage.conn = Some(conn);
508            storage.psm = req.psm;
509            storage.peer_cid = req.scid;
510            storage.peer_credits = req.credits;
511            storage.mps = req.mps;
512            storage.mtu = req.mtu;
513            storage.state = ChannelState::PeerConnecting(identifier);
514        })?;
515        self.state.borrow_mut().accept_waker.wake();
516        Ok(())
517    }
518
519    fn handle_connect_response(&self, conn: ConnHandle, identifier: u8, res: &LeCreditConnRes) -> Result<(), Error> {
520        match res.result {
521            LeCreditConnResultCode::Success => {
522                // Must be a response of a previous request which should already by allocated a channel for
523                let mut state = self.state.borrow_mut();
524                for storage in state.channels.iter_mut() {
525                    match storage.state {
526                        ChannelState::Connecting(req_id) if identifier == req_id && Some(conn) == storage.conn => {
527                            storage.peer_cid = res.dcid;
528                            storage.peer_credits = res.credits;
529                            storage.mps = storage.mps.min(res.mps);
530                            storage.mtu = storage.mtu.min(res.mtu);
531                            storage.state = ChannelState::Connected;
532                            state.create_waker.wake();
533                            return Ok(());
534                        }
535                        _ => {}
536                    }
537                }
538                debug!(
539                    "[l2cap][handle_connect_response][link = {}] request with id {} not found",
540                    conn.raw(),
541                    identifier
542                );
543                Err(Error::NotFound)
544            }
545            other => {
546                warn!("Channel open request failed: {:?}", other);
547                Err(Error::NotSupported)
548            }
549        }
550    }
551
552    fn handle_credit_flow(&self, conn: ConnHandle, req: &LeCreditFlowInd) -> Result<(), Error> {
553        let mut state = self.state.borrow_mut();
554        for storage in state.channels.iter_mut() {
555            match storage.state {
556                ChannelState::Connected if storage.peer_cid == req.cid && Some(conn) == storage.conn => {
557                    trace!(
558                        "[l2cap][handle_credit_flow][cid = {}] {} += {} credits",
559                        req.cid,
560                        storage.peer_credits,
561                        req.credits
562                    );
563                    storage.peer_credits = storage.peer_credits.saturating_add(req.credits);
564                    storage.credit_waker.wake();
565                    return Ok(());
566                }
567                _ => {}
568            }
569        }
570        //    trace!("[l2cap][handle_credit_flow] peer channel {} not found", req.cid);
571        Err(Error::NotFound)
572    }
573
574    fn handle_disconnect_request(&self, cid: u16) -> Result<(), Error> {
575        let mut state = self.state.borrow_mut();
576        for (idx, storage) in state.channels.iter_mut().enumerate() {
577            if cid == storage.cid {
578                storage.state = ChannelState::PeerDisconnecting;
579                let _ = storage.inbound.close();
580                state.disconnect_waker.wake();
581                break;
582            }
583        }
584        Ok(())
585    }
586
587    fn handle_disconnect_response(&self, cid: u16) -> Result<(), Error> {
588        let mut state = self.state.borrow_mut();
589        for storage in state.channels.iter_mut() {
590            if storage.state == ChannelState::Disconnecting && cid == storage.cid {
591                storage.close();
592                break;
593            }
594        }
595        Ok(())
596    }
597
598    /// Receive SDU on a given channel.
599    ///
600    /// The MTU of the channel must be <= the MTU of the packet.
601    pub(crate) async fn receive_sdu<T: Controller>(
602        &self,
603        chan: ChannelIndex,
604        ble: &BleHost<'d, T, P>,
605    ) -> Result<Sdu<P::Packet>, BleHostError<T::Error>> {
606        let pdu = self.receive_pdu(&ble.connections, chan).await?;
607        let mut p_buf: [u8; 16] = [0; 16];
608        self.flow_control(chan, ble, &mut p_buf).await?;
609        Ok(Sdu::from_pdu(pdu))
610    }
611
612    /// Receive data on a given channel and copy it into the buffer.
613    ///
614    /// The length provided buffer slice must be equal or greater to the agreed MTU.
615    pub(crate) async fn receive<T: Controller>(
616        &self,
617        chan: ChannelIndex,
618        buf: &mut [u8],
619        ble: &BleHost<'d, T, P>,
620    ) -> Result<usize, BleHostError<T::Error>> {
621        let pdu = self.receive_pdu(&ble.connections, chan).await?;
622
623        let to_copy = pdu.len().min(buf.len());
624        // info!("[host] received a pdu of len {}, copying {} bytes", pdu.len(), to_copy);
625        buf[..to_copy].copy_from_slice(&pdu.as_ref()[..to_copy]);
626
627        let mut p_buf: [u8; 16] = [0; 16];
628        self.flow_control(chan, ble, &mut p_buf).await?;
629        Ok(to_copy)
630    }
631
632    async fn receive_pdu<'m>(
633        &self,
634        ble: &'m ConnectionManager<'d, P>,
635        chan: ChannelIndex,
636    ) -> Result<Pdu<P::Packet>, Error> {
637        poll_fn(|cx| {
638            let state = self.state.borrow();
639            let chan = &state.channels[chan.0 as usize];
640            if chan.state == ChannelState::Connected {
641                let conn = chan.conn.unwrap();
642                match chan.inbound.poll_receive(cx) {
643                    Poll::Ready(Some(pdu)) => Poll::Ready(Ok(pdu)),
644                    Poll::Ready(None) => Poll::Ready(Err(Error::ChannelClosed)),
645                    Poll::Pending => Poll::Pending,
646                }
647            } else {
648                Poll::Ready(Err(Error::ChannelClosed))
649            }
650        })
651        .await
652    }
653
654    /// Send the provided buffer over a given l2cap channel.
655    ///
656    /// The buffer will be segmented to the maximum payload size agreed in the opening handshake.
657    ///
658    /// If the channel has been closed or the channel id is not valid, an error is returned.
659    pub(crate) async fn send<T: Controller>(
660        &self,
661        index: ChannelIndex,
662        buf: &[u8],
663        p_buf: &mut [u8],
664        ble: &BleHost<'d, T, P>,
665    ) -> Result<(), BleHostError<T::Error>> {
666        let (conn, mps, mtu, peer_cid) = self.connected_channel_params(index)?;
667        if buf.len() > mtu as usize {
668            return Err(Error::InsufficientSpace.into());
669        }
670        // The number of packets we'll need to send for this payload
671        let len = (buf.len() as u16).saturating_add(2);
672        let n_packets = len.div_ceil(mps);
673        // info!("[host] sending {} LE K frames, len {}, mps {}", n_packets, len, mps);
674
675        let mut grant = poll_fn(|cx| self.poll_request_to_send(index, n_packets, Some(cx))).await?;
676
677        // Segment using mps
678        let (first, remaining) = buf.split_at(buf.len().min(mps as usize - 2));
679
680        let len = encode(first, &mut p_buf[..], peer_cid, Some(buf.len() as u16))?;
681        ble.l2cap(conn, (len - 4) as u16, 1).await?.send(&p_buf[..len]).await?;
682        grant.confirm(1);
683
684        let chunks = remaining.chunks(mps as usize);
685
686        for chunk in chunks {
687            let len = encode(chunk, &mut p_buf[..], peer_cid, None)?;
688            ble.l2cap(conn, (len - 4) as u16, 1).await?.send(&p_buf[..len]).await?;
689            grant.confirm(1);
690        }
691        Ok(())
692    }
693
694    /// Send the provided buffer over a given l2cap channel.
695    ///
696    /// The buffer must be equal to or smaller than the MTU agreed for the channel.
697    ///
698    /// If the channel has been closed or the channel id is not valid, an error is returned.
699    pub(crate) fn try_send<T: Controller + blocking::Controller>(
700        &self,
701        index: ChannelIndex,
702        buf: &[u8],
703        p_buf: &mut [u8],
704        ble: &BleHost<'d, T, P>,
705    ) -> Result<(), BleHostError<T::Error>> {
706        let (conn, mps, mtu, peer_cid) = self.connected_channel_params(index)?;
707        if buf.len() > mtu as usize {
708            return Err(Error::InsufficientSpace.into());
709        }
710
711        // The number of packets we'll need to send for this payload
712        let len = (buf.len() as u16).saturating_add(2);
713        let n_packets = len.div_ceil(mps);
714
715        let mut grant = match self.poll_request_to_send(index, n_packets, None) {
716            Poll::Ready(res) => res?,
717            Poll::Pending => {
718                return Err(Error::Busy.into());
719            }
720        };
721
722        // Pre-request
723        let mut sender = ble.try_l2cap(conn, len, n_packets)?;
724
725        // Segment using mps
726        let (first, remaining) = buf.split_at(buf.len().min(mps as usize - 2));
727
728        let len = encode(first, &mut p_buf[..], peer_cid, Some(buf.len() as u16))?;
729        sender.try_send(&p_buf[..len])?;
730        grant.confirm(1);
731
732        let chunks = remaining.chunks(mps as usize);
733        let num_chunks = chunks.len();
734
735        for (i, chunk) in chunks.enumerate() {
736            let len = encode(chunk, &mut p_buf[..], peer_cid, None)?;
737            sender.try_send(&p_buf[..len])?;
738            grant.confirm(1);
739        }
740        Ok(())
741    }
742
743    pub(crate) async fn send_conn_param_update_req<T: Controller>(
744        &self,
745        handle: ConnHandle,
746        host: &BleHost<'d, T, P>,
747        param: &ConnParamUpdateReq,
748    ) -> Result<(), BleHostError<T::Error>> {
749        let identifier = self.next_request_id();
750        let mut tx = [0; 16];
751        host.l2cap_signal(handle, identifier, param, &mut tx[..]).await
752    }
753
754    pub(crate) async fn send_conn_param_update_res<T: Controller>(
755        &self,
756        handle: ConnHandle,
757        host: &BleHost<'d, T, P>,
758        param: &ConnParamUpdateRes,
759    ) -> Result<(), BleHostError<T::Error>> {
760        let identifier = self.next_request_id();
761        let mut tx = [0; 16];
762        host.l2cap_signal(handle, identifier, param, &mut tx[..]).await
763    }
764
765    fn connected_channel_params(&self, index: ChannelIndex) -> Result<(ConnHandle, u16, u16, u16), Error> {
766        let state = self.state.borrow();
767        let chan = &state.channels[index.0 as usize];
768        if chan.state == ChannelState::Connected {
769            return Ok((chan.conn.unwrap(), chan.mps, chan.mtu, chan.peer_cid));
770        }
771        //trace!("[l2cap][connected_channel_params] channel {} closed", index);
772        Err(Error::ChannelClosed)
773    }
774
775    // Check the current state of flow control and send flow indications if
776    // our policy says so.
777    async fn flow_control<T: Controller>(
778        &self,
779        index: ChannelIndex,
780        ble: &BleHost<'d, T, P>,
781        p_buf: &mut [u8],
782    ) -> Result<(), BleHostError<T::Error>> {
783        let (conn, cid, credits) = self.with_mut(|state| {
784            let chan = &mut state.channels[index.0 as usize];
785            if chan.state == ChannelState::Connected {
786                return Ok((chan.conn.unwrap(), chan.cid, chan.flow_control.process()));
787            }
788            debug!("[l2cap][flow_control_process] channel {:?} not found", index);
789            Err(Error::NotFound)
790        })?;
791
792        if let Some(credits) = credits {
793            let identifier = self.next_request_id();
794            let signal = LeCreditFlowInd { cid, credits };
795            // info!("[host] sending credit flow {} credits on cid {}", credits, cid);
796
797            // Reuse packet buffer for signalling data to save the extra TX buffer
798            ble.l2cap_signal(conn, identifier, &signal, p_buf).await?;
799            self.with_mut(|state| {
800                let chan = &mut state.channels[index.0 as usize];
801                if chan.state == ChannelState::Connected {
802                    chan.flow_control.confirm_granted(credits);
803                    return Ok(());
804                }
805                debug!("[l2cap][flow_control_grant] channel {:?} not found", index);
806                Err(Error::NotFound)
807            })?;
808        }
809        Ok(())
810    }
811
812    fn with_mut<F: FnOnce(&mut State<'d, P::Packet>) -> R, R>(&self, f: F) -> R {
813        let mut state = self.state.borrow_mut();
814        f(&mut state)
815    }
816
817    fn poll_request_to_send(
818        &self,
819        index: ChannelIndex,
820        credits: u16,
821        cx: Option<&mut Context<'_>>,
822    ) -> Poll<Result<CreditGrant<'_, 'd, P::Packet>, Error>> {
823        let mut state = self.state.borrow_mut();
824        let chan = &mut state.channels[index.0 as usize];
825        if chan.state == ChannelState::Connected {
826            if let Some(cx) = cx {
827                chan.credit_waker.register(cx.waker());
828            }
829            if credits <= chan.peer_credits {
830                chan.peer_credits -= credits;
831                #[cfg(feature = "channel-metrics")]
832                chan.metrics.sent(credits as usize);
833                return Poll::Ready(Ok(CreditGrant::new(&self.state, index, credits)));
834            } else {
835                #[cfg(feature = "channel-metrics")]
836                chan.metrics.blocked_send();
837                return Poll::Pending;
838            }
839        }
840        debug!("[l2cap][pool_request_to_send] channel index {:?} not found", index);
841        Poll::Ready(Err(Error::NotFound))
842    }
843
844    pub(crate) fn poll_disconnecting<'m>(&'m self, cx: Option<&mut Context<'_>>) -> Poll<DisconnectRequest<'m, 'd, P>> {
845        let mut state = self.state.borrow_mut();
846        if let Some(cx) = cx {
847            state.disconnect_waker.register(cx.waker());
848        }
849        for (idx, storage) in state.channels.iter().enumerate() {
850            match storage.state {
851                ChannelState::Disconnecting | ChannelState::PeerDisconnecting => {
852                    return Poll::Ready(DisconnectRequest {
853                        index: ChannelIndex(idx as u8),
854                        handle: storage.conn.unwrap(),
855                        state: &self.state,
856                    });
857                }
858                _ => {}
859            }
860        }
861        Poll::Pending
862    }
863
864    pub(crate) fn inc_ref(&self, index: ChannelIndex) {
865        self.with_mut(|state| {
866            state.inc_ref(index);
867        });
868    }
869
870    pub(crate) fn dec_ref(&self, index: ChannelIndex) {
871        self.with_mut(|state| {
872            let state = &mut state.channels[index.0 as usize];
873            state.refcount = unwrap!(
874                state.refcount.checked_sub(1),
875                "bug: dropping a channel (i = {}) with refcount 0",
876                index.0
877            );
878            if state.refcount == 0 && state.state == ChannelState::Connected {
879                state.state = ChannelState::Disconnecting;
880            }
881        });
882    }
883
884    pub(crate) fn log_status(&self, verbose: bool) {
885        let state = self.state.borrow();
886        state.print(verbose);
887    }
888
889    #[cfg(feature = "defmt")]
890    pub(crate) fn print(&self, index: ChannelIndex, f: defmt::Formatter) {
891        use defmt::Format;
892        self.with_mut(|state| {
893            let chan = &mut state.channels[index.0 as usize];
894            chan.format(f);
895        })
896    }
897
898    #[cfg(feature = "channel-metrics")]
899    pub(crate) fn metrics<F: FnOnce(&Metrics) -> R, R>(&self, index: ChannelIndex, f: F) -> R {
900        self.with_mut(|state| {
901            let state = &state.channels[index.0 as usize];
902            f(&state.metrics)
903        })
904    }
905}
906
907pub struct DisconnectRequest<'a, 'd, P: PacketPool> {
908    index: ChannelIndex,
909    handle: ConnHandle,
910    state: &'a RefCell<State<'d, P::Packet>>,
911}
912
913impl<'a, 'd, P: PacketPool> DisconnectRequest<'a, 'd, P> {
914    pub fn handle(&self) -> ConnHandle {
915        self.handle
916    }
917
918    pub async fn send<T: Controller>(&self, host: &BleHost<'_, T, P>) -> Result<(), BleHostError<T::Error>> {
919        let (state, conn, identifier, dcid, scid) = {
920            let mut state = self.state.borrow_mut();
921            let identifier = state.next_request_id();
922            let chan = &state.channels[self.index.0 as usize];
923            (chan.state.clone(), chan.conn, identifier, chan.peer_cid, chan.cid)
924        };
925
926        let mut tx = [0; 18];
927        match state {
928            ChannelState::PeerDisconnecting => {
929                assert_eq!(Some(self.handle), conn);
930                host.l2cap_signal(self.handle, identifier, &DisconnectionRes { dcid, scid }, &mut tx[..])
931                    .await?;
932            }
933            ChannelState::Disconnecting => {
934                assert_eq!(Some(self.handle), conn);
935                host.l2cap_signal(self.handle, identifier, &DisconnectionReq { dcid, scid }, &mut tx[..])
936                    .await?;
937            }
938            _ => {}
939        }
940        Ok(())
941    }
942
943    pub fn confirm(self) {
944        self.state.borrow_mut().channels[self.index.0 as usize].state = ChannelState::Disconnected;
945    }
946}
947
948fn encode(data: &[u8], packet: &mut [u8], peer_cid: u16, header: Option<u16>) -> Result<usize, Error> {
949    let mut w = WriteCursor::new(packet);
950    if header.is_some() {
951        w.write(2 + data.len() as u16)?;
952    } else {
953        w.write(data.len() as u16)?;
954    }
955    w.write(peer_cid)?;
956
957    if let Some(len) = header {
958        w.write(len)?;
959    }
960
961    w.append(data)?;
962    Ok(w.len())
963}
964
965pub struct ChannelStorage<P> {
966    state: ChannelState,
967    conn: Option<ConnHandle>,
968    cid: u16,
969    psm: u16,
970    mps: u16,
971    mtu: u16,
972    flow_control: CreditFlowControl,
973    refcount: u8,
974
975    peer_cid: u16,
976    peer_credits: u16,
977    credit_waker: WakerRegistration,
978
979    inbound: PacketChannel<P, { config::L2CAP_RX_QUEUE_SIZE }>,
980    #[cfg(not(feature = "l2cap-sdu-reassembly-optimization"))]
981    reassembly: PacketReassembly<P>,
982
983    #[cfg(feature = "channel-metrics")]
984    metrics: Metrics,
985}
986
987/// Metrics for this channel
988#[cfg(feature = "channel-metrics")]
989#[derive(Debug)]
990pub struct Metrics {
991    /// Number of sent l2cap packets.
992    pub num_sent: usize,
993    /// Number of received l2cap packets.
994    pub num_received: usize,
995    /// Number of l2cap packets blocked from sending.
996    pub blocked_send: usize,
997    /// Number of l2cap packets blocked from receiving.
998    pub blocked_receive: usize,
999}
1000
1001#[cfg(feature = "channel-metrics")]
1002impl Metrics {
1003    pub(crate) const fn new() -> Self {
1004        Self {
1005            num_sent: 0,
1006            num_received: 0,
1007            blocked_send: 0,
1008            blocked_receive: 0,
1009        }
1010    }
1011    pub(crate) fn sent(&mut self, num: usize) {
1012        self.num_sent = self.num_sent.wrapping_add(num);
1013    }
1014
1015    pub(crate) fn received(&mut self, num: usize) {
1016        self.num_received = self.num_received.wrapping_add(num);
1017    }
1018
1019    pub(crate) fn blocked_send(&mut self) {
1020        self.blocked_send = self.blocked_send.wrapping_add(1);
1021    }
1022
1023    pub(crate) fn blocked_receive(&mut self) {
1024        self.blocked_receive = self.blocked_receive.wrapping_add(1);
1025    }
1026
1027    pub(crate) fn reset(&mut self) {
1028        *self = Self::new();
1029    }
1030}
1031
1032#[cfg(feature = "channel-metrics")]
1033#[cfg(feature = "defmt")]
1034impl defmt::Format for Metrics {
1035    fn format(&self, f: defmt::Formatter<'_>) {
1036        defmt::write!(
1037            f,
1038            "sent = {}, recvd = {}, blocked send = {}, blocked receive = {}",
1039            self.num_sent,
1040            self.num_received,
1041            self.blocked_send,
1042            self.blocked_receive,
1043        );
1044    }
1045}
1046
1047impl<P> core::fmt::Debug for ChannelStorage<P> {
1048    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1049        let mut d = f.debug_struct("ChannelStorage");
1050        let d = d
1051            .field("state", &self.state)
1052            .field("conn", &self.conn)
1053            .field("cid", &self.cid)
1054            .field("peer_cid", &self.peer_cid)
1055            .field("mps", &self.mps)
1056            .field("mtu", &self.mtu)
1057            .field("peer_credits", &self.peer_credits)
1058            .field("available", &self.flow_control.available())
1059            .field("refcount", &self.refcount);
1060        #[cfg(feature = "channel-metrics")]
1061        let d = d.field("metrics", &self.metrics);
1062        d.finish()
1063    }
1064}
1065
1066#[cfg(feature = "defmt")]
1067impl<P> defmt::Format for ChannelStorage<P> {
1068    fn format(&self, f: defmt::Formatter<'_>) {
1069        defmt::write!(
1070            f,
1071            "state = {}, c = {}, cid = {}, peer = {}, mps = {}, mtu = {}, cred out {}, cred in = {}, ref = {}",
1072            self.state,
1073            self.conn,
1074            self.cid,
1075            self.peer_cid,
1076            self.mps,
1077            self.mtu,
1078            self.peer_credits,
1079            self.flow_control.available(),
1080            self.refcount,
1081        );
1082        #[cfg(feature = "channel-metrics")]
1083        defmt::write!(f, ", {}", self.metrics);
1084    }
1085}
1086
1087impl<P> ChannelStorage<P> {
1088    pub(crate) const fn new() -> ChannelStorage<P> {
1089        ChannelStorage {
1090            state: ChannelState::Disconnected,
1091            conn: None,
1092            cid: 0,
1093            mps: 0,
1094            mtu: 0,
1095            psm: 0,
1096
1097            flow_control: CreditFlowControl::new(CreditFlowPolicy::Every(1), 0),
1098            peer_cid: 0,
1099            peer_credits: 0,
1100            credit_waker: WakerRegistration::new(),
1101            refcount: 0,
1102            inbound: PacketChannel::new(),
1103            #[cfg(not(feature = "l2cap-sdu-reassembly-optimization"))]
1104            reassembly: PacketReassembly::new(),
1105            #[cfg(feature = "channel-metrics")]
1106            metrics: Metrics::new(),
1107        }
1108    }
1109
1110    fn close(&mut self) {
1111        self.state = ChannelState::Disconnected;
1112        self.cid = 0;
1113        self.conn = None;
1114        self.mps = 0;
1115        self.mtu = 0;
1116        self.psm = 0;
1117        self.peer_cid = 0;
1118        self.flow_control = CreditFlowControl::new(CreditFlowPolicy::Every(1), 0);
1119        self.peer_credits = 0;
1120    }
1121}
1122
1123#[derive(Debug, PartialEq, Clone)]
1124#[cfg_attr(feature = "defmt", derive(defmt::Format))]
1125pub enum ChannelState {
1126    Disconnected,
1127    Connecting(u8),
1128    PeerConnecting(u8),
1129    Connected,
1130    PeerDisconnecting,
1131    Disconnecting,
1132}
1133
1134/// Control how credits are issued by the receiving end.
1135#[derive(Clone, Copy, Debug)]
1136#[cfg_attr(feature = "defmt", derive(defmt::Format))]
1137pub enum CreditFlowPolicy {
1138    /// Issue credits for every N messages received
1139    Every(u16),
1140    /// Issue credits when below a threshold
1141    MinThreshold(u16),
1142}
1143
1144impl Default for CreditFlowPolicy {
1145    fn default() -> Self {
1146        Self::Every(1)
1147    }
1148}
1149
1150#[derive(Debug)]
1151#[cfg_attr(feature = "defmt", derive(defmt::Format))]
1152pub(crate) struct CreditFlowControl {
1153    policy: CreditFlowPolicy,
1154    credits: u16,
1155    received: u16,
1156}
1157
1158impl CreditFlowControl {
1159    const fn new(policy: CreditFlowPolicy, initial_credits: u16) -> Self {
1160        Self {
1161            policy,
1162            credits: initial_credits,
1163            received: 0,
1164        }
1165    }
1166    fn available(&self) -> u16 {
1167        self.credits
1168    }
1169
1170    fn confirm_received(&mut self, n: u16) {
1171        self.credits = self.credits.saturating_sub(n);
1172        self.received = self.received.saturating_add(n);
1173    }
1174
1175    // Confirm that we've granted amount credits
1176    fn confirm_granted(&mut self, amount: u16) {
1177        self.received = self.received.saturating_sub(amount);
1178        self.credits = self.credits.saturating_add(amount);
1179    }
1180
1181    // Check if policy says we should grant more credits
1182    fn process(&mut self) -> Option<u16> {
1183        match self.policy {
1184            CreditFlowPolicy::Every(count) => {
1185                if self.received >= count {
1186                    Some(self.received)
1187                } else {
1188                    None
1189                }
1190            }
1191            CreditFlowPolicy::MinThreshold(threshold) => {
1192                if self.credits < threshold {
1193                    Some(self.received)
1194                } else {
1195                    None
1196                }
1197            }
1198        }
1199    }
1200}
1201
1202pub struct CreditGrant<'reference, 'state, P> {
1203    state: &'reference RefCell<State<'state, P>>,
1204    index: ChannelIndex,
1205    credits: u16,
1206}
1207
1208impl<'reference, 'state, P> CreditGrant<'reference, 'state, P> {
1209    fn new(state: &'reference RefCell<State<'state, P>>, index: ChannelIndex, credits: u16) -> Self {
1210        Self { state, index, credits }
1211    }
1212
1213    pub(crate) fn confirm(&mut self, sent: u16) {
1214        self.credits = self.credits.saturating_sub(sent);
1215    }
1216
1217    pub(crate) fn remaining(&self) -> u16 {
1218        self.credits
1219    }
1220
1221    fn done(&mut self) {
1222        self.credits = 0;
1223    }
1224}
1225
1226impl<P> Drop for CreditGrant<'_, '_, P> {
1227    fn drop(&mut self) {
1228        if self.credits > 0 {
1229            let mut state = self.state.borrow_mut();
1230            let chan = &mut state.channels[self.index.0 as usize];
1231            if chan.state == ChannelState::Connected {
1232                chan.peer_credits += self.credits;
1233                chan.credit_waker.wake();
1234            }
1235            // make it an assert?
1236            //        warn!("[l2cap][credit grant drop] channel {} not found", self.index);
1237        }
1238    }
1239}
1240
1241#[cfg(test)]
1242mod tests {
1243    extern crate std;
1244
1245    use bt_hci::param::{AddrKind, BdAddr, LeConnRole, Status};
1246
1247    use super::*;
1248    use crate::mock_controller::MockController;
1249    use crate::prelude::DefaultPacketPool;
1250    use crate::HostResources;
1251
1252    #[test]
1253    fn channel_refcount() {
1254        let mut resources: HostResources<DefaultPacketPool, 2, 2> = HostResources::new();
1255        let ble = MockController::new();
1256
1257        let builder = crate::new(ble, &mut resources);
1258        let ble = builder.host;
1259
1260        let conn = ConnHandle::new(33);
1261        ble.connections
1262            .connect(conn, AddrKind::PUBLIC, BdAddr::new([0; 6]), LeConnRole::Central)
1263            .unwrap();
1264        let idx = ble
1265            .channels
1266            .alloc(conn, |storage| {
1267                storage.state = ChannelState::Connecting(42);
1268            })
1269            .unwrap();
1270
1271        let chan = ble.channels.poll_created(conn, idx, &ble, None);
1272        assert!(matches!(chan, Poll::Pending));
1273
1274        ble.connections.disconnected(conn, Status::UNSPECIFIED).unwrap();
1275        ble.channels.disconnected(conn).unwrap();
1276
1277        let chan = ble.channels.poll_created(conn, idx, &ble, None);
1278        assert!(matches!(
1279            chan,
1280            Poll::Ready(Err(BleHostError::BleHost(Error::Disconnected)))
1281        ));
1282    }
1283}