rust_mqtt/client/
raw_client.rs

1use embedded_io::ReadReady;
2use embedded_io_async::{Read, Write};
3use heapless::Vec;
4use rand_core::RngCore;
5
6use crate::{
7    encoding::variable_byte_integer::{VariableByteInteger, VariableByteIntegerDecoder},
8    network::NetworkConnection,
9    packet::v5::{
10        connack_packet::ConnackPacket,
11        connect_packet::ConnectPacket,
12        disconnect_packet::DisconnectPacket,
13        mqtt_packet::Packet,
14        packet_type::PacketType,
15        pingreq_packet::PingreqPacket,
16        pingresp_packet::PingrespPacket,
17        puback_packet::PubackPacket,
18        publish_packet::{PublishPacket, QualityOfService},
19        reason_codes::ReasonCode,
20        suback_packet::SubackPacket,
21        subscription_packet::SubscriptionPacket,
22        unsuback_packet::UnsubackPacket,
23        unsubscription_packet::UnsubscriptionPacket,
24    },
25    utils::{buffer_reader::BuffReader, buffer_writer::BuffWriter, types::BufferError},
26};
27
28use super::client_config::{ClientConfig, MqttVersion};
29
30pub enum Event<'a> {
31    Connack,
32    /// event.0 is Packet identifier, event.1 is true if there were matching subscribers, false otherwise
33    Puback(u16, bool),
34    Suback(u16),
35    Unsuback(u16),
36    Pingresp,
37    Message(&'a str, &'a [u8]),
38    Disconnect(ReasonCode),
39}
40
41pub struct RawMqttClient<'a, T, const MAX_PROPERTIES: usize, R: RngCore>
42where
43    T: Read + Write,
44{
45    connection: Option<NetworkConnection<T>>,
46    buffer: &'a mut [u8],
47    buffer_len: usize,
48    recv_buffer: &'a mut [u8],
49    recv_buffer_len: usize,
50    config: ClientConfig<'a, MAX_PROPERTIES, R>,
51}
52
53impl<'a, T, const MAX_PROPERTIES: usize, R> RawMqttClient<'a, T, MAX_PROPERTIES, R>
54where
55    T: Read + Write,
56    R: RngCore,
57{
58    pub fn new(
59        network_driver: T,
60        buffer: &'a mut [u8],
61        buffer_len: usize,
62        recv_buffer: &'a mut [u8],
63        recv_buffer_len: usize,
64        config: ClientConfig<'a, MAX_PROPERTIES, R>,
65    ) -> Self {
66        Self {
67            connection: Some(NetworkConnection::new(network_driver)),
68            buffer,
69            buffer_len,
70            recv_buffer,
71            recv_buffer_len,
72            config,
73        }
74    }
75
76    async fn connect_to_broker_v5<'b>(&'b mut self) -> Result<(), ReasonCode> {
77        if self.connection.is_none() {
78            return Err(ReasonCode::NetworkError);
79        }
80        let len = {
81            let mut connect = ConnectPacket::<'b, MAX_PROPERTIES, 0>::new();
82            connect.keep_alive = self.config.keep_alive;
83            self.config.add_max_packet_size_as_prop();
84            connect.property_len = connect.add_properties(&self.config.properties);
85            if self.config.username_flag {
86                connect.add_username(&self.config.username);
87            }
88            if self.config.password_flag {
89                connect.add_password(&self.config.password)
90            }
91            if self.config.will_flag {
92                connect.add_will(
93                    &self.config.will_topic,
94                    &self.config.will_payload,
95                    self.config.will_retain,
96                )
97            }
98            connect.add_client_id(&self.config.client_id);
99            connect.encode(self.buffer, self.buffer_len)
100        };
101
102        if let Err(err) = len {
103            error!("[DECODE ERR]: {}", err);
104            return Err(ReasonCode::BuffError);
105        }
106        let conn = self.connection.as_mut().unwrap();
107        trace!("Sending connect");
108        conn.send(&self.buffer[0..len.unwrap()]).await?;
109
110        Ok(())
111    }
112
113    /// Method allows client connect to server. Client is connecting to the specified broker
114    /// in the `ClientConfig`. Method selects proper implementation of the MQTT version based on the config.
115    /// If the connection to the broker fails, method returns Err variable that contains
116    /// Reason codes returned from the broker.
117    pub async fn connect_to_broker<'b>(&'b mut self) -> Result<(), ReasonCode> {
118        match self.config.mqtt_version {
119            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
120            MqttVersion::MQTTv5 => self.connect_to_broker_v5().await,
121        }
122    }
123
124    async fn disconnect_v5<'b>(&'b mut self) -> Result<(), ReasonCode> {
125        if self.connection.is_none() {
126            return Err(ReasonCode::NetworkError);
127        }
128        let conn = self.connection.as_mut().unwrap();
129        trace!("Creating disconnect packet!");
130        let mut disconnect = DisconnectPacket::<'b, MAX_PROPERTIES>::new();
131        let len = disconnect.encode(self.buffer, self.buffer_len);
132        if let Err(err) = len {
133            warn!("[DECODE ERR]: {}", err);
134            let _ = self.connection.take();
135            return Err(ReasonCode::BuffError);
136        }
137
138        if let Err(_e) = conn.send(&self.buffer[0..len.unwrap()]).await {
139            warn!("Could not send DISCONNECT packet");
140        }
141
142        // Drop connection
143        let _ = self.connection.take();
144        Ok(())
145    }
146
147    /// Method allows client disconnect from the server. Client disconnects from the specified broker
148    /// in the `ClientConfig`. Method selects proper implementation of the MQTT version based on the config.
149    /// If the disconnect from the broker fails, method returns Err variable that contains
150    /// Reason codes returned from the broker.
151    pub async fn disconnect<'b>(&'b mut self) -> Result<(), ReasonCode> {
152        match self.config.mqtt_version {
153            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
154            MqttVersion::MQTTv5 => self.disconnect_v5().await,
155        }
156    }
157
158    async fn send_message_v5<'b>(
159        &'b mut self,
160        topic_name: &'b str,
161        message: &'b [u8],
162        qos: QualityOfService,
163        retain: bool,
164    ) -> Result<u16, ReasonCode> {
165        if self.connection.is_none() {
166            return Err(ReasonCode::NetworkError);
167        }
168        let conn = self.connection.as_mut().unwrap();
169        let identifier: u16 = self.config.rng.next_u32() as u16;
170        //self.rng.next_u32() as u16;
171        let len = {
172            let mut packet = PublishPacket::<'b, MAX_PROPERTIES>::new();
173            packet.add_topic_name(topic_name);
174            packet.add_qos(qos);
175            packet.add_identifier(identifier);
176            packet.add_message(message);
177            packet.add_retain(retain);
178            packet.encode(self.buffer, self.buffer_len)
179        };
180
181        if let Err(err) = len {
182            error!("[DECODE ERR]: {}", err);
183            return Err(ReasonCode::BuffError);
184        }
185        trace!("Sending message");
186        conn.send(&self.buffer[0..len.unwrap()]).await?;
187
188        Ok(identifier)
189    }
190    /// Method allows sending message to broker specified from the ClientConfig. Client sends the
191    /// message from the parameter `message` to the topic `topic_name` on the broker
192    /// specified in the ClientConfig. If the send fails method returns Err with reason code
193    /// received by broker.
194    pub async fn send_message<'b>(
195        &'b mut self,
196        topic_name: &'b str,
197        message: &'b [u8],
198        qos: QualityOfService,
199        retain: bool,
200    ) -> Result<u16, ReasonCode> {
201        match self.config.mqtt_version {
202            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
203            MqttVersion::MQTTv5 => self.send_message_v5(topic_name, message, qos, retain).await,
204        }
205    }
206
207    async fn subscribe_to_topics_v5<'b, const TOPICS: usize>(
208        &'b mut self,
209        topic_names: &'b Vec<&'b str, TOPICS>,
210    ) -> Result<u16, ReasonCode> {
211        if self.connection.is_none() {
212            return Err(ReasonCode::NetworkError);
213        }
214        let conn = self.connection.as_mut().unwrap();
215        let identifier: u16 = self.config.rng.next_u32() as u16;
216        let len = {
217            let mut subs = SubscriptionPacket::<'b, TOPICS, MAX_PROPERTIES>::new();
218            subs.packet_identifier = identifier;
219            for topic_name in topic_names.iter() {
220                subs.add_new_filter(topic_name, self.config.max_subscribe_qos);
221            }
222            subs.encode(self.buffer, self.buffer_len)
223        };
224
225        if let Err(err) = len {
226            error!("[DECODE ERR]: {}", err);
227            return Err(ReasonCode::BuffError);
228        }
229
230        conn.send(&self.buffer[0..len.unwrap()]).await?;
231
232        Ok(identifier)
233    }
234
235    /// Method allows client subscribe to multiple topics specified in the parameter
236    /// `topic_names` on the broker specified in the `ClientConfig`. Generics `TOPICS`
237    /// sets the value of the `topics_names` vector. MQTT protocol implementation
238    /// is selected automatically.
239    pub async fn subscribe_to_topics<'b, const TOPICS: usize>(
240        &'b mut self,
241        topic_names: &'b Vec<&'b str, TOPICS>,
242    ) -> Result<u16, ReasonCode> {
243        match self.config.mqtt_version {
244            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
245            MqttVersion::MQTTv5 => self.subscribe_to_topics_v5(topic_names).await,
246        }
247    }
248
249    /// Method allows client unsubscribe from the topic specified in the parameter
250    /// `topic_name` on the broker from the `ClientConfig`. MQTT protocol implementation
251    /// is selected automatically.
252    pub async fn unsubscribe_from_topic<'b>(
253        &'b mut self,
254        topic_name: &'b str,
255    ) -> Result<u16, ReasonCode> {
256        match self.config.mqtt_version {
257            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
258            MqttVersion::MQTTv5 => self.unsubscribe_from_topic_v5(topic_name).await,
259        }
260    }
261
262    async fn unsubscribe_from_topic_v5<'b>(
263        &'b mut self,
264        topic_name: &'b str,
265    ) -> Result<u16, ReasonCode> {
266        if self.connection.is_none() {
267            return Err(ReasonCode::NetworkError);
268        }
269        let conn = self.connection.as_mut().unwrap();
270        let identifier = self.config.rng.next_u32() as u16;
271
272        let len = {
273            let mut unsub = UnsubscriptionPacket::<'b, 1, MAX_PROPERTIES>::new();
274            unsub.packet_identifier = identifier;
275            unsub.add_new_filter(topic_name);
276            unsub.encode(self.buffer, self.buffer_len)
277        };
278
279        if let Err(err) = len {
280            error!("[DECODE ERR]: {}", err);
281            return Err(ReasonCode::BuffError);
282        }
283        conn.send(&self.buffer[0..len.unwrap()]).await?;
284
285        Ok(identifier)
286    }
287
288    async fn send_ping_v5<'b>(&'b mut self) -> Result<(), ReasonCode> {
289        if self.connection.is_none() {
290            return Err(ReasonCode::NetworkError);
291        }
292        let conn = self.connection.as_mut().unwrap();
293        let len = {
294            let mut packet = PingreqPacket::new();
295            packet.encode(self.buffer, self.buffer_len)
296        };
297
298        if let Err(err) = len {
299            error!("[DECODE ERR]: {}", err);
300            return Err(ReasonCode::BuffError);
301        }
302
303        conn.send(&self.buffer[0..len.unwrap()]).await?;
304
305        Ok(())
306    }
307
308    /// Method allows client send PING message to the broker specified in the `ClientConfig`.
309    /// If there is expectation for long running connection. Method should be executed
310    /// regularly by the timer that counts down the session expiry interval.
311    pub async fn send_ping<'b>(&'b mut self) -> Result<(), ReasonCode> {
312        match self.config.mqtt_version {
313            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
314            MqttVersion::MQTTv5 => self.send_ping_v5().await,
315        }
316    }
317
318    pub async fn poll<'b, const MAX_TOPICS: usize>(&'b mut self) -> Result<Event<'b>, ReasonCode> {
319        if self.connection.is_none() {
320            return Err(ReasonCode::NetworkError);
321        }
322
323        let conn = self.connection.as_mut().unwrap();
324
325        trace!("Waiting for a packet");
326
327        let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? };
328
329        let buf_reader = BuffReader::new(self.buffer, read);
330
331        match PacketType::from(buf_reader.peek_u8().map_err(|_| ReasonCode::BuffError)?) {
332            PacketType::Reserved
333            | PacketType::Connect
334            | PacketType::Subscribe
335            | PacketType::Unsubscribe
336            | PacketType::Pingreq => Err(ReasonCode::ProtocolError),
337            PacketType::Pubrec | PacketType::Pubrel | PacketType::Pubcomp | PacketType::Auth => {
338                Err(ReasonCode::ImplementationSpecificError)
339            }
340            PacketType::Connack => {
341                let mut packet = ConnackPacket::<'b, MAX_PROPERTIES>::new();
342                if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
343                    // if err == BufferError::PacketTypeMismatch {
344                    //     let mut disc = DisconnectPacket::<'b, MAX_PROPERTIES>::new();
345                    //     if disc.decode(&mut BuffReader::new(self.buffer, read)).is_ok() {
346                    //         error!("Client was disconnected with reason: ");
347                    //         return Err(ReasonCode::from(disc.disconnect_reason));
348                    //     }
349                    // }
350                    error!("[DECODE ERR]: {}", err);
351                    Err(ReasonCode::BuffError)
352                } else if packet.connect_reason_code != 0x00 {
353                    Err(ReasonCode::from(packet.connect_reason_code))
354                } else {
355                    Ok(Event::Connack)
356                }
357            }
358            PacketType::Puback => {
359                let result = {
360                    let mut packet = PubackPacket::<'b, MAX_PROPERTIES>::new();
361                    packet
362                        .decode(&mut BuffReader::new(self.buffer, read))
363                        .map(|_| {
364                            (
365                                packet.packet_identifier,
366                                ReasonCode::from(packet.reason_code),
367                            )
368                        })
369                };
370
371                match result {
372                    Ok((id, ReasonCode::Success)) => Ok(Event::Puback(id, true)),
373                    Ok((id, ReasonCode::NoMatchingSubscribers)) => Ok(Event::Puback(id, false)),
374                    Ok((_id, err)) => Err(err),
375                    Err(err) => {
376                        error!("[DECODE ERR]: {}", err);
377                        Err(ReasonCode::BuffError)
378                    }
379                }
380            }
381            PacketType::Suback => {
382                let reason: Result<(u16, Vec<u8, MAX_TOPICS>), BufferError> = {
383                    let mut packet = SubackPacket::<'b, MAX_TOPICS, MAX_PROPERTIES>::new();
384                    packet
385                        .decode(&mut BuffReader::new(self.buffer, read))
386                        .map(|_| (packet.packet_identifier, packet.reason_codes))
387                };
388
389                if let Err(err) = reason {
390                    error!("[DECODE ERR]: {}", err);
391                    return Err(ReasonCode::BuffError);
392                }
393                let (packet_identifier, reasons) = reason.unwrap();
394                for reason_code in &reasons {
395                    if *reason_code
396                        != (<QualityOfService as Into<u8>>::into(self.config.max_subscribe_qos)
397                            >> 1)
398                    {
399                        return Err(ReasonCode::from(*reason_code));
400                    }
401                }
402                Ok(Event::Suback(packet_identifier))
403            }
404            PacketType::Unsuback => {
405                let res: Result<u16, BufferError> = {
406                    let mut packet = UnsubackPacket::<'b, 1, MAX_PROPERTIES>::new();
407                    packet
408                        .decode(&mut BuffReader::new(self.buffer, read))
409                        .map(|_| packet.packet_identifier)
410                };
411
412                if let Err(err) = res {
413                    error!("[DECODE ERR]: {}", err);
414                    Err(ReasonCode::BuffError)
415                } else {
416                    Ok(Event::Unsuback(res.unwrap()))
417                }
418            }
419            PacketType::Pingresp => {
420                let mut packet = PingrespPacket::new();
421                if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
422                    error!("[DECODE ERR]: {}", err);
423                    Err(ReasonCode::BuffError)
424                } else {
425                    Ok(Event::Pingresp)
426                }
427            }
428            PacketType::Publish => {
429                let mut packet = PublishPacket::<'b, 5>::new();
430                if let Err(err) = { packet.decode(&mut BuffReader::new(self.buffer, read)) } {
431                    // if err == BufferError::PacketTypeMismatch {
432                    //     let mut disc = DisconnectPacket::<'b, 5>::new();
433                    //     if disc.decode(&mut BuffReader::new(self.buffer, read)).is_ok() {
434                    //         error!("Client was disconnected with reason: ");
435                    //         return Err(ReasonCode::from(disc.disconnect_reason));
436                    //     }
437                    // }
438                    error!("[DECODE ERR]: {}", err);
439                    return Err(ReasonCode::BuffError);
440                }
441
442                if (packet.fixed_header & 0x06)
443                    == <QualityOfService as Into<u8>>::into(QualityOfService::QoS1)
444                {
445                    let mut puback = PubackPacket::<'b, MAX_PROPERTIES>::new();
446                    puback.packet_identifier = packet.packet_identifier;
447                    puback.reason_code = 0x00;
448                    {
449                        let len = { puback.encode(self.recv_buffer, self.recv_buffer_len) };
450                        if let Err(err) = len {
451                            error!("[DECODE ERR]: {}", err);
452                            return Err(ReasonCode::BuffError);
453                        }
454                        conn.send(&self.recv_buffer[0..len.unwrap()]).await?;
455                    }
456                }
457
458                Ok(Event::Message(
459                    packet.topic_name.string,
460                    packet.message.unwrap(),
461                ))
462            }
463            PacketType::Disconnect => {
464                let mut disc = DisconnectPacket::<'b, 5>::new();
465                let res = disc.decode(&mut BuffReader::new(self.buffer, read));
466
467                match res {
468                    Ok(_) => Ok(Event::Disconnect(ReasonCode::from(disc.disconnect_reason))),
469                    Err(err) => {
470                        error!("[DECODE ERR]: {}", err);
471                        Err(ReasonCode::BuffError)
472                    }
473                }
474            }
475        }
476    }
477}
478
479impl<'a, T, const MAX_PROPERTIES: usize, R> RawMqttClient<'a, T, MAX_PROPERTIES, R>
480where
481    T: Read + Write + ReadReady,
482    R: RngCore,
483{
484    pub async fn poll_if_ready<'b, const MAX_TOPICS: usize>(
485        &'b mut self,
486    ) -> Result<Option<Event<'b>>, ReasonCode> {
487        if self.connection.is_none() {
488            return Err(ReasonCode::NetworkError);
489        }
490
491        let conn = self.connection.as_mut().unwrap();
492
493        // If there's no data, just return None
494        if !conn.receive_ready()? {
495            Ok(None)
496        } else {
497            self.poll::<MAX_TOPICS>().await.map(Some)
498        }
499    }
500}
501
502#[cfg(not(feature = "tls"))]
503async fn receive_packet<'c, T: Read + Write>(
504    buffer: &mut [u8],
505    buffer_len: usize,
506    recv_buffer: &mut [u8],
507    conn: &'c mut NetworkConnection<T>,
508) -> Result<usize, ReasonCode> {
509    use crate::utils::buffer_writer::RemLenError;
510
511    let target_len: usize;
512    let mut rem_len: Result<VariableByteInteger, RemLenError>;
513    let mut writer = BuffWriter::new(buffer, buffer_len);
514    let mut i = 0;
515
516    // Get len of packet
517    trace!("Reading lenght of packet");
518    loop {
519        trace!("    Reading in loop!");
520        let len: usize = conn
521            .receive(&mut recv_buffer[writer.position..(writer.position + 1)])
522            .await?;
523        trace!("    Received data!");
524        if len == 0 {
525            trace!("Zero byte len packet received, dropping connection.");
526            return Err(ReasonCode::NetworkError);
527        }
528        i += len;
529        if let Err(_e) = writer.insert_ref(len, &recv_buffer[writer.position..i]) {
530            error!("Error occurred during write to buffer!");
531            return Err(ReasonCode::BuffError);
532        }
533        if i > 1 {
534            rem_len = writer.get_rem_len();
535            if rem_len.is_ok() {
536                break;
537            }
538            if i >= 5 {
539                error!("Could not read len of packet!");
540                return Err(ReasonCode::NetworkError);
541            }
542        }
543    }
544    trace!("Lenght done!");
545    let rem_len_len = i;
546    i = 0;
547    if let Ok(l) = VariableByteIntegerDecoder::decode(rem_len.unwrap()) {
548        trace!("Reading packet with target len {}", l);
549        target_len = l as usize;
550    } else {
551        error!("Could not decode len of packet!");
552        return Err(ReasonCode::BuffError);
553    }
554
555    loop {
556        if writer.position == target_len + rem_len_len {
557            trace!("Received packet with len: {}", (target_len + rem_len_len));
558            return Ok(target_len + rem_len_len);
559        }
560        let len: usize = conn
561            .receive(&mut recv_buffer[writer.position..writer.position + (target_len - i)])
562            .await?;
563        i += len;
564        if let Err(_e) =
565            writer.insert_ref(len, &recv_buffer[writer.position..(writer.position + len)])
566        {
567            error!("Error occurred during write to buffer!");
568            return Err(ReasonCode::BuffError);
569        }
570    }
571}
572
573#[cfg(feature = "tls")]
574async fn receive_packet<'c, T: Read + Write>(
575    buffer: &mut [u8],
576    buffer_len: usize,
577    recv_buffer: &mut [u8],
578    conn: &'c mut NetworkConnection<T>,
579) -> Result<usize, ReasonCode> {
580    trace!("Reading packet");
581    let mut writer = BuffWriter::new(buffer, buffer_len);
582    let len = conn.receive(recv_buffer).await?;
583    if let Err(_e) = writer.insert_ref(len, &recv_buffer[writer.position..(writer.position + len)])
584    {
585        error!("Error occurred during write to buffer!");
586        return Err(ReasonCode::BuffError);
587    }
588    Ok(len)
589}