rust_mqtt/client/
raw_client.rs

1use embedded_io_async::{Read, Write};
2use heapless::Vec;
3use rand_core::RngCore;
4
5use crate::{
6    encoding::variable_byte_integer::{VariableByteInteger, VariableByteIntegerDecoder},
7    network::NetworkConnection,
8    packet::v5::{
9        connack_packet::ConnackPacket,
10        connect_packet::ConnectPacket,
11        disconnect_packet::DisconnectPacket,
12        mqtt_packet::Packet,
13        packet_type::PacketType,
14        pingreq_packet::PingreqPacket,
15        pingresp_packet::PingrespPacket,
16        puback_packet::PubackPacket,
17        publish_packet::{PublishPacket, QualityOfService},
18        reason_codes::ReasonCode,
19        suback_packet::SubackPacket,
20        subscription_packet::SubscriptionPacket,
21        unsuback_packet::UnsubackPacket,
22        unsubscription_packet::UnsubscriptionPacket,
23    },
24    utils::{buffer_reader::BuffReader, buffer_writer::BuffWriter, types::BufferError},
25};
26
27use super::client_config::{ClientConfig, MqttVersion};
28
29pub enum Event<'a> {
30    Connack,
31    Puback(u16),
32    Suback(u16),
33    Unsuback(u16),
34    Pingresp,
35    Message(&'a str, &'a [u8]),
36    Disconnect(ReasonCode),
37}
38
39pub struct RawMqttClient<'a, T, const MAX_PROPERTIES: usize, R: RngCore>
40where
41    T: Read + Write,
42{
43    connection: Option<NetworkConnection<T>>,
44    buffer: &'a mut [u8],
45    buffer_len: usize,
46    recv_buffer: &'a mut [u8],
47    recv_buffer_len: usize,
48    config: ClientConfig<'a, MAX_PROPERTIES, R>,
49}
50
51impl<'a, T, const MAX_PROPERTIES: usize, R> RawMqttClient<'a, T, MAX_PROPERTIES, R>
52where
53    T: Read + Write,
54    R: RngCore,
55{
56    pub fn new(
57        network_driver: T,
58        buffer: &'a mut [u8],
59        buffer_len: usize,
60        recv_buffer: &'a mut [u8],
61        recv_buffer_len: usize,
62        config: ClientConfig<'a, MAX_PROPERTIES, R>,
63    ) -> Self {
64        Self {
65            connection: Some(NetworkConnection::new(network_driver)),
66            buffer,
67            buffer_len,
68            recv_buffer,
69            recv_buffer_len,
70            config,
71        }
72    }
73
74    async fn connect_to_broker_v5<'b>(&'b mut self) -> Result<(), ReasonCode> {
75        if self.connection.is_none() {
76            return Err(ReasonCode::NetworkError);
77        }
78        let len = {
79            let mut connect = ConnectPacket::<'b, MAX_PROPERTIES, 0>::new();
80            connect.keep_alive = self.config.keep_alive;
81            self.config.add_max_packet_size_as_prop();
82            connect.property_len = connect.add_properties(&self.config.properties);
83            if self.config.username_flag {
84                connect.add_username(&self.config.username);
85            }
86            if self.config.password_flag {
87                connect.add_password(&self.config.password)
88            }
89            if self.config.will_flag {
90                connect.add_will(
91                    &self.config.will_topic,
92                    &self.config.will_payload,
93                    self.config.will_retain,
94                )
95            }
96            connect.add_client_id(&self.config.client_id);
97            connect.encode(self.buffer, self.buffer_len)
98        };
99
100        if let Err(err) = len {
101            error!("[DECODE ERR]: {}", err);
102            return Err(ReasonCode::BuffError);
103        }
104        let conn = self.connection.as_mut().unwrap();
105        trace!("Sending connect");
106        conn.send(&self.buffer[0..len.unwrap()]).await?;
107
108        Ok(())
109    }
110
111    /// Method allows client connect to server. Client is connecting to the specified broker
112    /// in the `ClientConfig`. Method selects proper implementation of the MQTT version based on the config.
113    /// If the connection to the broker fails, method returns Err variable that contains
114    /// Reason codes returned from the broker.
115    pub async fn connect_to_broker<'b>(&'b mut self) -> Result<(), ReasonCode> {
116        match self.config.mqtt_version {
117            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
118            MqttVersion::MQTTv5 => self.connect_to_broker_v5().await,
119        }
120    }
121
122    async fn disconnect_v5<'b>(&'b mut self) -> Result<(), ReasonCode> {
123        if self.connection.is_none() {
124            return Err(ReasonCode::NetworkError);
125        }
126        let conn = self.connection.as_mut().unwrap();
127        trace!("Creating disconnect packet!");
128        let mut disconnect = DisconnectPacket::<'b, MAX_PROPERTIES>::new();
129        let len = disconnect.encode(self.buffer, self.buffer_len);
130        if let Err(err) = len {
131            warn!("[DECODE ERR]: {}", err);
132            let _ = self.connection.take();
133            return Err(ReasonCode::BuffError);
134        }
135
136        if let Err(_e) = conn.send(&self.buffer[0..len.unwrap()]).await {
137            warn!("Could not send DISCONNECT packet");
138        }
139
140        // Drop connection
141        let _ = self.connection.take();
142        Ok(())
143    }
144
145    /// Method allows client disconnect from the server. Client disconnects from the specified broker
146    /// in the `ClientConfig`. Method selects proper implementation of the MQTT version based on the config.
147    /// If the disconnect from the broker fails, method returns Err variable that contains
148    /// Reason codes returned from the broker.
149    pub async fn disconnect<'b>(&'b mut self) -> Result<(), ReasonCode> {
150        match self.config.mqtt_version {
151            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
152            MqttVersion::MQTTv5 => self.disconnect_v5().await,
153        }
154    }
155
156    async fn send_message_v5<'b>(
157        &'b mut self,
158        topic_name: &'b str,
159        message: &'b [u8],
160        qos: QualityOfService,
161        retain: bool,
162    ) -> Result<u16, ReasonCode> {
163        if self.connection.is_none() {
164            return Err(ReasonCode::NetworkError);
165        }
166        let conn = self.connection.as_mut().unwrap();
167        let identifier: u16 = self.config.rng.next_u32() as u16;
168        //self.rng.next_u32() as u16;
169        let len = {
170            let mut packet = PublishPacket::<'b, MAX_PROPERTIES>::new();
171            packet.add_topic_name(topic_name);
172            packet.add_qos(qos);
173            packet.add_identifier(identifier);
174            packet.add_message(message);
175            packet.add_retain(retain);
176            packet.encode(self.buffer, self.buffer_len)
177        };
178
179        if let Err(err) = len {
180            error!("[DECODE ERR]: {}", err);
181            return Err(ReasonCode::BuffError);
182        }
183        trace!("Sending message");
184        conn.send(&self.buffer[0..len.unwrap()]).await?;
185
186        Ok(identifier)
187    }
188    /// Method allows sending message to broker specified from the ClientConfig. Client sends the
189    /// message from the parameter `message` to the topic `topic_name` on the broker
190    /// specified in the ClientConfig. If the send fails method returns Err with reason code
191    /// received by broker.
192    pub async fn send_message<'b>(
193        &'b mut self,
194        topic_name: &'b str,
195        message: &'b [u8],
196        qos: QualityOfService,
197        retain: bool,
198    ) -> Result<u16, ReasonCode> {
199        match self.config.mqtt_version {
200            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
201            MqttVersion::MQTTv5 => self.send_message_v5(topic_name, message, qos, retain).await,
202        }
203    }
204
205    async fn subscribe_to_topics_v5<'b, const TOPICS: usize>(
206        &'b mut self,
207        topic_names: &'b Vec<&'b str, TOPICS>,
208    ) -> Result<u16, ReasonCode> {
209        if self.connection.is_none() {
210            return Err(ReasonCode::NetworkError);
211        }
212        let conn = self.connection.as_mut().unwrap();
213        let identifier: u16 = self.config.rng.next_u32() as u16;
214        let len = {
215            let mut subs = SubscriptionPacket::<'b, TOPICS, MAX_PROPERTIES>::new();
216            subs.packet_identifier = identifier;
217            for topic_name in topic_names.iter() {
218                subs.add_new_filter(topic_name, self.config.max_subscribe_qos);
219            }
220            subs.encode(self.buffer, self.buffer_len)
221        };
222
223        if let Err(err) = len {
224            error!("[DECODE ERR]: {}", err);
225            return Err(ReasonCode::BuffError);
226        }
227
228        conn.send(&self.buffer[0..len.unwrap()]).await?;
229
230        Ok(identifier)
231    }
232
233    /// Method allows client subscribe to multiple topics specified in the parameter
234    /// `topic_names` on the broker specified in the `ClientConfig`. Generics `TOPICS`
235    /// sets the value of the `topics_names` vector. MQTT protocol implementation
236    /// is selected automatically.
237    pub async fn subscribe_to_topics<'b, const TOPICS: usize>(
238        &'b mut self,
239        topic_names: &'b Vec<&'b str, TOPICS>,
240    ) -> Result<u16, ReasonCode> {
241        match self.config.mqtt_version {
242            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
243            MqttVersion::MQTTv5 => self.subscribe_to_topics_v5(topic_names).await,
244        }
245    }
246
247    /// Method allows client unsubscribe from the topic specified in the parameter
248    /// `topic_name` on the broker from the `ClientConfig`. MQTT protocol implementation
249    /// is selected automatically.
250    pub async fn unsubscribe_from_topic<'b>(
251        &'b mut self,
252        topic_name: &'b str,
253    ) -> Result<u16, ReasonCode> {
254        match self.config.mqtt_version {
255            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
256            MqttVersion::MQTTv5 => self.unsubscribe_from_topic_v5(topic_name).await,
257        }
258    }
259
260    async fn unsubscribe_from_topic_v5<'b>(
261        &'b mut self,
262        topic_name: &'b str,
263    ) -> Result<u16, ReasonCode> {
264        if self.connection.is_none() {
265            return Err(ReasonCode::NetworkError);
266        }
267        let conn = self.connection.as_mut().unwrap();
268        let identifier = self.config.rng.next_u32() as u16;
269
270        let len = {
271            let mut unsub = UnsubscriptionPacket::<'b, 1, MAX_PROPERTIES>::new();
272            unsub.packet_identifier = identifier;
273            unsub.add_new_filter(topic_name);
274            unsub.encode(self.buffer, self.buffer_len)
275        };
276
277        if let Err(err) = len {
278            error!("[DECODE ERR]: {}", err);
279            return Err(ReasonCode::BuffError);
280        }
281        conn.send(&self.buffer[0..len.unwrap()]).await?;
282
283        Ok(identifier)
284    }
285
286    async fn send_ping_v5<'b>(&'b mut self) -> Result<(), ReasonCode> {
287        if self.connection.is_none() {
288            return Err(ReasonCode::NetworkError);
289        }
290        let conn = self.connection.as_mut().unwrap();
291        let len = {
292            let mut packet = PingreqPacket::new();
293            packet.encode(self.buffer, self.buffer_len)
294        };
295
296        if let Err(err) = len {
297            error!("[DECODE ERR]: {}", err);
298            return Err(ReasonCode::BuffError);
299        }
300
301        conn.send(&self.buffer[0..len.unwrap()]).await?;
302
303        Ok(())
304    }
305
306    /// Method allows client send PING message to the broker specified in the `ClientConfig`.
307    /// If there is expectation for long running connection. Method should be executed
308    /// regularly by the timer that counts down the session expiry interval.
309    pub async fn send_ping<'b>(&'b mut self) -> Result<(), ReasonCode> {
310        match self.config.mqtt_version {
311            MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
312            MqttVersion::MQTTv5 => self.send_ping_v5().await,
313        }
314    }
315
316    pub async fn poll<'b, const MAX_TOPICS: usize>(&'b mut self) -> Result<Event<'b>, ReasonCode> {
317        if self.connection.is_none() {
318            return Err(ReasonCode::NetworkError);
319        }
320
321        let conn = self.connection.as_mut().unwrap();
322
323        trace!("Waiting for a packet");
324
325        let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? };
326
327        let buf_reader = BuffReader::new(self.buffer, read);
328
329        match PacketType::from(buf_reader.peek_u8().map_err(|_| ReasonCode::BuffError)?) {
330            PacketType::Reserved
331            | PacketType::Connect
332            | PacketType::Subscribe
333            | PacketType::Unsubscribe
334            | PacketType::Pingreq => Err(ReasonCode::ProtocolError),
335            PacketType::Pubrec | PacketType::Pubrel | PacketType::Pubcomp | PacketType::Auth => {
336                Err(ReasonCode::ImplementationSpecificError)
337            }
338            PacketType::Connack => {
339                let mut packet = ConnackPacket::<'b, MAX_PROPERTIES>::new();
340                if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
341                    // if err == BufferError::PacketTypeMismatch {
342                    //     let mut disc = DisconnectPacket::<'b, MAX_PROPERTIES>::new();
343                    //     if disc.decode(&mut BuffReader::new(self.buffer, read)).is_ok() {
344                    //         error!("Client was disconnected with reason: ");
345                    //         return Err(ReasonCode::from(disc.disconnect_reason));
346                    //     }
347                    // }
348                    error!("[DECODE ERR]: {}", err);
349                    Err(ReasonCode::BuffError)
350                } else if packet.connect_reason_code != 0x00 {
351                    Err(ReasonCode::from(packet.connect_reason_code))
352                } else {
353                    Ok(Event::Connack)
354                }
355            }
356            PacketType::Puback => {
357                let reason: Result<[u16; 2], BufferError> = {
358                    let mut packet = PubackPacket::<'b, MAX_PROPERTIES>::new();
359                    packet
360                        .decode(&mut BuffReader::new(self.buffer, read))
361                        .map(|_| [packet.packet_identifier, packet.reason_code as u16])
362                };
363
364                if let Err(err) = reason {
365                    error!("[DECODE ERR]: {}", err);
366                    return Err(ReasonCode::BuffError);
367                }
368
369                let res = reason.unwrap();
370
371                if res[1] != 0 {
372                    return Err(ReasonCode::from(res[1] as u8));
373                }
374
375                Ok(Event::Puback(res[0]))
376            }
377            PacketType::Suback => {
378                let reason: Result<(u16, Vec<u8, MAX_TOPICS>), BufferError> = {
379                    let mut packet = SubackPacket::<'b, MAX_TOPICS, MAX_PROPERTIES>::new();
380                    packet
381                        .decode(&mut BuffReader::new(self.buffer, read))
382                        .map(|_| (packet.packet_identifier, packet.reason_codes))
383                };
384
385                if let Err(err) = reason {
386                    error!("[DECODE ERR]: {}", err);
387                    return Err(ReasonCode::BuffError);
388                }
389                let (packet_identifier, reasons) = reason.unwrap();
390                for reason_code in &reasons {
391                    if *reason_code
392                        != (<QualityOfService as Into<u8>>::into(self.config.max_subscribe_qos)
393                            >> 1)
394                    {
395                        return Err(ReasonCode::from(*reason_code));
396                    }
397                }
398                Ok(Event::Suback(packet_identifier))
399            }
400            PacketType::Unsuback => {
401                let res: Result<u16, BufferError> = {
402                    let mut packet = UnsubackPacket::<'b, 1, MAX_PROPERTIES>::new();
403                    packet
404                        .decode(&mut BuffReader::new(self.buffer, read))
405                        .map(|_| packet.packet_identifier)
406                };
407
408                if let Err(err) = res {
409                    error!("[DECODE ERR]: {}", err);
410                    Err(ReasonCode::BuffError)
411                } else {
412                    Ok(Event::Unsuback(res.unwrap()))
413                }
414            }
415            PacketType::Pingresp => {
416                let mut packet = PingrespPacket::new();
417                if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
418                    error!("[DECODE ERR]: {}", err);
419                    Err(ReasonCode::BuffError)
420                } else {
421                    Ok(Event::Pingresp)
422                }
423            }
424            PacketType::Publish => {
425                let mut packet = PublishPacket::<'b, 5>::new();
426                if let Err(err) = { packet.decode(&mut BuffReader::new(self.buffer, read)) } {
427                    // if err == BufferError::PacketTypeMismatch {
428                    //     let mut disc = DisconnectPacket::<'b, 5>::new();
429                    //     if disc.decode(&mut BuffReader::new(self.buffer, read)).is_ok() {
430                    //         error!("Client was disconnected with reason: ");
431                    //         return Err(ReasonCode::from(disc.disconnect_reason));
432                    //     }
433                    // }
434                    error!("[DECODE ERR]: {}", err);
435                    return Err(ReasonCode::BuffError);
436                }
437
438                if (packet.fixed_header & 0x06)
439                    == <QualityOfService as Into<u8>>::into(QualityOfService::QoS1)
440                {
441                    let mut puback = PubackPacket::<'b, MAX_PROPERTIES>::new();
442                    puback.packet_identifier = packet.packet_identifier;
443                    puback.reason_code = 0x00;
444                    {
445                        let len = { puback.encode(self.recv_buffer, self.recv_buffer_len) };
446                        if let Err(err) = len {
447                            error!("[DECODE ERR]: {}", err);
448                            return Err(ReasonCode::BuffError);
449                        }
450                        conn.send(&self.recv_buffer[0..len.unwrap()]).await?;
451                    }
452                }
453
454                Ok(Event::Message(
455                    packet.topic_name.string,
456                    packet.message.unwrap(),
457                ))
458            }
459            PacketType::Disconnect => {
460                let mut disc = DisconnectPacket::<'b, 5>::new();
461                let res = disc.decode(&mut BuffReader::new(self.buffer, read));
462
463                match res {
464                    Ok(_) => Ok(Event::Disconnect(ReasonCode::from(disc.disconnect_reason))),
465                    Err(err) => {
466                        error!("[DECODE ERR]: {}", err);
467                        Err(ReasonCode::BuffError)
468                    }
469                }
470            }
471        }
472    }
473}
474
475#[cfg(not(feature = "tls"))]
476async fn receive_packet<'c, T: Read + Write>(
477    buffer: &mut [u8],
478    buffer_len: usize,
479    recv_buffer: &mut [u8],
480    conn: &'c mut NetworkConnection<T>,
481) -> Result<usize, ReasonCode> {
482    use crate::utils::buffer_writer::RemLenError;
483
484    let target_len: usize;
485    let mut rem_len: Result<VariableByteInteger, RemLenError>;
486    let mut writer = BuffWriter::new(buffer, buffer_len);
487    let mut i = 0;
488
489    // Get len of packet
490    trace!("Reading lenght of packet");
491    loop {
492        trace!("    Reading in loop!");
493        let len: usize = conn
494            .receive(&mut recv_buffer[writer.position..(writer.position + 1)])
495            .await?;
496        trace!("    Received data!");
497        if len == 0 {
498            trace!("Zero byte len packet received, dropping connection.");
499            return Err(ReasonCode::NetworkError);
500        }
501        i += len;
502        if let Err(_e) = writer.insert_ref(len, &recv_buffer[writer.position..i]) {
503            error!("Error occurred during write to buffer!");
504            return Err(ReasonCode::BuffError);
505        }
506        if i > 1 {
507            rem_len = writer.get_rem_len();
508            if rem_len.is_ok() {
509                break;
510            }
511            if i >= 5 {
512                error!("Could not read len of packet!");
513                return Err(ReasonCode::NetworkError);
514            }
515        }
516    }
517    trace!("Lenght done!");
518    let rem_len_len = i;
519    i = 0;
520    if let Ok(l) = VariableByteIntegerDecoder::decode(rem_len.unwrap()) {
521        trace!("Reading packet with target len {}", l);
522        target_len = l as usize;
523    } else {
524        error!("Could not decode len of packet!");
525        return Err(ReasonCode::BuffError);
526    }
527
528    loop {
529        if writer.position == target_len + rem_len_len {
530            trace!("Received packet with len: {}", (target_len + rem_len_len));
531            return Ok(target_len + rem_len_len);
532        }
533        let len: usize = conn
534            .receive(&mut recv_buffer[writer.position..writer.position + (target_len - i)])
535            .await?;
536        i += len;
537        if let Err(_e) =
538            writer.insert_ref(len, &recv_buffer[writer.position..(writer.position + i)])
539        {
540            error!("Error occurred during write to buffer!");
541            return Err(ReasonCode::BuffError);
542        }
543    }
544}
545
546#[cfg(feature = "tls")]
547async fn receive_packet<'c, T: Read + Write>(
548    buffer: &mut [u8],
549    buffer_len: usize,
550    recv_buffer: &mut [u8],
551    conn: &'c mut NetworkConnection<T>,
552) -> Result<usize, ReasonCode> {
553    trace!("Reading packet");
554    let mut writer = BuffWriter::new(buffer, buffer_len);
555    let len = conn.receive(recv_buffer).await?;
556    if let Err(_e) = writer.insert_ref(len, &recv_buffer[writer.position..(writer.position + len)])
557    {
558        error!("Error occurred during write to buffer!");
559        return Err(ReasonCode::BuffError);
560    }
561    Ok(len)
562}