trouble_host/
attribute_server.rs

1use core::cell::RefCell;
2use core::marker::PhantomData;
3
4use embassy_sync::blocking_mutex::raw::RawMutex;
5use embassy_sync::blocking_mutex::Mutex;
6
7use crate::att::{self, AttClient, AttCmd, AttErrorCode, AttReq};
8use crate::attribute::{Attribute, AttributeData, AttributeTable, CCCD};
9use crate::cursor::WriteCursor;
10use crate::prelude::Connection;
11use crate::types::uuid::Uuid;
12use crate::{codec, Error, Identity, PacketPool};
13
14#[derive(Default)]
15struct Client {
16    identity: Identity,
17    is_connected: bool,
18}
19
20impl Client {
21    fn set_identity(&mut self, identity: Identity) {
22        self.identity = identity;
23    }
24}
25
26/// A table of CCCD values.
27#[cfg_attr(feature = "defmt", derive(defmt::Format))]
28#[derive(Clone, Debug)]
29pub struct CccdTable<const ENTRIES: usize> {
30    inner: [(u16, CCCD); ENTRIES],
31}
32
33impl<const ENTRIES: usize> Default for CccdTable<ENTRIES> {
34    fn default() -> Self {
35        Self {
36            inner: [(0, CCCD(0)); ENTRIES],
37        }
38    }
39}
40
41impl<const ENTRIES: usize> CccdTable<ENTRIES> {
42    /// Create a new CCCD table from an array of (handle, cccd) pairs.
43    pub fn new(cccd_values: [(u16, CCCD); ENTRIES]) -> Self {
44        Self { inner: cccd_values }
45    }
46
47    /// Get the inner array of (handle, cccd) pairs.
48    pub fn inner(&self) -> &[(u16, CCCD); ENTRIES] {
49        &self.inner
50    }
51
52    fn add_handle(&mut self, cccd_handle: u16) {
53        for (handle, _) in self.inner.iter_mut() {
54            if *handle == 0 {
55                *handle = cccd_handle;
56                break;
57            }
58        }
59    }
60
61    fn disable_all(&mut self) {
62        for (_, value) in self.inner.iter_mut() {
63            value.disable();
64        }
65    }
66
67    fn get_raw(&self, cccd_handle: u16) -> Option<[u8; 2]> {
68        for (handle, value) in self.inner.iter() {
69            if *handle == cccd_handle {
70                return Some(value.raw().to_le_bytes());
71            }
72        }
73        None
74    }
75
76    fn set_notify(&mut self, cccd_handle: u16, is_enabled: bool) {
77        for (handle, value) in self.inner.iter_mut() {
78            if *handle == cccd_handle {
79                trace!("[cccd] set_notify({}) = {}", cccd_handle, is_enabled);
80                value.set_notify(is_enabled);
81                break;
82            }
83        }
84    }
85
86    fn should_notify(&self, cccd_handle: u16) -> bool {
87        for (handle, value) in self.inner.iter() {
88            if *handle == cccd_handle {
89                return value.should_notify();
90            }
91        }
92        false
93    }
94}
95
96/// A table of CCCD values for each connected client.
97struct CccdTables<M: RawMutex, const CCCD_MAX: usize, const CONN_MAX: usize> {
98    state: Mutex<M, RefCell<[(Client, CccdTable<CCCD_MAX>); CONN_MAX]>>,
99}
100
101impl<M: RawMutex, const CCCD_MAX: usize, const CONN_MAX: usize> CccdTables<M, CCCD_MAX, CONN_MAX> {
102    fn new<const ATT_MAX: usize>(att_table: &AttributeTable<'_, M, ATT_MAX>) -> Self {
103        let mut values: [(Client, CccdTable<CCCD_MAX>); CONN_MAX] =
104            core::array::from_fn(|_| (Client::default(), CccdTable::default()));
105        let mut base_cccd_table = CccdTable::default();
106        att_table.iterate(|mut at| {
107            while let Some(att) = at.next() {
108                if let AttributeData::Cccd { .. } = att.data {
109                    base_cccd_table.add_handle(att.handle);
110                }
111            }
112        });
113        // add the base CCCD table for each potential connected client
114        for (_, table) in values.iter_mut() {
115            *table = base_cccd_table.clone();
116        }
117        Self {
118            state: Mutex::new(RefCell::new(values)),
119        }
120    }
121
122    fn connect(&self, peer_identity: &Identity) -> Result<(), Error> {
123        self.state.lock(|n| {
124            trace!("[server] searching for peer {:?}", peer_identity);
125            let mut n = n.borrow_mut();
126            let empty_slot = Identity::default();
127            for (client, table) in n.iter_mut() {
128                if client.identity.match_identity(peer_identity) {
129                    // trace!("[server] found! table = {:?}", *table);
130                    client.is_connected = true;
131                    return Ok(());
132                } else if client.identity == empty_slot {
133                    //  trace!("[server] empty slot: connecting");
134                    client.is_connected = true;
135                    client.set_identity(*peer_identity);
136                    return Ok(());
137                }
138            }
139            trace!("[server] all slots full...");
140            // if we got here all slots are full; replace the first disconnected client
141            for (client, table) in n.iter_mut() {
142                if !client.is_connected {
143                    trace!("[server] booting disconnected peer {:?}", client.identity);
144                    client.is_connected = true;
145                    client.set_identity(*peer_identity);
146                    // erase the previous client's config
147                    table.disable_all();
148                    return Ok(());
149                }
150            }
151            // Should be unreachable if the max connections (CONN_MAX) matches that defined
152            // in HostResources...
153            warn!("[server] unable to obtain CCCD slot");
154            Err(Error::ConnectionLimitReached)
155        })
156    }
157
158    fn disconnect(&self, peer_identity: &Identity) {
159        self.state.lock(|n| {
160            let mut n = n.borrow_mut();
161            for (client, _) in n.iter_mut() {
162                if client.identity.match_identity(peer_identity) {
163                    client.is_connected = false;
164                    break;
165                }
166            }
167        })
168    }
169
170    fn get_value(&self, peer_identity: &Identity, cccd_handle: u16) -> Option<[u8; 2]> {
171        self.state.lock(|n| {
172            let n = n.borrow();
173            for (client, table) in n.iter() {
174                if client.identity.match_identity(peer_identity) {
175                    return table.get_raw(cccd_handle);
176                }
177            }
178            None
179        })
180    }
181
182    fn set_notify(&self, peer_identity: &Identity, cccd_handle: u16, is_enabled: bool) {
183        self.state.lock(|n| {
184            let mut n = n.borrow_mut();
185            for (client, table) in n.iter_mut() {
186                if client.identity.match_identity(peer_identity) {
187                    table.set_notify(cccd_handle, is_enabled);
188                    break;
189                }
190            }
191        })
192    }
193
194    fn should_notify(&self, peer_identity: &Identity, cccd_handle: u16) -> bool {
195        self.state.lock(|n| {
196            let n = n.borrow();
197            for (client, table) in n.iter() {
198                if client.identity.match_identity(peer_identity) {
199                    return table.should_notify(cccd_handle);
200                }
201            }
202            false
203        })
204    }
205
206    fn get_cccd_table(&self, peer_identity: &Identity) -> Option<CccdTable<CCCD_MAX>> {
207        self.state.lock(|n| {
208            let n = n.borrow();
209            for (client, table) in n.iter() {
210                if client.identity.match_identity(peer_identity) {
211                    return Some(table.clone());
212                }
213            }
214            None
215        })
216    }
217
218    fn set_cccd_table(&self, peer_identity: &Identity, table: CccdTable<CCCD_MAX>) {
219        self.state.lock(|n| {
220            let mut n = n.borrow_mut();
221            for (client, t) in n.iter_mut() {
222                if client.identity.match_identity(peer_identity) {
223                    trace!("Setting cccd table {:?} for {:?}", table, peer_identity);
224                    *t = table;
225                    break;
226                }
227            }
228        })
229    }
230
231    fn update_identity(&self, identity: Identity) -> Result<(), Error> {
232        self.state.lock(|n| {
233            let mut n = n.borrow_mut();
234            for (client, _) in n.iter_mut() {
235                if identity.match_identity(&client.identity) {
236                    client.set_identity(identity);
237                    return Ok(());
238                }
239            }
240            Err(Error::NotFound)
241        })
242    }
243}
244
245/// A GATT server capable of processing the GATT protocol using the provided table of attributes.
246pub struct AttributeServer<
247    'values,
248    M: RawMutex,
249    P: PacketPool,
250    const ATT_MAX: usize,
251    const CCCD_MAX: usize,
252    const CONN_MAX: usize,
253> {
254    att_table: AttributeTable<'values, M, ATT_MAX>,
255    cccd_tables: CccdTables<M, CCCD_MAX, CONN_MAX>,
256    _p: PhantomData<P>,
257}
258
259pub(crate) mod sealed {
260    use super::*;
261
262    pub trait DynamicAttributeServer<P: PacketPool> {
263        fn connect(&self, connection: &Connection<'_, P>) -> Result<(), Error>;
264        fn disconnect(&self, connection: &Connection<'_, P>);
265        fn process(
266            &self,
267            connection: &Connection<'_, P>,
268            packet: &AttClient,
269            rx: &mut [u8],
270        ) -> Result<Option<usize>, Error>;
271        fn should_notify(&self, connection: &Connection<'_, P>, cccd_handle: u16) -> bool;
272        fn set(&self, characteristic: u16, input: &[u8]) -> Result<(), Error>;
273        fn update_identity(&self, identity: Identity) -> Result<(), Error>;
274    }
275}
276
277/// Type erased attribute server
278pub trait DynamicAttributeServer<P: PacketPool>: sealed::DynamicAttributeServer<P> {}
279
280impl<M: RawMutex, P: PacketPool, const ATT_MAX: usize, const CCCD_MAX: usize, const CONN_MAX: usize>
281    DynamicAttributeServer<P> for AttributeServer<'_, M, P, ATT_MAX, CCCD_MAX, CONN_MAX>
282{
283}
284impl<M: RawMutex, P: PacketPool, const ATT_MAX: usize, const CCCD_MAX: usize, const CONN_MAX: usize>
285    sealed::DynamicAttributeServer<P> for AttributeServer<'_, M, P, ATT_MAX, CCCD_MAX, CONN_MAX>
286{
287    fn connect(&self, connection: &Connection<'_, P>) -> Result<(), Error> {
288        AttributeServer::connect(self, connection)
289    }
290
291    fn disconnect(&self, connection: &Connection<'_, P>) {
292        self.cccd_tables.disconnect(&connection.peer_identity());
293    }
294
295    fn process(
296        &self,
297        connection: &Connection<'_, P>,
298        packet: &AttClient,
299        rx: &mut [u8],
300    ) -> Result<Option<usize>, Error> {
301        let res = AttributeServer::process(self, connection, packet, rx)?;
302        Ok(res)
303    }
304
305    fn should_notify(&self, connection: &Connection<'_, P>, cccd_handle: u16) -> bool {
306        AttributeServer::should_notify(self, connection, cccd_handle)
307    }
308
309    fn set(&self, characteristic: u16, input: &[u8]) -> Result<(), Error> {
310        self.att_table.set_raw(characteristic, input)
311    }
312
313    fn update_identity(&self, identity: Identity) -> Result<(), Error> {
314        self.cccd_tables.update_identity(identity)
315    }
316}
317
318impl<'values, M: RawMutex, P: PacketPool, const ATT_MAX: usize, const CCCD_MAX: usize, const CONN_MAX: usize>
319    AttributeServer<'values, M, P, ATT_MAX, CCCD_MAX, CONN_MAX>
320{
321    /// Create a new instance of the AttributeServer
322    pub fn new(
323        att_table: AttributeTable<'values, M, ATT_MAX>,
324    ) -> AttributeServer<'values, M, P, ATT_MAX, CCCD_MAX, CONN_MAX> {
325        let cccd_tables = CccdTables::new(&att_table);
326        AttributeServer {
327            att_table,
328            cccd_tables,
329            _p: PhantomData,
330        }
331    }
332
333    pub(crate) fn connect(&self, connection: &Connection<'_, P>) -> Result<(), Error> {
334        self.cccd_tables.connect(&connection.peer_identity())
335    }
336
337    pub(crate) fn should_notify(&self, connection: &Connection<'_, P>, cccd_handle: u16) -> bool {
338        self.cccd_tables.should_notify(&connection.peer_identity(), cccd_handle)
339    }
340
341    fn read_attribute_data(
342        &self,
343        connection: &Connection<'_, P>,
344        offset: usize,
345        att: &mut Attribute<'values>,
346        data: &mut [u8],
347    ) -> Result<usize, AttErrorCode> {
348        if let AttributeData::Cccd { .. } = att.data {
349            // CCCD values for each connected client are held in the CCCD tables:
350            // the value is written back into att.data so att.read() has the final
351            // say when parsing at the requested offset.
352            if let Some(value) = self.cccd_tables.get_value(&connection.peer_identity(), att.handle) {
353                let _ = att.write(0, value.as_slice());
354            }
355        }
356        att.read(offset, data)
357    }
358
359    fn write_attribute_data(
360        &self,
361        connection: &Connection<'_, P>,
362        offset: usize,
363        att: &mut Attribute<'values>,
364        data: &[u8],
365    ) -> Result<(), AttErrorCode> {
366        let err = att.write(offset, data);
367        if err.is_ok() {
368            if let AttributeData::Cccd {
369                notifications,
370                indications,
371            } = att.data
372            {
373                self.cccd_tables
374                    .set_notify(&connection.peer_identity(), att.handle, notifications);
375            }
376        }
377        err
378    }
379
380    fn handle_read_by_type_req(
381        &self,
382        connection: &Connection<'_, P>,
383        buf: &mut [u8],
384        start: u16,
385        end: u16,
386        attribute_type: &Uuid,
387    ) -> Result<usize, codec::Error> {
388        let mut handle = start;
389        let mut data = WriteCursor::new(buf);
390
391        let (mut header, mut body) = data.split(2)?;
392        let err = self.att_table.iterate(|mut it| {
393            let mut ret = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
394            while let Some(att) = it.next() {
395                // trace!("[read_by_type] Check attribute {:?} {}", att.uuid, att.handle);
396                if &att.uuid == attribute_type && att.handle >= start && att.handle <= end {
397                    body.write(att.handle)?;
398                    handle = att.handle;
399
400                    let new_ret = self.read_attribute_data(connection, 0, att, body.write_buf());
401                    match (new_ret, ret) {
402                        (Ok(first_length), Err(_)) => {
403                            // First successful read, store this length, all subsequent ones must match it.
404                            // debug!("[read_by_type] found first entry {:x?}, handle {}", att.uuid, handle);
405                            ret = new_ret;
406                            body.commit(first_length)?;
407                        }
408                        (Ok(new_length), Ok(old_length)) => {
409                            // Any matching attribute after the first, verify the lengths are identical, if not break.
410                            if new_length == old_length {
411                                // debug!("[read_by_type] found equal length {}, handle {}", new_length, handle);
412                                body.commit(new_length)?;
413                            } else {
414                                // We encountered a different length,  unwind the handle.
415                                // debug!("[read_by_type] different length: {}, old: {}", new_length, old_length);
416                                body.truncate(body.len() - 2);
417                                // And then break to ensure we respond with the previously found entries.
418                                break;
419                            }
420                        }
421                        (Err(error_code), Ok(_old_length)) => {
422                            // New read failed, but we had a previous value, return what we had thus far, truncate to
423                            // remove the previously written handle.
424                            body.truncate(body.len() - 2);
425                            // We do silently drop the error here.
426                            // debug!("[read_by_group] new error: {:?}, returning result thus far", error_code);
427                            break;
428                        }
429                        (Err(_), Err(_)) => {
430                            // Error on the first possible read, return this error.
431                            ret = new_ret;
432                            break;
433                        }
434                    }
435                    // If we get here, we always have had a successful read, and we can check that we still have space
436                    // left in the buffer to write the next entry if it exists.
437                    if let Ok(expected_length) = ret {
438                        if body.available() < expected_length + 2 {
439                            break;
440                        }
441                    }
442                }
443            }
444            ret
445        });
446
447        match err {
448            Ok(len) => {
449                header.write(att::ATT_READ_BY_TYPE_RSP)?;
450                header.write(2 + len as u8)?;
451                Ok(header.len() + body.len())
452            }
453            Err(e) => Ok(Self::error_response(data, att::ATT_READ_BY_TYPE_REQ, handle, e)?),
454        }
455    }
456
457    fn handle_read_by_group_type_req(
458        &self,
459        connection: &Connection<'_, P>,
460        buf: &mut [u8],
461        start: u16,
462        end: u16,
463        group_type: &Uuid,
464    ) -> Result<usize, codec::Error> {
465        let mut handle = start;
466        let mut data = WriteCursor::new(buf);
467        let (mut header, mut body) = data.split(2)?;
468        // Multiple entries can be returned in the response as long as they are of equal length.
469        let err = self.att_table.iterate(|mut it| {
470            // ret either holds the length of the attribute, or the error code encountered.
471            let mut ret: Result<usize, AttErrorCode> = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
472            while let Some(att) = it.next() {
473                // trace!("[read_by_group] Check attribute {:x?} {}", att.uuid, att.handle);
474                if &att.uuid == group_type && att.handle >= start && att.handle <= end {
475                    // debug!("[read_by_group] found! {:x?} handle: {}", att.uuid, att.handle);
476                    handle = att.handle;
477
478                    body.write(att.handle)?;
479                    body.write(att.last_handle_in_group)?;
480                    let new_ret = self.read_attribute_data(connection, 0, att, body.write_buf());
481                    match (new_ret, ret) {
482                        (Ok(first_length), Err(_)) => {
483                            // First successful read, store this length, all subsequent ones must match it.
484                            // debug!("[read_by_group] found first entry {:x?}, handle {}", att.uuid, handle);
485                            ret = new_ret;
486                            body.commit(first_length)?;
487                        }
488                        (Ok(new_length), Ok(old_length)) => {
489                            // Any matching attribute after the first, verify the lengths are identical, if not break.
490                            if new_length == old_length {
491                                // debug!("[read_by_group] found equal length {}, handle {}", new_length, handle);
492                                body.commit(new_length)?;
493                            } else {
494                                // We encountered a different length,  unwind the handle and last_handle written.
495                                // debug!("[read_by_group] different length: {}, old: {}", new_length, old_length);
496                                body.truncate(body.len() - 4);
497                                // And then break to ensure we respond with the previously found entries.
498                                break;
499                            }
500                        }
501                        (Err(error_code), Ok(_old_length)) => {
502                            // New read failed, but we had a previous value, return what we had thus far, truncate to
503                            // remove the previously written handle and last handle.
504                            body.truncate(body.len() - 4);
505                            // We do silently drop the error here.
506                            // debug!("[read_by_group] new error: {:?}, returning result thus far", error_code);
507                            break;
508                        }
509                        (Err(_), Err(_)) => {
510                            // Error on the first possible read, return this error.
511                            ret = new_ret;
512                            break;
513                        }
514                    }
515                    // If we get here, we always have had a successful read, and we can check that we still have space
516                    // left in the buffer to write the next entry if it exists.
517                    if let Ok(expected_length) = ret {
518                        if body.available() < expected_length + 4 {
519                            break;
520                        }
521                    }
522                }
523            }
524            ret
525        });
526
527        match err {
528            Ok(len) => {
529                header.write(att::ATT_READ_BY_GROUP_TYPE_RSP)?;
530                header.write(4 + len as u8)?;
531                Ok(header.len() + body.len())
532            }
533            Err(e) => Ok(Self::error_response(data, att::ATT_READ_BY_GROUP_TYPE_REQ, handle, e)?),
534        }
535    }
536
537    fn handle_read_req(
538        &self,
539        connection: &Connection<'_, P>,
540        buf: &mut [u8],
541        handle: u16,
542    ) -> Result<usize, codec::Error> {
543        let mut data = WriteCursor::new(buf);
544
545        data.write(att::ATT_READ_RSP)?;
546
547        let err = self.att_table.iterate(|mut it| {
548            let mut err = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
549            while let Some(att) = it.next() {
550                if att.handle == handle {
551                    err = self.read_attribute_data(connection, 0, att, data.write_buf());
552                    if let Ok(len) = err {
553                        data.commit(len)?;
554                    }
555                    break;
556                }
557            }
558            err
559        });
560
561        match err {
562            Ok(_) => Ok(data.len()),
563            Err(e) => Ok(Self::error_response(data, att::ATT_READ_REQ, handle, e)?),
564        }
565    }
566
567    fn handle_write_cmd(
568        &self,
569        connection: &Connection<'_, P>,
570        buf: &mut [u8],
571        handle: u16,
572        data: &[u8],
573    ) -> Result<usize, codec::Error> {
574        self.att_table.iterate(|mut it| {
575            while let Some(att) = it.next() {
576                if att.handle == handle {
577                    // Write commands can't respond with an error.
578                    let _ = self.write_attribute_data(connection, 0, att, data);
579                    break;
580                }
581            }
582        });
583        Ok(0)
584    }
585
586    fn handle_write_req(
587        &self,
588        connection: &Connection<'_, P>,
589        buf: &mut [u8],
590        handle: u16,
591        data: &[u8],
592    ) -> Result<usize, codec::Error> {
593        let err = self.att_table.iterate(|mut it| {
594            let mut err = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
595            while let Some(att) = it.next() {
596                if att.handle == handle {
597                    err = self.write_attribute_data(connection, 0, att, data);
598                    break;
599                }
600            }
601            err
602        });
603
604        let mut w = WriteCursor::new(buf);
605        match err {
606            Ok(()) => {
607                w.write(att::ATT_WRITE_RSP)?;
608                Ok(w.len())
609            }
610            Err(e) => Ok(Self::error_response(w, att::ATT_WRITE_REQ, handle, e)?),
611        }
612    }
613
614    fn handle_find_type_value(
615        &self,
616        buf: &mut [u8],
617        start: u16,
618        end: u16,
619        attr_type: u16,
620        attr_value: &[u8],
621    ) -> Result<usize, codec::Error> {
622        let mut w = WriteCursor::new(buf);
623        let attr_type = Uuid::new_short(attr_type);
624
625        w.write(att::ATT_FIND_BY_TYPE_VALUE_RSP)?;
626        self.att_table.iterate(|mut it| {
627            while let Some(att) = it.next() {
628                if att.handle >= start && att.handle <= end && att.uuid == attr_type {
629                    if let AttributeData::Service { uuid } = &att.data {
630                        if uuid.as_raw() == attr_value {
631                            if w.available() < 4 + uuid.as_raw().len() {
632                                break;
633                            }
634                            w.write(att.handle)?;
635                            w.write(att.last_handle_in_group)?;
636                        }
637                    }
638                }
639            }
640            Ok::<(), codec::Error>(())
641        })?;
642        if w.len() > 1 {
643            Ok(w.len())
644        } else {
645            Ok(Self::error_response(
646                w,
647                att::ATT_FIND_BY_TYPE_VALUE_REQ,
648                start,
649                AttErrorCode::ATTRIBUTE_NOT_FOUND,
650            )?)
651        }
652    }
653
654    fn handle_find_information(&self, buf: &mut [u8], start: u16, end: u16) -> Result<usize, codec::Error> {
655        let mut w = WriteCursor::new(buf);
656
657        let (mut header, mut body) = w.split(2)?;
658
659        header.write(att::ATT_FIND_INFORMATION_RSP)?;
660        let mut t = 0;
661
662        self.att_table.iterate(|mut it| {
663            while let Some(att) = it.next() {
664                if att.handle >= start && att.handle <= end {
665                    if t == 0 {
666                        t = att.uuid.get_type();
667                    } else if t != att.uuid.get_type() {
668                        break;
669                    }
670                    body.write(att.handle)?;
671                    body.append(att.uuid.as_raw())?;
672                }
673            }
674            Ok::<(), codec::Error>(())
675        })?;
676        header.write(t)?;
677
678        if body.len() > 2 {
679            Ok(header.len() + body.len())
680        } else {
681            Ok(Self::error_response(
682                w,
683                att::ATT_FIND_INFORMATION_REQ,
684                start,
685                AttErrorCode::ATTRIBUTE_NOT_FOUND,
686            )?)
687        }
688    }
689
690    fn error_response(
691        mut w: WriteCursor<'_>,
692        opcode: u8,
693        handle: u16,
694        code: AttErrorCode,
695    ) -> Result<usize, codec::Error> {
696        w.reset();
697        w.write(att::ATT_ERROR_RSP)?;
698        w.write(opcode)?;
699        w.write(handle)?;
700        w.write(code)?;
701        Ok(w.len())
702    }
703
704    fn handle_prepare_write(
705        &self,
706        connection: &Connection<'_, P>,
707        buf: &mut [u8],
708        handle: u16,
709        offset: u16,
710        value: &[u8],
711    ) -> Result<usize, codec::Error> {
712        let mut w = WriteCursor::new(buf);
713        w.write(att::ATT_PREPARE_WRITE_RSP)?;
714        w.write(handle)?;
715        w.write(offset)?;
716
717        let err = self.att_table.iterate(|mut it| {
718            let mut err = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
719            while let Some(att) = it.next() {
720                if att.handle == handle {
721                    err = self.write_attribute_data(connection, offset as usize, att, value);
722                    w.append(value)?;
723                    break;
724                }
725            }
726            err
727        });
728
729        match err {
730            Ok(()) => Ok(w.len()),
731            Err(e) => Ok(Self::error_response(w, att::ATT_PREPARE_WRITE_REQ, handle, e)?),
732        }
733    }
734
735    fn handle_execute_write(&self, buf: &mut [u8], _flags: u8) -> Result<usize, codec::Error> {
736        let mut w = WriteCursor::new(buf);
737        w.write(att::ATT_EXECUTE_WRITE_RSP)?;
738        Ok(w.len())
739    }
740
741    fn handle_read_blob(
742        &self,
743        connection: &Connection<'_, P>,
744        buf: &mut [u8],
745        handle: u16,
746        offset: u16,
747    ) -> Result<usize, codec::Error> {
748        let mut w = WriteCursor::new(buf);
749        w.write(att::ATT_READ_BLOB_RSP)?;
750
751        let err = self.att_table.iterate(|mut it| {
752            let mut err = Err(AttErrorCode::ATTRIBUTE_NOT_FOUND);
753            while let Some(att) = it.next() {
754                if att.handle == handle {
755                    err = self.read_attribute_data(connection, offset as usize, att, w.write_buf());
756                    if let Ok(n) = err {
757                        w.commit(n)?;
758                    }
759                    break;
760                }
761            }
762            err
763        });
764
765        match err {
766            Ok(_) => Ok(w.len()),
767            Err(e) => Ok(Self::error_response(w, att::ATT_READ_BLOB_REQ, handle, e)?),
768        }
769    }
770
771    fn handle_read_multiple(&self, buf: &mut [u8], handles: &[u8]) -> Result<usize, codec::Error> {
772        let w = WriteCursor::new(buf);
773        Self::error_response(
774            w,
775            att::ATT_READ_MULTIPLE_REQ,
776            u16::from_le_bytes([handles[0], handles[1]]),
777            AttErrorCode::ATTRIBUTE_NOT_FOUND,
778        )
779    }
780
781    /// Process an event and produce a response if necessary
782    pub fn process(
783        &self,
784        connection: &Connection<'_, P>,
785        packet: &AttClient,
786        rx: &mut [u8],
787    ) -> Result<Option<usize>, codec::Error> {
788        let len = match packet {
789            AttClient::Request(AttReq::ReadByType {
790                start,
791                end,
792                attribute_type,
793            }) => self.handle_read_by_type_req(connection, rx, *start, *end, attribute_type)?,
794
795            AttClient::Request(AttReq::ReadByGroupType { start, end, group_type }) => {
796                self.handle_read_by_group_type_req(connection, rx, *start, *end, group_type)?
797            }
798            AttClient::Request(AttReq::FindInformation {
799                start_handle,
800                end_handle,
801            }) => self.handle_find_information(rx, *start_handle, *end_handle)?,
802
803            AttClient::Request(AttReq::Read { handle }) => self.handle_read_req(connection, rx, *handle)?,
804
805            AttClient::Command(AttCmd::Write { handle, data }) => {
806                self.handle_write_cmd(connection, rx, *handle, data)?;
807                0
808            }
809
810            AttClient::Request(AttReq::Write { handle, data }) => {
811                self.handle_write_req(connection, rx, *handle, data)?
812            }
813
814            AttClient::Request(AttReq::ExchangeMtu { mtu }) => 0, // Done outside,
815
816            AttClient::Request(AttReq::FindByTypeValue {
817                start_handle,
818                end_handle,
819                att_type,
820                att_value,
821            }) => self.handle_find_type_value(rx, *start_handle, *end_handle, *att_type, att_value)?,
822
823            AttClient::Request(AttReq::PrepareWrite { handle, offset, value }) => {
824                self.handle_prepare_write(connection, rx, *handle, *offset, value)?
825            }
826
827            AttClient::Request(AttReq::ExecuteWrite { flags }) => self.handle_execute_write(rx, *flags)?,
828
829            AttClient::Request(AttReq::ReadBlob { handle, offset }) => {
830                self.handle_read_blob(connection, rx, *handle, *offset)?
831            }
832
833            AttClient::Request(AttReq::ReadMultiple { handles }) => self.handle_read_multiple(rx, handles)?,
834
835            AttClient::Confirmation(_) => 0,
836        };
837        if len > 0 {
838            Ok(Some(len))
839        } else {
840            Ok(None)
841        }
842    }
843
844    /// Get a reference to the attribute table
845    pub fn table(&self) -> &AttributeTable<'values, M, ATT_MAX> {
846        &self.att_table
847    }
848
849    /// Get the CCCD table for a connection
850    pub fn get_cccd_table(&self, connection: &Connection<'_, P>) -> Option<CccdTable<CCCD_MAX>> {
851        self.cccd_tables.get_cccd_table(&connection.peer_identity())
852    }
853
854    /// Set the CCCD table for a connection
855    pub fn set_cccd_table(&self, connection: &Connection<'_, P>, table: CccdTable<CCCD_MAX>) {
856        self.cccd_tables.set_cccd_table(&connection.peer_identity(), table);
857    }
858}
859
860#[cfg(test)]
861mod tests {
862    use core::task::Poll;
863
864    use bt_hci::param::{AddrKind, BdAddr, ConnHandle, LeConnRole};
865    use embassy_sync::blocking_mutex::raw::NoopRawMutex;
866
867    use super::*;
868    use crate::connection_manager::tests::{setup, ADDR_1};
869    use crate::prelude::*;
870
871    #[test]
872    fn test_attribute_server_last_handle_of_group() {
873        // This test comes from a situation where a service had exactly 16 handles, this resulted in the
874        // last_handle_in_group field of the ReadByGroupType response was 16 aligned (96 to be exact), in this situation
875        // the next request will start at 96 + 1, which was one handle beyond the start of the next service.
876        //
877        // Snippet from the original failure mode:
878        // WARN  trouble_host::attribute_server] Looking for group: Uuid16([0, 28]) between 75 and 65535
879        // DEBUG trouble_host::attribute_server] [read_by_group] found! Uuid16([0, 28]) 80
880        // DEBUG trouble_host::attribute_server] last_handle_in_group: 96
881        // DEBUG trouble_host::attribute_server] read_attribute_data: Ok(16)
882        // TRACE trouble_host::host] [host] granted send packets = 1, len = 30
883        // TRACE trouble_host::host] [host] sent acl packet len = 26
884        // TRACE trouble_host::host] [host] inbound l2cap header channel = 4, fragment len = 7, total = 7
885        // INFO  main_ble::ble_bas_peripheral] [gatt-attclient]: ReadByGroupType { start: 97, end: 65535, group_type: Uuid16([0, 40]) }
886        // INFO  main_ble::ble_bas_peripheral] [gatt] other event
887        // WARN  trouble_host::attribute_server] Looking for group: Uuid16([0, 28]) between 97 and 65535
888        // WARN  trouble_host::attribute_server] [read_by_group] Dit not find attribute Uuid16([0, 28]) between 97  65535
889
890        // The request:
891        // INFO  main_ble::ble_bas_peripheral] [gatt-attclient]: ReadByGroupType { start: 97, end: 65535, group_type: Uuid16([0, 40]) }
892        // In trace, the "group_type: Uuid16([0, 40]) }" is decimal, so this becomes group type 0x2800, which is the
893        // primary service group.
894        let primary_service_group_type = Uuid::new_short(0x2800);
895
896        let _ = env_logger::try_init();
897        const MAX_ATTRIBUTES: usize = 1024;
898        const CONNECTIONS_MAX: usize = 3;
899        const CCCD_MAX: usize = 1024;
900        const L2CAP_CHANNELS_MAX: usize = 5;
901        type FacadeDummyType = [u8; 0];
902
903        // Instead of only checking the failure mode, we fuzz the length of the interior service to cross over several
904        // multiples of 16.
905        for interior_handle_count in 0..=64u8 {
906            debug!("Testing with interior handle count of {}", interior_handle_count);
907
908            // Create a new table.
909            let mut table: AttributeTable<'_, NoopRawMutex, { MAX_ATTRIBUTES }> = AttributeTable::new();
910
911            // Add a first service, contents don't really matter, but the issue doesn't manifest without this.
912            {
913                let svc = table.add_service(Service {
914                    uuid: Uuid::new_long([10; 16]).into(),
915                });
916            }
917
918            // Add an interior service that has a varying length.
919            {
920                let mut svc = table.add_service(Service {
921                    uuid: Uuid::new_long([0; 16]).into(),
922                });
923
924                for c in 0..interior_handle_count {
925                    let _service_instance = svc
926                        .add_characteristic_ro::<[u8; 2], _>(Uuid::new_long([c; 16]), &[0, 0])
927                        .build();
928                }
929            }
930            // Now add the service at the end, contents don't really matter.
931            {
932                table.add_service(Service {
933                    uuid: Uuid::new_long([8; 16]).into(),
934                });
935            }
936
937            // Print the table for debugging.
938            table.iterate(|mut it| {
939                while let Some(att) = it.next() {
940                    let handle = att.handle;
941                    let uuid = &att.uuid;
942                    trace!(
943                        "last_handle_in_group for 0x{:0>4x?}, 0x{:0>2x?}  0x{:0>2x?}",
944                        handle,
945                        uuid,
946                        att.last_handle_in_group
947                    );
948                }
949            });
950
951            // Create a server.
952            let server = AttributeServer::<_, DefaultPacketPool, MAX_ATTRIBUTES, CCCD_MAX, CONNECTIONS_MAX>::new(table);
953
954            // Create the connection manager.
955            let mgr = setup();
956
957            // Try to connect.
958            assert!(mgr.poll_accept(LeConnRole::Peripheral, &[], None).is_pending());
959            unwrap!(mgr.connect(
960                ConnHandle::new(0),
961                AddrKind::RANDOM,
962                BdAddr::new(ADDR_1),
963                LeConnRole::Peripheral
964            ));
965
966            if let Poll::Ready(conn_handle) = mgr.poll_accept(LeConnRole::Peripheral, &[], None) {
967                // We now have a connection, we can send the mocked requests to our attribute server.
968                let mut buffer = [0u8; 64];
969
970                let mut start = 0;
971                let end = u16::MAX;
972                // There are always three services that we should be able to discover.
973                for _ in 0..3 {
974                    let length = server
975                        .handle_read_by_group_type_req(
976                            &conn_handle,
977                            &mut buffer,
978                            start,
979                            end,
980                            &primary_service_group_type,
981                        )
982                        .unwrap();
983                    let response = &buffer[0..length];
984                    trace!("  0x{:0>2x?}", response);
985                    // It should be a successful response, because the service should be found, this will assert if
986                    // we failed to retrieve the third service.
987                    assert_eq!(response[0], att::ATT_READ_BY_GROUP_TYPE_RSP);
988                    // The last handle of this group is at byte 4 & 5, so retrieve that and update the start for the
989                    // next cycle. We only check the first response here, and ignore any others that may be in the
990                    // response.
991                    let last_handle = u16::from_le_bytes([response[4], response[5]]);
992                    start = last_handle + 1;
993                }
994            } else {
995                panic!("expected connection to be accepted");
996            };
997        }
998    }
999}