spotflow_rumqttc_fork/
eventloop.rs

1use crate::{framed::Network, Transport};
2#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
3use crate::tls;
4use crate::{Incoming, MqttState, Packet, Request, StateError};
5use crate::{MqttOptions, Outgoing};
6
7use crate::mqttbytes::v4::*;
8use async_channel::{bounded, Receiver, Sender};
9#[cfg(feature = "websocket")]
10use async_tungstenite::tokio::connect_async;
11#[cfg(all(feature = "use-rustls", feature = "websocket"))]
12use async_tungstenite::tokio::connect_async_with_tls_connector;
13use tokio::net::TcpStream;
14#[cfg(unix)]
15use tokio::net::UnixStream;
16use tokio::select;
17use tokio::time::{self, error::Elapsed, Instant, Sleep};
18#[cfg(feature = "websocket")]
19use ws_stream_tungstenite::WsStream;
20
21use std::io;
22#[cfg(unix)]
23use std::path::Path;
24use std::pin::Pin;
25use std::time::Duration;
26use std::vec::IntoIter;
27
28/// Critical errors during eventloop polling
29#[derive(Debug, thiserror::Error)]
30pub enum ConnectionError {
31    #[error("Mqtt state: {0}")]
32    MqttState(#[from] StateError),
33    #[error("Timeout")]
34    Timeout(#[from] Elapsed),
35    #[cfg(feature = "websocket")]
36    #[error("Websocket: {0}")]
37    Websocket(#[from] async_tungstenite::tungstenite::error::Error),
38    #[cfg(feature = "websocket")]
39    #[error("Websocket Connect: {0}")]
40    WsConnect(#[from] http::Error),
41    #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
42    #[error("TLS: {0}")]
43    Tls(#[from] tls::Error),
44    #[error("I/O: {0}")]
45    Io(#[from] io::Error),
46    #[error("Connection refused, return code: {0:?}")]
47    ConnectionRefused(ConnectReturnCode),
48    #[error("Expected ConnAck packet, received: {0:?}")]
49    NotConnAck(Packet),
50    #[error("Requests done")]
51    RequestsDone,
52    #[error("Cancel request by the user")]
53    Cancel,
54}
55
56/// Eventloop with all the state of a connection
57pub struct EventLoop {
58    /// Options of the current mqtt connection
59    pub options: MqttOptions,
60    /// Current state of the connection
61    pub state: MqttState,
62    /// Request stream
63    pub requests_rx: Receiver<Request>,
64    /// Requests handle to send requests
65    pub requests_tx: Sender<Request>,
66    /// Pending packets from last session
67    pub pending: IntoIter<Request>,
68    /// Network connection to the broker
69    pub(crate) network: Option<Network>,
70    /// Keep alive time
71    pub(crate) keepalive_timeout: Option<Pin<Box<Sleep>>>,
72    /// Handle to read cancellation requests
73    pub(crate) cancel_rx: Receiver<()>,
74    /// Handle to send cancellation requests (and drops)
75    pub(crate) cancel_tx: Sender<()>,
76}
77
78/// Events which can be yielded by the event loop
79#[derive(Debug, PartialEq, Clone)]
80pub enum Event {
81    Incoming(Incoming),
82    Outgoing(Outgoing),
83}
84
85impl EventLoop {
86    /// New MQTT `EventLoop`
87    ///
88    /// When connection encounters critical errors (like auth failure), user has a choice to
89    /// access and update `options`, `state` and `requests`.
90    pub fn new(options: MqttOptions, cap: usize) -> EventLoop {
91        let (cancel_tx, cancel_rx) = bounded(5);
92        let (requests_tx, requests_rx) = bounded(cap);
93        let pending = Vec::new();
94        let pending = pending.into_iter();
95        let max_inflight = options.inflight;
96        let manual_acks = options.manual_acks;
97
98        EventLoop {
99            options,
100            state: MqttState::new(max_inflight, manual_acks),
101            requests_tx,
102            requests_rx,
103            pending,
104            network: None,
105            keepalive_timeout: None,
106            cancel_rx,
107            cancel_tx,
108        }
109    }
110
111    /// Returns a handle to communicate with this eventloop
112    pub fn handle(&self) -> Sender<Request> {
113        self.requests_tx.clone()
114    }
115
116    /// Handle for cancelling the eventloop.
117    ///
118    /// Can be useful in cases when connection should be halted immediately
119    /// between half-open connection detections or (re)connection timeouts
120    pub(crate) fn cancel_handle(&mut self) -> Sender<()> {
121        self.cancel_tx.clone()
122    }
123
124    fn clean(&mut self) {
125        self.network = None;
126        self.keepalive_timeout = None;
127        let pending = self.state.clean();
128        self.pending = pending.into_iter();
129    }
130
131    /// Yields Next notification or outgoing request and periodically pings
132    /// the broker. Continuing to poll will reconnect to the broker if there is
133    /// a disconnection.
134    /// **NOTE** Don't block this while iterating
135    pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
136        if self.network.is_none() {
137            let (network, connack) = connect_or_cancel(&self.options, &self.cancel_rx).await?;
138            self.network = Some(network);
139
140            if self.keepalive_timeout.is_none() {
141                self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive)));
142            }
143
144            return Ok(Event::Incoming(connack));
145        }
146
147        match self.select().await {
148            Ok(v) => Ok(v),
149            Err(e) => {
150                self.clean();
151                Err(e)
152            }
153        }
154    }
155
156    /// Select on network and requests and generate keepalive pings when necessary
157    async fn select(&mut self) -> Result<Event, ConnectionError> {
158        let network = self.network.as_mut().unwrap();
159        // let await_acks = self.state.await_acks;
160        let inflight_full = self.state.inflight >= self.options.inflight;
161        let throttle = self.options.pending_throttle;
162        let pending = self.pending.len() > 0;
163        let collision = self.state.collision.is_some();
164
165        // Read buffered events from previous polls before calling a new poll
166        if let Some(event) = self.state.events.pop_front() {
167            return Ok(event);
168        }
169
170        // this loop is necessary since self.incoming.pop_front() might return None. In that case,
171        // instead of returning a None event, we try again.
172        select! {
173            // Pull a bunch of packets from network, reply in bunch and yield the first item
174            o = network.readb(&mut self.state) => {
175                o?;
176                // flush all the acks and return first incoming packet
177                network.flush(&mut self.state.write).await?;
178                Ok(self.state.events.pop_front().unwrap())
179            },
180            // Pull next request from user requests channel.
181            // If conditions in the below branch are for flow control. We read next user
182            // user request only when inflight messages are < configured inflight and there
183            // are no collisions while handling previous outgoing requests.
184            //
185            // Flow control is based on ack count. If inflight packet count in the buffer is
186            // less than max_inflight setting, next outgoing request will progress. For this
187            // to work correctly, broker should ack in sequence (a lot of brokers won't)
188            //
189            // E.g If max inflight = 5, user requests will be blocked when inflight queue
190            // looks like this                 -> [1, 2, 3, 4, 5].
191            // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5].
192            // This pulls next user request. But because max packet id = max_inflight, next
193            // user request's packet id will roll to 1. This replaces existing packet id 1.
194            // Resulting in a collision
195            //
196            // Eventloop can stop receiving outgoing user requests when previous outgoing
197            // request collided. I.e collision state. Collision state will be cleared only
198            // when correct ack is received
199            // Full inflight queue will look like -> [1a, 2, 3, 4, 5].
200            // If 3 is acked instead of 1 first   -> [1a, 2, x, 4, 5].
201            // After collision with pkid 1        -> [1b ,2, x, 4, 5].
202            // 1a is saved to state and event loop is set to collision mode stopping new
203            // outgoing requests (along with 1b).
204            o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o {
205                Ok(request) => {
206                    self.state.handle_outgoing_packet(request)?;
207                    network.flush(&mut self.state.write).await?;
208                    Ok(self.state.events.pop_front().unwrap())
209                }
210                Err(_) => Err(ConnectionError::RequestsDone),
211            },
212            // Handle the next pending packet from previous session. Disable
213            // this branch when done with all the pending packets
214            Some(request) = next_pending(throttle, &mut self.pending), if pending => {
215                self.state.handle_outgoing_packet(request)?;
216                network.flush(&mut self.state.write).await?;
217                Ok(self.state.events.pop_front().unwrap())
218            },
219            // We generate pings irrespective of network activity. This keeps the ping logic
220            // simple. We can change this behavior in future if necessary (to prevent extra pings)
221            _ = self.keepalive_timeout.as_mut().unwrap() => {
222                let timeout = self.keepalive_timeout.as_mut().unwrap();
223                timeout.as_mut().reset(Instant::now() + self.options.keep_alive);
224
225                self.state.handle_outgoing_packet(Request::PingReq)?;
226                network.flush(&mut self.state.write).await?;
227                Ok(self.state.events.pop_front().unwrap())
228            }
229            // cancellation requests to stop the polling
230            _ = self.cancel_rx.recv() => {
231                Err(ConnectionError::Cancel)
232            }
233        }
234    }
235}
236
237async fn connect_or_cancel(
238    options: &MqttOptions,
239    cancel_rx: &Receiver<()>,
240) -> Result<(Network, Incoming), ConnectionError> {
241    // select here prevents cancel request from being blocked until connection request is
242    // resolved. Returns with an error if connections fail continuously
243    select! {
244        o = connect(options) => o,
245        _ = cancel_rx.recv() => {
246            Err(ConnectionError::Cancel)
247        }
248    }
249}
250
251/// This stream internally processes requests from the request stream provided to the eventloop
252/// while also consuming byte stream from the network and yielding mqtt packets as the output of
253/// the stream.
254/// This function (for convenience) includes internal delays for users to perform internal sleeps
255/// between re-connections so that cancel semantics can be used during this sleep
256async fn connect(options: &MqttOptions) -> Result<(Network, Incoming), ConnectionError> {
257    // connect to the broker
258    let mut network = match network_connect(options).await {
259        Ok(network) => network,
260        Err(e) => {
261            return Err(e);
262        }
263    };
264
265    // make MQTT connection request (which internally awaits for ack)
266    let packet = match mqtt_connect(options, &mut network).await {
267        Ok(p) => p,
268        Err(e) => return Err(e),
269    };
270
271    // Last session might contain packets which aren't acked. MQTT says these packets should be
272    // republished in the next session
273    // move pending messages from state to eventloop
274    // let pending = self.state.clean();
275    // self.pending = pending.into_iter();
276    Ok((network, packet))
277}
278
279async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
280    let network = match options.transport() {
281        Transport::Tcp => {
282            let addr = options.broker_addr.as_str();
283            let port = options.port;
284            let socket = TcpStream::connect((addr, port)).await?;
285            Network::new(socket, options.max_incoming_packet_size)
286        }
287        #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
288        Transport::Tls(tls_config) => {
289            let socket = tls::tls_connect(options, &tls_config).await?;
290            Network::new(socket, options.max_incoming_packet_size)
291        }
292        #[cfg(unix)]
293        Transport::Unix => {
294            let file = options.broker_addr.as_str();
295            let socket = UnixStream::connect(Path::new(file)).await?;
296            Network::new(socket, options.max_incoming_packet_size)
297        }
298        #[cfg(feature = "websocket")]
299        Transport::Ws => {
300            let request = http::Request::builder()
301                .method(http::Method::GET)
302                .uri(options.broker_addr.as_str())
303                .header("Sec-WebSocket-Protocol", "mqttv3.1")
304                .body(())?;
305
306            let (socket, _) = connect_async(request).await?;
307
308            Network::new(WsStream::new(socket), options.max_incoming_packet_size)
309        }
310        #[cfg(all(feature = "use-rustls", feature = "websocket"))]
311        Transport::Wss(tls_config) => {
312            let request = http::Request::builder()
313                .method(http::Method::GET)
314                .uri(options.broker_addr.as_str())
315                .header("Sec-WebSocket-Protocol", "mqttv3.1")
316                .body(())?;
317
318            let connector = tls::rustls_connector(&tls_config).await?;
319
320            let (socket, _) = connect_async_with_tls_connector(request, Some(connector)).await?;
321
322            Network::new(WsStream::new(socket), options.max_incoming_packet_size)
323        }
324    };
325
326    Ok(network)
327}
328
329async fn mqtt_connect(
330    options: &MqttOptions,
331    network: &mut Network,
332) -> Result<Incoming, ConnectionError> {
333    let keep_alive = options.keep_alive().as_secs() as u16;
334    let clean_session = options.clean_session();
335    let last_will = options.last_will();
336
337    let mut connect = Connect::new(options.client_id());
338    connect.keep_alive = keep_alive;
339    connect.clean_session = clean_session;
340    connect.last_will = last_will;
341
342    if let Some((username, password)) = options.credentials() {
343        let login = Login::new(username, password);
344        connect.login = Some(login);
345    }
346
347    // mqtt connection with timeout
348    time::timeout(Duration::from_secs(options.connection_timeout()), async {
349        network.connect(connect).await?;
350        Ok::<_, ConnectionError>(())
351    })
352    .await??;
353
354    // wait for 'timeout' time to validate connack
355    let packet = time::timeout(Duration::from_secs(options.connection_timeout()), async {
356        match network.read().await? {
357            Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
358                Ok(Packet::ConnAck(connack))
359            }
360            Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
361            packet => Err(ConnectionError::NotConnAck(packet)),
362        }
363    })
364    .await??;
365
366    Ok(packet)
367}
368
369/// Returns the next pending packet asynchronously to be used in select!
370/// This is a synchronous function but made async to make it fit in select!
371pub(crate) async fn next_pending(
372    delay: Duration,
373    pending: &mut IntoIter<Request>,
374) -> Option<Request> {
375    // return next packet with a delay
376    time::sleep(delay).await;
377    pending.next()
378}