trouble_host/
attribute_server.rs

1use core::cell::RefCell;
2use core::marker::PhantomData;
3
4use embassy_sync::blocking_mutex::raw::RawMutex;
5use embassy_sync::blocking_mutex::Mutex;
6
7use crate::att::{self, AttClient, AttCmd, AttErrorCode, AttReq};
8use crate::attribute::{Attribute, AttributeData, AttributeTable, CCCD};
9use crate::cursor::WriteCursor;
10use crate::prelude::Connection;
11use crate::types::uuid::Uuid;
12use crate::{codec, Error, Identity, PacketPool};
13
14#[derive(Default)]
15struct Client {
16    identity: Identity,
17    is_connected: bool,
18}
19
20impl Client {
21    fn set_identity(&mut self, identity: Identity) {
22        self.identity = identity;
23    }
24}
25
26/// A table of CCCD values.
27#[cfg_attr(feature = "defmt", derive(defmt::Format))]
28#[derive(Clone, Debug)]
29pub struct CccdTable<const ENTRIES: usize> {
30    inner: [(u16, CCCD); ENTRIES],
31}
32
33impl<const ENTRIES: usize> Default for CccdTable<ENTRIES> {
34    fn default() -> Self {
35        Self {
36            inner: [(0, CCCD(0)); ENTRIES],
37        }
38    }
39}
40
41impl<const ENTRIES: usize> CccdTable<ENTRIES> {
42    /// Create a new CCCD table from an array of (handle, cccd) pairs.
43    pub fn new(cccd_values: [(u16, CCCD); ENTRIES]) -> Self {
44        Self { inner: cccd_values }
45    }
46
47    /// Get the inner array of (handle, cccd) pairs.
48    pub fn inner(&self) -> &[(u16, CCCD); ENTRIES] {
49        &self.inner
50    }
51
52    fn add_handle(&mut self, cccd_handle: u16) {
53        for (handle, _) in self.inner.iter_mut() {
54            if *handle == 0 {
55                *handle = cccd_handle;
56                break;
57            }
58        }
59    }
60
61    fn disable_all(&mut self) {
62        for (_, value) in self.inner.iter_mut() {
63            value.disable();
64        }
65    }
66
67    fn get_raw(&self, cccd_handle: u16) -> Option<[u8; 2]> {
68        for (handle, value) in self.inner.iter() {
69            if *handle == cccd_handle {
70                return Some(value.raw().to_le_bytes());
71            }
72        }
73        None
74    }
75
76    fn set_notify(&mut self, cccd_handle: u16, is_enabled: bool) {
77        for (handle, value) in self.inner.iter_mut() {
78            if *handle == cccd_handle {
79                trace!("[cccd] set_notify({}) = {}", cccd_handle, is_enabled);
80                value.set_notify(is_enabled);
81                break;
82            }
83        }
84    }
85
86    fn should_notify(&self, cccd_handle: u16) -> bool {
87        for (handle, value) in self.inner.iter() {
88            if *handle == cccd_handle {
89                return value.should_notify();
90            }
91        }
92        false
93    }
94
95    fn set_indicate(&mut self, cccd_handle: u16, is_enabled: bool) {
96        for (handle, value) in self.inner.iter_mut() {
97            if *handle == cccd_handle {
98                trace!("\n\n\n[cccd] set_indicate({}) = {}", cccd_handle, is_enabled);
99                value.set_indicate(is_enabled);
100                break;
101            }
102        }
103    }
104    fn should_indicate(&self, cccd_handle: u16) -> bool {
105        for (handle, value) in self.inner.iter() {
106            if *handle == cccd_handle {
107                return value.should_indicate();
108            }
109        }
110        false
111    }
112}
113
114/// A table of CCCD values for each connected client.
115struct CccdTables<M: RawMutex, const CCCD_MAX: usize, const CONN_MAX: usize> {
116    state: Mutex<M, RefCell<[(Client, CccdTable<CCCD_MAX>); CONN_MAX]>>,
117}
118
119impl<M: RawMutex, const CCCD_MAX: usize, const CONN_MAX: usize> CccdTables<M, CCCD_MAX, CONN_MAX> {
120    fn new<const ATT_MAX: usize>(att_table: &AttributeTable<'_, M, ATT_MAX>) -> Self {
121        let mut values: [(Client, CccdTable<CCCD_MAX>); CONN_MAX] =
122            core::array::from_fn(|_| (Client::default(), CccdTable::default()));
123        let mut base_cccd_table = CccdTable::default();
124        att_table.iterate(|mut at| {
125            while let Some(att) = at.next() {
126                if let AttributeData::Cccd { .. } = att.data {
127                    base_cccd_table.add_handle(att.handle);
128                }
129            }
130        });
131        // add the base CCCD table for each potential connected client
132        for (_, table) in values.iter_mut() {
133            *table = base_cccd_table.clone();
134        }
135        Self {
136            state: Mutex::new(RefCell::new(values)),
137        }
138    }
139
140    fn connect(&self, peer_identity: &Identity) -> Result<(), Error> {
141        self.state.lock(|n| {
142            trace!("[server] searching for peer {:?}", peer_identity);
143            let mut n = n.borrow_mut();
144            let empty_slot = Identity::default();
145            for (client, table) in n.iter_mut() {
146                if client.identity.match_identity(peer_identity) {
147                    // trace!("[server] found! table = {:?}", *table);
148                    client.is_connected = true;
149                    return Ok(());
150                } else if client.identity == empty_slot {
151                    //  trace!("[server] empty slot: connecting");
152                    client.is_connected = true;
153                    client.set_identity(*peer_identity);
154                    return Ok(());
155                }
156            }
157            trace!("[server] all slots full...");
158            // if we got here all slots are full; replace the first disconnected client
159            for (client, table) in n.iter_mut() {
160                if !client.is_connected {
161                    trace!("[server] booting disconnected peer {:?}", client.identity);
162                    client.is_connected = true;
163                    client.set_identity(*peer_identity);
164                    // erase the previous client's config
165                    table.disable_all();
166                    return Ok(());
167                }
168            }
169            // Should be unreachable if the max connections (CONN_MAX) matches that defined
170            // in HostResources...
171            warn!("[server] unable to obtain CCCD slot");
172            Err(Error::ConnectionLimitReached)
173        })
174    }
175
176    fn disconnect(&self, peer_identity: &Identity) {
177        self.state.lock(|n| {
178            let mut n = n.borrow_mut();
179            for (client, _) in n.iter_mut() {
180                if client.identity.match_identity(peer_identity) {
181                    client.is_connected = false;
182                    break;
183                }
184            }
185        })
186    }
187
188    fn get_value(&self, peer_identity: &Identity, cccd_handle: u16) -> Option<[u8; 2]> {
189        self.state.lock(|n| {
190            let n = n.borrow();
191            for (client, table) in n.iter() {
192                if client.identity.match_identity(peer_identity) {
193                    return table.get_raw(cccd_handle);
194                }
195            }
196            None
197        })
198    }
199
200    fn set_notify(&self, peer_identity: &Identity, cccd_handle: u16, is_enabled: bool) {
201        self.state.lock(|n| {
202            let mut n = n.borrow_mut();
203            for (client, table) in n.iter_mut() {
204                if client.identity.match_identity(peer_identity) {
205                    table.set_notify(cccd_handle, is_enabled);
206                    break;
207                }
208            }
209        })
210    }
211
212    fn should_notify(&self, peer_identity: &Identity, cccd_handle: u16) -> bool {
213        self.state.lock(|n| {
214            let n = n.borrow();
215            for (client, table) in n.iter() {
216                if client.identity.match_identity(peer_identity) {
217                    return table.should_notify(cccd_handle);
218                }
219            }
220            false
221        })
222    }
223
224    fn set_indicate(&self, peer_identity: &Identity, cccd_handle: u16, is_enabled: bool) {
225        self.state.lock(|n| {
226            let mut n = n.borrow_mut();
227            for (client, table) in n.iter_mut() {
228                if client.identity.match_identity(peer_identity) {
229                    table.set_indicate(cccd_handle, is_enabled);
230                    break;
231                }
232            }
233        })
234    }
235
236    fn should_indicate(&self, peer_identity: &Identity, cccd_handle: u16) -> bool {
237        self.state.lock(|n| {
238            let n = n.borrow();
239            for (client, table) in n.iter() {
240                if client.identity.match_identity(peer_identity) {
241                    return table.should_indicate(cccd_handle);
242                }
243            }
244            false
245        })
246    }
247
248    fn get_cccd_table(&self, peer_identity: &Identity) -> Option<CccdTable<CCCD_MAX>> {
249        self.state.lock(|n| {
250            let n = n.borrow();
251            for (client, table) in n.iter() {
252                if client.identity.match_identity(peer_identity) {
253                    return Some(table.clone());
254                }
255            }
256            None
257        })
258    }
259
260    fn set_cccd_table(&self, peer_identity: &Identity, table: CccdTable<CCCD_MAX>) {
261        self.state.lock(|n| {
262            let mut n = n.borrow_mut();
263            for (client, t) in n.iter_mut() {
264                if client.identity.match_identity(peer_identity) {
265                    trace!("Setting cccd table {:?} for {:?}", table, peer_identity);
266                    *t = table;
267                    break;
268                }
269            }
270        })
271    }
272
273    fn update_identity(&self, identity: Identity) -> Result<(), Error> {
274        self.state.lock(|n| {
275            let mut n = n.borrow_mut();
276            for (client, _) in n.iter_mut() {
277                if identity.match_identity(&client.identity) {
278                    client.set_identity(identity);
279                    return Ok(());
280                }
281            }
282            Err(Error::NotFound)
283        })
284    }
285}
286
287/// A GATT server capable of processing the GATT protocol using the provided table of attributes.
288pub struct AttributeServer<
289    'values,
290    M: RawMutex,
291    P: PacketPool,
292    const ATT_MAX: usize,
293    const CCCD_MAX: usize,
294    const CONN_MAX: usize,
295> {
296    att_table: AttributeTable<'values, M, ATT_MAX>,
297    cccd_tables: CccdTables<M, CCCD_MAX, CONN_MAX>,
298    _p: PhantomData<P>,
299}
300
301pub(crate) mod sealed {
302    use super::*;
303
304    pub trait DynamicAttributeServer<P: PacketPool> {
305        fn connect(&self, connection: &Connection<'_, P>) -> Result<(), Error>;
306        fn disconnect(&self, connection: &Connection<'_, P>);
307        fn process(
308            &self,
309            connection: &Connection<'_, P>,
310            packet: &AttClient,
311            rx: &mut [u8],
312        ) -> Result<Option<usize>, Error>;
313        fn should_notify(&self, connection: &Connection<'_, P>, cccd_handle: u16) -> bool;
314        fn should_indicate(&self, connection: &Connection<'_, P>, cccd_handle: u16) -> bool;
315        fn set(&self, characteristic: u16, input: &[u8]) -> Result<(), Error>;
316        fn update_identity(&self, identity: Identity) -> Result<(), Error>;
317    }
318}
319
320/// Type erased attribute server
321pub trait DynamicAttributeServer<P: PacketPool>: sealed::DynamicAttributeServer<P> {}
322
323impl<M: RawMutex, P: PacketPool, const ATT_MAX: usize, const CCCD_MAX: usize, const CONN_MAX: usize>
324    DynamicAttributeServer<P> for AttributeServer<'_, M, P, ATT_MAX, CCCD_MAX, CONN_MAX>
325{
326}
327impl<M: RawMutex, P: PacketPool, const ATT_MAX: usize, const CCCD_MAX: usize, const CONN_MAX: usize>
328    sealed::DynamicAttributeServer<P> for AttributeServer<'_, M, P, ATT_MAX, CCCD_MAX, CONN_MAX>
329{
330    fn connect(&self, connection: &Connection<'_, P>) -> Result<(), Error> {
331        AttributeServer::connect(self, connection)
332    }
333
334    fn disconnect(&self, connection: &Connection<'_, P>) {
335        self.cccd_tables.disconnect(&connection.peer_identity());
336    }
337
338    fn process(
339        &self,
340        connection: &Connection<'_, P>,
341        packet: &AttClient,
342        rx: &mut [u8],
343    ) -> Result<Option<usize>, Error> {
344        let res = AttributeServer::process(self, connection, packet, rx)?;
345        Ok(res)
346    }
347
348    fn should_notify(&self, connection: &Connection<'_, P>, cccd_handle: u16) -> bool {
349        AttributeServer::should_notify(self, connection, cccd_handle)
350    }
351    fn should_indicate(&self, connection: &Connection<'_, P>, cccd_handle: u16) -> bool {
352        AttributeServer::should_indicate(self, connection, cccd_handle)
353    }
354
355    fn set(&self, characteristic: u16, input: &[u8]) -> Result<(), Error> {
356        self.att_table.set_raw(characteristic, input)
357    }
358
359    fn update_identity(&self, identity: Identity) -> Result<(), Error> {
360        self.cccd_tables.update_identity(identity)
361    }
362}
363
364impl<'values, M: RawMutex, P: PacketPool, const ATT_MAX: usize, const CCCD_MAX: usize, const CONN_MAX: usize>
365    AttributeServer<'values, M, P, ATT_MAX, CCCD_MAX, CONN_MAX>
366{
367    /// Create a new instance of the AttributeServer
368    pub fn new(
369        att_table: AttributeTable<'values, M, ATT_MAX>,
370    ) -> AttributeServer<'values, M, P, ATT_MAX, CCCD_MAX, CONN_MAX> {
371        let cccd_tables = CccdTables::new(&att_table);
372        AttributeServer {
373            att_table,
374            cccd_tables,
375            _p: PhantomData,
376        }
377    }
378
379    pub(crate) fn connect(&self, connection: &Connection<'_, P>) -> Result<(), Error> {
380        self.cccd_tables.connect(&connection.peer_identity())
381    }
382
383    pub(crate) fn should_notify(&self, connection: &Connection<'_, P>, cccd_handle: u16) -> bool {
384        self.cccd_tables.should_notify(&connection.peer_identity(), cccd_handle)
385    }
386
387    pub(crate) fn should_indicate(&self, connection: &Connection<'_, P>, cccd_handle: u16) -> bool {
388        self.cccd_tables
389            .should_indicate(&connection.peer_identity(), cccd_handle)
390    }
391
392    fn read_attribute_data(
393        &self,
394        connection: &Connection<'_, P>,
395        offset: usize,
396        att: &mut Attribute<'values>,
397        data: &mut [u8],
398    ) -> Result<usize, AttErrorCode> {
399        if let AttributeData::Cccd { .. } = att.data {
400            // CCCD values for each connected client are held in the CCCD tables:
401            // the value is written back into att.data so att.read() has the final
402            // say when parsing at the requested offset.
403            if let Some(value) = self.cccd_tables.get_value(&connection.peer_identity(), att.handle) {
404                let _ = att.write(0, value.as_slice());
405            }
406        }
407        att.read(offset, data)
408    }
409
410    fn write_attribute_data(
411        &self,
412        connection: &Connection<'_, P>,
413        offset: usize,
414        att: &mut Attribute<'values>,
415        data: &[u8],
416    ) -> Result<(), AttErrorCode> {
417        let err = att.write(offset, data);
418        if err.is_ok() {
419            if let AttributeData::Cccd {
420                notifications,
421                indications,
422            } = att.data
423            {
424                self.cccd_tables
425                    .set_notify(&connection.peer_identity(), att.handle, notifications);
426                self.cccd_tables
427                    .set_indicate(&connection.peer_identity(), att.handle, indications);
428            }
429        }
430        err
431    }
432
433    fn handle_read_by_type_req(
434        &self,
435        connection: &Connection<'_, P>,
436        buf: &mut [u8],
437        start: u16,
438        end: u16,
439        attribute_type: &Uuid,
440    ) -> Result<usize, codec::Error> {
441        let mut handle = start;
442        let mut data = WriteCursor::new(buf);
443
444        let (mut header, mut body) = data.split(2)?;
445        let err = self.att_table.iterate(|mut it| {
446            let mut ret = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
447            while let Some(att) = it.next() {
448                // trace!("[read_by_type] Check attribute {:?} {}", att.uuid, att.handle);
449                if &att.uuid == attribute_type && att.handle >= start && att.handle <= end {
450                    body.write(att.handle)?;
451                    handle = att.handle;
452
453                    let new_ret = self.read_attribute_data(connection, 0, att, body.write_buf());
454                    match (new_ret, ret) {
455                        (Ok(first_length), Err(_)) => {
456                            // First successful read, store this length, all subsequent ones must match it.
457                            // debug!("[read_by_type] found first entry {:x?}, handle {}", att.uuid, handle);
458                            ret = new_ret;
459                            body.commit(first_length)?;
460                        }
461                        (Ok(new_length), Ok(old_length)) => {
462                            // Any matching attribute after the first, verify the lengths are identical, if not break.
463                            if new_length == old_length {
464                                // debug!("[read_by_type] found equal length {}, handle {}", new_length, handle);
465                                body.commit(new_length)?;
466                            } else {
467                                // We encountered a different length,  unwind the handle.
468                                // debug!("[read_by_type] different length: {}, old: {}", new_length, old_length);
469                                body.truncate(body.len() - 2);
470                                // And then break to ensure we respond with the previously found entries.
471                                break;
472                            }
473                        }
474                        (Err(error_code), Ok(_old_length)) => {
475                            // New read failed, but we had a previous value, return what we had thus far, truncate to
476                            // remove the previously written handle.
477                            body.truncate(body.len() - 2);
478                            // We do silently drop the error here.
479                            // debug!("[read_by_group] new error: {:?}, returning result thus far", error_code);
480                            break;
481                        }
482                        (Err(_), Err(_)) => {
483                            // Error on the first possible read, return this error.
484                            ret = new_ret;
485                            break;
486                        }
487                    }
488                    // If we get here, we always have had a successful read, and we can check that we still have space
489                    // left in the buffer to write the next entry if it exists.
490                    if let Ok(expected_length) = ret {
491                        if body.available() < expected_length + 2 {
492                            break;
493                        }
494                    }
495                }
496            }
497            ret
498        });
499
500        match err {
501            Ok(len) => {
502                header.write(att::ATT_READ_BY_TYPE_RSP)?;
503                header.write(2 + len as u8)?;
504                Ok(header.len() + body.len())
505            }
506            Err(e) => Ok(Self::error_response(data, att::ATT_READ_BY_TYPE_REQ, handle, e)?),
507        }
508    }
509
510    fn handle_read_by_group_type_req(
511        &self,
512        connection: &Connection<'_, P>,
513        buf: &mut [u8],
514        start: u16,
515        end: u16,
516        group_type: &Uuid,
517    ) -> Result<usize, codec::Error> {
518        let mut handle = start;
519        let mut data = WriteCursor::new(buf);
520        let (mut header, mut body) = data.split(2)?;
521        // Multiple entries can be returned in the response as long as they are of equal length.
522        let err = self.att_table.iterate(|mut it| {
523            // ret either holds the length of the attribute, or the error code encountered.
524            let mut ret: Result<usize, AttErrorCode> = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
525            while let Some(att) = it.next() {
526                // trace!("[read_by_group] Check attribute {:x?} {}", att.uuid, att.handle);
527                if &att.uuid == group_type && att.handle >= start && att.handle <= end {
528                    // debug!("[read_by_group] found! {:x?} handle: {}", att.uuid, att.handle);
529                    handle = att.handle;
530
531                    body.write(att.handle)?;
532                    body.write(att.last_handle_in_group)?;
533                    let new_ret = self.read_attribute_data(connection, 0, att, body.write_buf());
534                    match (new_ret, ret) {
535                        (Ok(first_length), Err(_)) => {
536                            // First successful read, store this length, all subsequent ones must match it.
537                            // debug!("[read_by_group] found first entry {:x?}, handle {}", att.uuid, handle);
538                            ret = new_ret;
539                            body.commit(first_length)?;
540                        }
541                        (Ok(new_length), Ok(old_length)) => {
542                            // Any matching attribute after the first, verify the lengths are identical, if not break.
543                            if new_length == old_length {
544                                // debug!("[read_by_group] found equal length {}, handle {}", new_length, handle);
545                                body.commit(new_length)?;
546                            } else {
547                                // We encountered a different length,  unwind the handle and last_handle written.
548                                // debug!("[read_by_group] different length: {}, old: {}", new_length, old_length);
549                                body.truncate(body.len() - 4);
550                                // And then break to ensure we respond with the previously found entries.
551                                break;
552                            }
553                        }
554                        (Err(error_code), Ok(_old_length)) => {
555                            // New read failed, but we had a previous value, return what we had thus far, truncate to
556                            // remove the previously written handle and last handle.
557                            body.truncate(body.len() - 4);
558                            // We do silently drop the error here.
559                            // debug!("[read_by_group] new error: {:?}, returning result thus far", error_code);
560                            break;
561                        }
562                        (Err(_), Err(_)) => {
563                            // Error on the first possible read, return this error.
564                            ret = new_ret;
565                            break;
566                        }
567                    }
568                    // If we get here, we always have had a successful read, and we can check that we still have space
569                    // left in the buffer to write the next entry if it exists.
570                    if let Ok(expected_length) = ret {
571                        if body.available() < expected_length + 4 {
572                            break;
573                        }
574                    }
575                }
576            }
577            ret
578        });
579
580        match err {
581            Ok(len) => {
582                header.write(att::ATT_READ_BY_GROUP_TYPE_RSP)?;
583                header.write(4 + len as u8)?;
584                Ok(header.len() + body.len())
585            }
586            Err(e) => Ok(Self::error_response(data, att::ATT_READ_BY_GROUP_TYPE_REQ, handle, e)?),
587        }
588    }
589
590    fn handle_read_req(
591        &self,
592        connection: &Connection<'_, P>,
593        buf: &mut [u8],
594        handle: u16,
595    ) -> Result<usize, codec::Error> {
596        let mut data = WriteCursor::new(buf);
597
598        data.write(att::ATT_READ_RSP)?;
599
600        let err = self.att_table.iterate(|mut it| {
601            let mut err = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
602            while let Some(att) = it.next() {
603                if att.handle == handle {
604                    err = self.read_attribute_data(connection, 0, att, data.write_buf());
605                    if let Ok(len) = err {
606                        data.commit(len)?;
607                    }
608                    break;
609                }
610            }
611            err
612        });
613
614        match err {
615            Ok(_) => Ok(data.len()),
616            Err(e) => Ok(Self::error_response(data, att::ATT_READ_REQ, handle, e)?),
617        }
618    }
619
620    fn handle_write_cmd(
621        &self,
622        connection: &Connection<'_, P>,
623        buf: &mut [u8],
624        handle: u16,
625        data: &[u8],
626    ) -> Result<usize, codec::Error> {
627        self.att_table.iterate(|mut it| {
628            while let Some(att) = it.next() {
629                if att.handle == handle {
630                    // Write commands can't respond with an error.
631                    let _ = self.write_attribute_data(connection, 0, att, data);
632                    break;
633                }
634            }
635        });
636        Ok(0)
637    }
638
639    fn handle_write_req(
640        &self,
641        connection: &Connection<'_, P>,
642        buf: &mut [u8],
643        handle: u16,
644        data: &[u8],
645    ) -> Result<usize, codec::Error> {
646        let err = self.att_table.iterate(|mut it| {
647            let mut err = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
648            while let Some(att) = it.next() {
649                if att.handle == handle {
650                    err = self.write_attribute_data(connection, 0, att, data);
651                    break;
652                }
653            }
654            err
655        });
656
657        let mut w = WriteCursor::new(buf);
658        match err {
659            Ok(()) => {
660                w.write(att::ATT_WRITE_RSP)?;
661                Ok(w.len())
662            }
663            Err(e) => Ok(Self::error_response(w, att::ATT_WRITE_REQ, handle, e)?),
664        }
665    }
666
667    fn handle_find_type_value(
668        &self,
669        buf: &mut [u8],
670        start: u16,
671        end: u16,
672        attr_type: u16,
673        attr_value: &[u8],
674    ) -> Result<usize, codec::Error> {
675        let mut w = WriteCursor::new(buf);
676        let attr_type = Uuid::new_short(attr_type);
677
678        w.write(att::ATT_FIND_BY_TYPE_VALUE_RSP)?;
679        self.att_table.iterate(|mut it| {
680            while let Some(att) = it.next() {
681                if att.handle >= start && att.handle <= end && att.uuid == attr_type {
682                    if let AttributeData::Service { uuid } = &att.data {
683                        if uuid.as_raw() == attr_value {
684                            if w.available() < 4 + uuid.as_raw().len() {
685                                break;
686                            }
687                            w.write(att.handle)?;
688                            w.write(att.last_handle_in_group)?;
689                        }
690                    }
691                }
692            }
693            Ok::<(), codec::Error>(())
694        })?;
695        if w.len() > 1 {
696            Ok(w.len())
697        } else {
698            Ok(Self::error_response(
699                w,
700                att::ATT_FIND_BY_TYPE_VALUE_REQ,
701                start,
702                AttErrorCode::ATTRIBUTE_NOT_FOUND,
703            )?)
704        }
705    }
706
707    fn handle_find_information(&self, buf: &mut [u8], start: u16, end: u16) -> Result<usize, codec::Error> {
708        let mut w = WriteCursor::new(buf);
709
710        let (mut header, mut body) = w.split(2)?;
711
712        header.write(att::ATT_FIND_INFORMATION_RSP)?;
713        let mut t = 0;
714
715        self.att_table.iterate(|mut it| {
716            while let Some(att) = it.next() {
717                if att.handle >= start && att.handle <= end {
718                    if t == 0 {
719                        t = att.uuid.get_type();
720                    } else if t != att.uuid.get_type() {
721                        break;
722                    }
723                    body.write(att.handle)?;
724                    body.append(att.uuid.as_raw())?;
725                }
726            }
727            Ok::<(), codec::Error>(())
728        })?;
729        header.write(t)?;
730
731        if body.len() > 2 {
732            Ok(header.len() + body.len())
733        } else {
734            Ok(Self::error_response(
735                w,
736                att::ATT_FIND_INFORMATION_REQ,
737                start,
738                AttErrorCode::ATTRIBUTE_NOT_FOUND,
739            )?)
740        }
741    }
742
743    fn error_response(
744        mut w: WriteCursor<'_>,
745        opcode: u8,
746        handle: u16,
747        code: AttErrorCode,
748    ) -> Result<usize, codec::Error> {
749        w.reset();
750        w.write(att::ATT_ERROR_RSP)?;
751        w.write(opcode)?;
752        w.write(handle)?;
753        w.write(code)?;
754        Ok(w.len())
755    }
756
757    fn handle_prepare_write(
758        &self,
759        connection: &Connection<'_, P>,
760        buf: &mut [u8],
761        handle: u16,
762        offset: u16,
763        value: &[u8],
764    ) -> Result<usize, codec::Error> {
765        let mut w = WriteCursor::new(buf);
766        w.write(att::ATT_PREPARE_WRITE_RSP)?;
767        w.write(handle)?;
768        w.write(offset)?;
769
770        let err = self.att_table.iterate(|mut it| {
771            let mut err = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
772            while let Some(att) = it.next() {
773                if att.handle == handle {
774                    err = self.write_attribute_data(connection, offset as usize, att, value);
775                    w.append(value)?;
776                    break;
777                }
778            }
779            err
780        });
781
782        match err {
783            Ok(()) => Ok(w.len()),
784            Err(e) => Ok(Self::error_response(w, att::ATT_PREPARE_WRITE_REQ, handle, e)?),
785        }
786    }
787
788    fn handle_execute_write(&self, buf: &mut [u8], _flags: u8) -> Result<usize, codec::Error> {
789        let mut w = WriteCursor::new(buf);
790        w.write(att::ATT_EXECUTE_WRITE_RSP)?;
791        Ok(w.len())
792    }
793
794    fn handle_read_blob(
795        &self,
796        connection: &Connection<'_, P>,
797        buf: &mut [u8],
798        handle: u16,
799        offset: u16,
800    ) -> Result<usize, codec::Error> {
801        let mut w = WriteCursor::new(buf);
802        w.write(att::ATT_READ_BLOB_RSP)?;
803
804        let err = self.att_table.iterate(|mut it| {
805            let mut err = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
806            while let Some(att) = it.next() {
807                if att.handle == handle {
808                    err = self.read_attribute_data(connection, offset as usize, att, w.write_buf());
809                    if let Ok(n) = err {
810                        w.commit(n)?;
811                    }
812                    break;
813                }
814            }
815            err
816        });
817
818        match err {
819            Ok(_) => Ok(w.len()),
820            Err(e) => Ok(Self::error_response(w, att::ATT_READ_BLOB_REQ, handle, e)?),
821        }
822    }
823
824    fn handle_read_multiple(&self, buf: &mut [u8], handles: &[u8]) -> Result<usize, codec::Error> {
825        let w = WriteCursor::new(buf);
826        Self::error_response(
827            w,
828            att::ATT_READ_MULTIPLE_REQ,
829            u16::from_le_bytes([handles[0], handles[1]]),
830            AttErrorCode::ATTRIBUTE_NOT_FOUND,
831        )
832    }
833
834    /// Process an event and produce a response if necessary
835    pub fn process(
836        &self,
837        connection: &Connection<'_, P>,
838        packet: &AttClient,
839        rx: &mut [u8],
840    ) -> Result<Option<usize>, codec::Error> {
841        let len = match packet {
842            AttClient::Request(AttReq::ReadByType {
843                start,
844                end,
845                attribute_type,
846            }) => self.handle_read_by_type_req(connection, rx, *start, *end, attribute_type)?,
847
848            AttClient::Request(AttReq::ReadByGroupType { start, end, group_type }) => {
849                self.handle_read_by_group_type_req(connection, rx, *start, *end, group_type)?
850            }
851            AttClient::Request(AttReq::FindInformation {
852                start_handle,
853                end_handle,
854            }) => self.handle_find_information(rx, *start_handle, *end_handle)?,
855
856            AttClient::Request(AttReq::Read { handle }) => self.handle_read_req(connection, rx, *handle)?,
857
858            AttClient::Command(AttCmd::Write { handle, data }) => {
859                self.handle_write_cmd(connection, rx, *handle, data)?;
860                0
861            }
862
863            AttClient::Request(AttReq::Write { handle, data }) => {
864                self.handle_write_req(connection, rx, *handle, data)?
865            }
866
867            AttClient::Request(AttReq::ExchangeMtu { mtu }) => 0, // Done outside,
868
869            AttClient::Request(AttReq::FindByTypeValue {
870                start_handle,
871                end_handle,
872                att_type,
873                att_value,
874            }) => self.handle_find_type_value(rx, *start_handle, *end_handle, *att_type, att_value)?,
875
876            AttClient::Request(AttReq::PrepareWrite { handle, offset, value }) => {
877                self.handle_prepare_write(connection, rx, *handle, *offset, value)?
878            }
879
880            AttClient::Request(AttReq::ExecuteWrite { flags }) => self.handle_execute_write(rx, *flags)?,
881
882            AttClient::Request(AttReq::ReadBlob { handle, offset }) => {
883                self.handle_read_blob(connection, rx, *handle, *offset)?
884            }
885
886            AttClient::Request(AttReq::ReadMultiple { handles }) => self.handle_read_multiple(rx, handles)?,
887
888            AttClient::Confirmation(_) => 0,
889        };
890        if len > 0 {
891            Ok(Some(len))
892        } else {
893            Ok(None)
894        }
895    }
896
897    /// Get a reference to the attribute table
898    pub fn table(&self) -> &AttributeTable<'values, M, ATT_MAX> {
899        &self.att_table
900    }
901
902    /// Get the CCCD table for a connection
903    pub fn get_cccd_table(&self, connection: &Connection<'_, P>) -> Option<CccdTable<CCCD_MAX>> {
904        self.cccd_tables.get_cccd_table(&connection.peer_identity())
905    }
906
907    /// Set the CCCD table for a connection
908    pub fn set_cccd_table(&self, connection: &Connection<'_, P>, table: CccdTable<CCCD_MAX>) {
909        self.cccd_tables.set_cccd_table(&connection.peer_identity(), table);
910    }
911}
912
913#[cfg(test)]
914mod tests {
915    use core::task::Poll;
916
917    use bt_hci::param::{AddrKind, BdAddr, ConnHandle, LeConnRole};
918    use embassy_sync::blocking_mutex::raw::NoopRawMutex;
919
920    use super::*;
921    use crate::connection_manager::tests::{setup, ADDR_1};
922    use crate::prelude::*;
923
924    #[test]
925    fn test_attribute_server_last_handle_of_group() {
926        // This test comes from a situation where a service had exactly 16 handles, this resulted in the
927        // last_handle_in_group field of the ReadByGroupType response was 16 aligned (96 to be exact), in this situation
928        // the next request will start at 96 + 1, which was one handle beyond the start of the next service.
929        //
930        // Snippet from the original failure mode:
931        // WARN  trouble_host::attribute_server] Looking for group: Uuid16([0, 28]) between 75 and 65535
932        // DEBUG trouble_host::attribute_server] [read_by_group] found! Uuid16([0, 28]) 80
933        // DEBUG trouble_host::attribute_server] last_handle_in_group: 96
934        // DEBUG trouble_host::attribute_server] read_attribute_data: Ok(16)
935        // TRACE trouble_host::host] [host] granted send packets = 1, len = 30
936        // TRACE trouble_host::host] [host] sent acl packet len = 26
937        // TRACE trouble_host::host] [host] inbound l2cap header channel = 4, fragment len = 7, total = 7
938        // INFO  main_ble::ble_bas_peripheral] [gatt-attclient]: ReadByGroupType { start: 97, end: 65535, group_type: Uuid16([0, 40]) }
939        // INFO  main_ble::ble_bas_peripheral] [gatt] other event
940        // WARN  trouble_host::attribute_server] Looking for group: Uuid16([0, 28]) between 97 and 65535
941        // WARN  trouble_host::attribute_server] [read_by_group] Dit not find attribute Uuid16([0, 28]) between 97  65535
942
943        // The request:
944        // INFO  main_ble::ble_bas_peripheral] [gatt-attclient]: ReadByGroupType { start: 97, end: 65535, group_type: Uuid16([0, 40]) }
945        // In trace, the "group_type: Uuid16([0, 40]) }" is decimal, so this becomes group type 0x2800, which is the
946        // primary service group.
947        let primary_service_group_type = Uuid::new_short(0x2800);
948
949        let _ = env_logger::try_init();
950        const MAX_ATTRIBUTES: usize = 1024;
951        const CONNECTIONS_MAX: usize = 3;
952        const CCCD_MAX: usize = 1024;
953        const L2CAP_CHANNELS_MAX: usize = 5;
954        type FacadeDummyType = [u8; 0];
955
956        // Instead of only checking the failure mode, we fuzz the length of the interior service to cross over several
957        // multiples of 16.
958        for interior_handle_count in 0..=64u8 {
959            debug!("Testing with interior handle count of {}", interior_handle_count);
960
961            // Create a new table.
962            let mut table: AttributeTable<'_, NoopRawMutex, { MAX_ATTRIBUTES }> = AttributeTable::new();
963
964            // Add a first service, contents don't really matter, but the issue doesn't manifest without this.
965            {
966                let svc = table.add_service(Service {
967                    uuid: Uuid::new_long([10; 16]).into(),
968                });
969            }
970
971            // Add an interior service that has a varying length.
972            {
973                let mut svc = table.add_service(Service {
974                    uuid: Uuid::new_long([0; 16]).into(),
975                });
976
977                for c in 0..interior_handle_count {
978                    let _service_instance = svc
979                        .add_characteristic_ro::<[u8; 2], _>(Uuid::new_long([c; 16]), &[0, 0])
980                        .build();
981                }
982            }
983            // Now add the service at the end, contents don't really matter.
984            {
985                table.add_service(Service {
986                    uuid: Uuid::new_long([8; 16]).into(),
987                });
988            }
989
990            // Print the table for debugging.
991            table.iterate(|mut it| {
992                while let Some(att) = it.next() {
993                    let handle = att.handle;
994                    let uuid = &att.uuid;
995                    trace!(
996                        "last_handle_in_group for 0x{:0>4x?}, 0x{:0>2x?}  0x{:0>2x?}",
997                        handle,
998                        uuid,
999                        att.last_handle_in_group
1000                    );
1001                }
1002            });
1003
1004            // Create a server.
1005            let server = AttributeServer::<_, DefaultPacketPool, MAX_ATTRIBUTES, CCCD_MAX, CONNECTIONS_MAX>::new(table);
1006
1007            // Create the connection manager.
1008            let mgr = setup();
1009
1010            // Try to connect.
1011            assert!(mgr.poll_accept(LeConnRole::Peripheral, &[], None).is_pending());
1012            unwrap!(mgr.connect(
1013                ConnHandle::new(0),
1014                AddrKind::RANDOM,
1015                BdAddr::new(ADDR_1),
1016                LeConnRole::Peripheral
1017            ));
1018
1019            if let Poll::Ready(conn_handle) = mgr.poll_accept(LeConnRole::Peripheral, &[], None) {
1020                // We now have a connection, we can send the mocked requests to our attribute server.
1021                let mut buffer = [0u8; 64];
1022
1023                let mut start = 0;
1024                let end = u16::MAX;
1025                // There are always three services that we should be able to discover.
1026                for _ in 0..3 {
1027                    let length = server
1028                        .handle_read_by_group_type_req(
1029                            &conn_handle,
1030                            &mut buffer,
1031                            start,
1032                            end,
1033                            &primary_service_group_type,
1034                        )
1035                        .unwrap();
1036                    let response = &buffer[0..length];
1037                    trace!("  0x{:0>2x?}", response);
1038                    // It should be a successful response, because the service should be found, this will assert if
1039                    // we failed to retrieve the third service.
1040                    assert_eq!(response[0], att::ATT_READ_BY_GROUP_TYPE_RSP);
1041                    // The last handle of this group is at byte 4 & 5, so retrieve that and update the start for the
1042                    // next cycle. We only check the first response here, and ignore any others that may be in the
1043                    // response.
1044                    let last_handle = u16::from_le_bytes([response[4], response[5]]);
1045                    start = last_handle + 1;
1046                }
1047            } else {
1048                panic!("expected connection to be accepted");
1049            };
1050        }
1051    }
1052}