1use embedded_io_async::{Read, Write};
2use heapless::Vec;
3use rand_core::RngCore;
4
5use crate::{
6 encoding::variable_byte_integer::{VariableByteInteger, VariableByteIntegerDecoder},
7 network::NetworkConnection,
8 packet::v5::{
9 connack_packet::ConnackPacket,
10 connect_packet::ConnectPacket,
11 disconnect_packet::DisconnectPacket,
12 mqtt_packet::Packet,
13 packet_type::PacketType,
14 pingreq_packet::PingreqPacket,
15 pingresp_packet::PingrespPacket,
16 puback_packet::PubackPacket,
17 publish_packet::{PublishPacket, QualityOfService},
18 reason_codes::ReasonCode,
19 suback_packet::SubackPacket,
20 subscription_packet::SubscriptionPacket,
21 unsuback_packet::UnsubackPacket,
22 unsubscription_packet::UnsubscriptionPacket,
23 },
24 utils::{buffer_reader::BuffReader, buffer_writer::BuffWriter, types::BufferError},
25};
26
27use super::client_config::{ClientConfig, MqttVersion};
28
29pub enum Event<'a> {
30 Connack,
31 Puback(u16),
32 Suback(u16),
33 Unsuback(u16),
34 Pingresp,
35 Message(&'a str, &'a [u8]),
36 Disconnect(ReasonCode),
37}
38
39pub struct RawMqttClient<'a, T, const MAX_PROPERTIES: usize, R: RngCore>
40where
41 T: Read + Write,
42{
43 connection: Option<NetworkConnection<T>>,
44 buffer: &'a mut [u8],
45 buffer_len: usize,
46 recv_buffer: &'a mut [u8],
47 recv_buffer_len: usize,
48 config: ClientConfig<'a, MAX_PROPERTIES, R>,
49}
50
51impl<'a, T, const MAX_PROPERTIES: usize, R> RawMqttClient<'a, T, MAX_PROPERTIES, R>
52where
53 T: Read + Write,
54 R: RngCore,
55{
56 pub fn new(
57 network_driver: T,
58 buffer: &'a mut [u8],
59 buffer_len: usize,
60 recv_buffer: &'a mut [u8],
61 recv_buffer_len: usize,
62 config: ClientConfig<'a, MAX_PROPERTIES, R>,
63 ) -> Self {
64 Self {
65 connection: Some(NetworkConnection::new(network_driver)),
66 buffer,
67 buffer_len,
68 recv_buffer,
69 recv_buffer_len,
70 config,
71 }
72 }
73
74 async fn connect_to_broker_v5<'b>(&'b mut self) -> Result<(), ReasonCode> {
75 if self.connection.is_none() {
76 return Err(ReasonCode::NetworkError);
77 }
78 let len = {
79 let mut connect = ConnectPacket::<'b, MAX_PROPERTIES, 0>::new();
80 connect.keep_alive = self.config.keep_alive;
81 self.config.add_max_packet_size_as_prop();
82 connect.property_len = connect.add_properties(&self.config.properties);
83 if self.config.username_flag {
84 connect.add_username(&self.config.username);
85 }
86 if self.config.password_flag {
87 connect.add_password(&self.config.password)
88 }
89 if self.config.will_flag {
90 connect.add_will(
91 &self.config.will_topic,
92 &self.config.will_payload,
93 self.config.will_retain,
94 )
95 }
96 connect.add_client_id(&self.config.client_id);
97 connect.encode(self.buffer, self.buffer_len)
98 };
99
100 if let Err(err) = len {
101 error!("[DECODE ERR]: {}", err);
102 return Err(ReasonCode::BuffError);
103 }
104 let conn = self.connection.as_mut().unwrap();
105 trace!("Sending connect");
106 conn.send(&self.buffer[0..len.unwrap()]).await?;
107
108 Ok(())
109 }
110
111 pub async fn connect_to_broker<'b>(&'b mut self) -> Result<(), ReasonCode> {
116 match self.config.mqtt_version {
117 MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
118 MqttVersion::MQTTv5 => self.connect_to_broker_v5().await,
119 }
120 }
121
122 async fn disconnect_v5<'b>(&'b mut self) -> Result<(), ReasonCode> {
123 if self.connection.is_none() {
124 return Err(ReasonCode::NetworkError);
125 }
126 let conn = self.connection.as_mut().unwrap();
127 trace!("Creating disconnect packet!");
128 let mut disconnect = DisconnectPacket::<'b, MAX_PROPERTIES>::new();
129 let len = disconnect.encode(self.buffer, self.buffer_len);
130 if let Err(err) = len {
131 warn!("[DECODE ERR]: {}", err);
132 let _ = self.connection.take();
133 return Err(ReasonCode::BuffError);
134 }
135
136 if let Err(_e) = conn.send(&self.buffer[0..len.unwrap()]).await {
137 warn!("Could not send DISCONNECT packet");
138 }
139
140 let _ = self.connection.take();
142 Ok(())
143 }
144
145 pub async fn disconnect<'b>(&'b mut self) -> Result<(), ReasonCode> {
150 match self.config.mqtt_version {
151 MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
152 MqttVersion::MQTTv5 => self.disconnect_v5().await,
153 }
154 }
155
156 async fn send_message_v5<'b>(
157 &'b mut self,
158 topic_name: &'b str,
159 message: &'b [u8],
160 qos: QualityOfService,
161 retain: bool,
162 ) -> Result<u16, ReasonCode> {
163 if self.connection.is_none() {
164 return Err(ReasonCode::NetworkError);
165 }
166 let conn = self.connection.as_mut().unwrap();
167 let identifier: u16 = self.config.rng.next_u32() as u16;
168 let len = {
170 let mut packet = PublishPacket::<'b, MAX_PROPERTIES>::new();
171 packet.add_topic_name(topic_name);
172 packet.add_qos(qos);
173 packet.add_identifier(identifier);
174 packet.add_message(message);
175 packet.add_retain(retain);
176 packet.encode(self.buffer, self.buffer_len)
177 };
178
179 if let Err(err) = len {
180 error!("[DECODE ERR]: {}", err);
181 return Err(ReasonCode::BuffError);
182 }
183 trace!("Sending message");
184 conn.send(&self.buffer[0..len.unwrap()]).await?;
185
186 Ok(identifier)
187 }
188 pub async fn send_message<'b>(
193 &'b mut self,
194 topic_name: &'b str,
195 message: &'b [u8],
196 qos: QualityOfService,
197 retain: bool,
198 ) -> Result<u16, ReasonCode> {
199 match self.config.mqtt_version {
200 MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
201 MqttVersion::MQTTv5 => self.send_message_v5(topic_name, message, qos, retain).await,
202 }
203 }
204
205 async fn subscribe_to_topics_v5<'b, const TOPICS: usize>(
206 &'b mut self,
207 topic_names: &'b Vec<&'b str, TOPICS>,
208 ) -> Result<u16, ReasonCode> {
209 if self.connection.is_none() {
210 return Err(ReasonCode::NetworkError);
211 }
212 let conn = self.connection.as_mut().unwrap();
213 let identifier: u16 = self.config.rng.next_u32() as u16;
214 let len = {
215 let mut subs = SubscriptionPacket::<'b, TOPICS, MAX_PROPERTIES>::new();
216 subs.packet_identifier = identifier;
217 for topic_name in topic_names.iter() {
218 subs.add_new_filter(topic_name, self.config.max_subscribe_qos);
219 }
220 subs.encode(self.buffer, self.buffer_len)
221 };
222
223 if let Err(err) = len {
224 error!("[DECODE ERR]: {}", err);
225 return Err(ReasonCode::BuffError);
226 }
227
228 conn.send(&self.buffer[0..len.unwrap()]).await?;
229
230 Ok(identifier)
231 }
232
233 pub async fn subscribe_to_topics<'b, const TOPICS: usize>(
238 &'b mut self,
239 topic_names: &'b Vec<&'b str, TOPICS>,
240 ) -> Result<u16, ReasonCode> {
241 match self.config.mqtt_version {
242 MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
243 MqttVersion::MQTTv5 => self.subscribe_to_topics_v5(topic_names).await,
244 }
245 }
246
247 pub async fn unsubscribe_from_topic<'b>(
251 &'b mut self,
252 topic_name: &'b str,
253 ) -> Result<u16, ReasonCode> {
254 match self.config.mqtt_version {
255 MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
256 MqttVersion::MQTTv5 => self.unsubscribe_from_topic_v5(topic_name).await,
257 }
258 }
259
260 async fn unsubscribe_from_topic_v5<'b>(
261 &'b mut self,
262 topic_name: &'b str,
263 ) -> Result<u16, ReasonCode> {
264 if self.connection.is_none() {
265 return Err(ReasonCode::NetworkError);
266 }
267 let conn = self.connection.as_mut().unwrap();
268 let identifier = self.config.rng.next_u32() as u16;
269
270 let len = {
271 let mut unsub = UnsubscriptionPacket::<'b, 1, MAX_PROPERTIES>::new();
272 unsub.packet_identifier = identifier;
273 unsub.add_new_filter(topic_name);
274 unsub.encode(self.buffer, self.buffer_len)
275 };
276
277 if let Err(err) = len {
278 error!("[DECODE ERR]: {}", err);
279 return Err(ReasonCode::BuffError);
280 }
281 conn.send(&self.buffer[0..len.unwrap()]).await?;
282
283 Ok(identifier)
284 }
285
286 async fn send_ping_v5<'b>(&'b mut self) -> Result<(), ReasonCode> {
287 if self.connection.is_none() {
288 return Err(ReasonCode::NetworkError);
289 }
290 let conn = self.connection.as_mut().unwrap();
291 let len = {
292 let mut packet = PingreqPacket::new();
293 packet.encode(self.buffer, self.buffer_len)
294 };
295
296 if let Err(err) = len {
297 error!("[DECODE ERR]: {}", err);
298 return Err(ReasonCode::BuffError);
299 }
300
301 conn.send(&self.buffer[0..len.unwrap()]).await?;
302
303 Ok(())
304 }
305
306 pub async fn send_ping<'b>(&'b mut self) -> Result<(), ReasonCode> {
310 match self.config.mqtt_version {
311 MqttVersion::MQTTv3 => Err(ReasonCode::UnsupportedProtocolVersion),
312 MqttVersion::MQTTv5 => self.send_ping_v5().await,
313 }
314 }
315
316 pub async fn poll<'b, const MAX_TOPICS: usize>(&'b mut self) -> Result<Event<'b>, ReasonCode> {
317 if self.connection.is_none() {
318 return Err(ReasonCode::NetworkError);
319 }
320
321 let conn = self.connection.as_mut().unwrap();
322
323 trace!("Waiting for a packet");
324
325 let read = { receive_packet(self.buffer, self.buffer_len, self.recv_buffer, conn).await? };
326
327 let buf_reader = BuffReader::new(self.buffer, read);
328
329 match PacketType::from(buf_reader.peek_u8().map_err(|_| ReasonCode::BuffError)?) {
330 PacketType::Reserved
331 | PacketType::Connect
332 | PacketType::Subscribe
333 | PacketType::Unsubscribe
334 | PacketType::Pingreq => Err(ReasonCode::ProtocolError),
335 PacketType::Pubrec | PacketType::Pubrel | PacketType::Pubcomp | PacketType::Auth => {
336 Err(ReasonCode::ImplementationSpecificError)
337 }
338 PacketType::Connack => {
339 let mut packet = ConnackPacket::<'b, MAX_PROPERTIES>::new();
340 if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
341 error!("[DECODE ERR]: {}", err);
349 Err(ReasonCode::BuffError)
350 } else if packet.connect_reason_code != 0x00 {
351 Err(ReasonCode::from(packet.connect_reason_code))
352 } else {
353 Ok(Event::Connack)
354 }
355 }
356 PacketType::Puback => {
357 let reason: Result<[u16; 2], BufferError> = {
358 let mut packet = PubackPacket::<'b, MAX_PROPERTIES>::new();
359 packet
360 .decode(&mut BuffReader::new(self.buffer, read))
361 .map(|_| [packet.packet_identifier, packet.reason_code as u16])
362 };
363
364 if let Err(err) = reason {
365 error!("[DECODE ERR]: {}", err);
366 return Err(ReasonCode::BuffError);
367 }
368
369 let res = reason.unwrap();
370
371 if res[1] != 0 {
372 return Err(ReasonCode::from(res[1] as u8));
373 }
374
375 Ok(Event::Puback(res[0]))
376 }
377 PacketType::Suback => {
378 let reason: Result<(u16, Vec<u8, MAX_TOPICS>), BufferError> = {
379 let mut packet = SubackPacket::<'b, MAX_TOPICS, MAX_PROPERTIES>::new();
380 packet
381 .decode(&mut BuffReader::new(self.buffer, read))
382 .map(|_| (packet.packet_identifier, packet.reason_codes))
383 };
384
385 if let Err(err) = reason {
386 error!("[DECODE ERR]: {}", err);
387 return Err(ReasonCode::BuffError);
388 }
389 let (packet_identifier, reasons) = reason.unwrap();
390 for reason_code in &reasons {
391 if *reason_code
392 != (<QualityOfService as Into<u8>>::into(self.config.max_subscribe_qos)
393 >> 1)
394 {
395 return Err(ReasonCode::from(*reason_code));
396 }
397 }
398 Ok(Event::Suback(packet_identifier))
399 }
400 PacketType::Unsuback => {
401 let res: Result<u16, BufferError> = {
402 let mut packet = UnsubackPacket::<'b, 1, MAX_PROPERTIES>::new();
403 packet
404 .decode(&mut BuffReader::new(self.buffer, read))
405 .map(|_| packet.packet_identifier)
406 };
407
408 if let Err(err) = res {
409 error!("[DECODE ERR]: {}", err);
410 Err(ReasonCode::BuffError)
411 } else {
412 Ok(Event::Unsuback(res.unwrap()))
413 }
414 }
415 PacketType::Pingresp => {
416 let mut packet = PingrespPacket::new();
417 if let Err(err) = packet.decode(&mut BuffReader::new(self.buffer, read)) {
418 error!("[DECODE ERR]: {}", err);
419 Err(ReasonCode::BuffError)
420 } else {
421 Ok(Event::Pingresp)
422 }
423 }
424 PacketType::Publish => {
425 let mut packet = PublishPacket::<'b, 5>::new();
426 if let Err(err) = { packet.decode(&mut BuffReader::new(self.buffer, read)) } {
427 error!("[DECODE ERR]: {}", err);
435 return Err(ReasonCode::BuffError);
436 }
437
438 if (packet.fixed_header & 0x06)
439 == <QualityOfService as Into<u8>>::into(QualityOfService::QoS1)
440 {
441 let mut puback = PubackPacket::<'b, MAX_PROPERTIES>::new();
442 puback.packet_identifier = packet.packet_identifier;
443 puback.reason_code = 0x00;
444 {
445 let len = { puback.encode(self.recv_buffer, self.recv_buffer_len) };
446 if let Err(err) = len {
447 error!("[DECODE ERR]: {}", err);
448 return Err(ReasonCode::BuffError);
449 }
450 conn.send(&self.recv_buffer[0..len.unwrap()]).await?;
451 }
452 }
453
454 Ok(Event::Message(
455 packet.topic_name.string,
456 packet.message.unwrap(),
457 ))
458 }
459 PacketType::Disconnect => {
460 let mut disc = DisconnectPacket::<'b, 5>::new();
461 let res = disc.decode(&mut BuffReader::new(self.buffer, read));
462
463 match res {
464 Ok(_) => Ok(Event::Disconnect(ReasonCode::from(disc.disconnect_reason))),
465 Err(err) => {
466 error!("[DECODE ERR]: {}", err);
467 Err(ReasonCode::BuffError)
468 }
469 }
470 }
471 }
472 }
473}
474
475#[cfg(not(feature = "tls"))]
476async fn receive_packet<'c, T: Read + Write>(
477 buffer: &mut [u8],
478 buffer_len: usize,
479 recv_buffer: &mut [u8],
480 conn: &'c mut NetworkConnection<T>,
481) -> Result<usize, ReasonCode> {
482 use crate::utils::buffer_writer::RemLenError;
483
484 let target_len: usize;
485 let mut rem_len: Result<VariableByteInteger, RemLenError>;
486 let mut writer = BuffWriter::new(buffer, buffer_len);
487 let mut i = 0;
488
489 trace!("Reading lenght of packet");
491 loop {
492 trace!(" Reading in loop!");
493 let len: usize = conn
494 .receive(&mut recv_buffer[writer.position..(writer.position + 1)])
495 .await?;
496 trace!(" Received data!");
497 if len == 0 {
498 trace!("Zero byte len packet received, dropping connection.");
499 return Err(ReasonCode::NetworkError);
500 }
501 i += len;
502 if let Err(_e) = writer.insert_ref(len, &recv_buffer[writer.position..i]) {
503 error!("Error occurred during write to buffer!");
504 return Err(ReasonCode::BuffError);
505 }
506 if i > 1 {
507 rem_len = writer.get_rem_len();
508 if rem_len.is_ok() {
509 break;
510 }
511 if i >= 5 {
512 error!("Could not read len of packet!");
513 return Err(ReasonCode::NetworkError);
514 }
515 }
516 }
517 trace!("Lenght done!");
518 let rem_len_len = i;
519 i = 0;
520 if let Ok(l) = VariableByteIntegerDecoder::decode(rem_len.unwrap()) {
521 trace!("Reading packet with target len {}", l);
522 target_len = l as usize;
523 } else {
524 error!("Could not decode len of packet!");
525 return Err(ReasonCode::BuffError);
526 }
527
528 loop {
529 if writer.position == target_len + rem_len_len {
530 trace!("Received packet with len: {}", (target_len + rem_len_len));
531 return Ok(target_len + rem_len_len);
532 }
533 let len: usize = conn
534 .receive(&mut recv_buffer[writer.position..writer.position + (target_len - i)])
535 .await?;
536 i += len;
537 if let Err(_e) =
538 writer.insert_ref(len, &recv_buffer[writer.position..(writer.position + i)])
539 {
540 error!("Error occurred during write to buffer!");
541 return Err(ReasonCode::BuffError);
542 }
543 }
544}
545
546#[cfg(feature = "tls")]
547async fn receive_packet<'c, T: Read + Write>(
548 buffer: &mut [u8],
549 buffer_len: usize,
550 recv_buffer: &mut [u8],
551 conn: &'c mut NetworkConnection<T>,
552) -> Result<usize, ReasonCode> {
553 trace!("Reading packet");
554 let mut writer = BuffWriter::new(buffer, buffer_len);
555 let len = conn.receive(recv_buffer).await?;
556 if let Err(_e) = writer.insert_ref(len, &recv_buffer[writer.position..(writer.position + len)])
557 {
558 error!("Error occurred during write to buffer!");
559 return Err(ReasonCode::BuffError);
560 }
561 Ok(len)
562}