1use crate::{framed::Network, Transport};
2#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
3use crate::tls;
4use crate::{Incoming, MqttState, Packet, Request, StateError};
5use crate::{MqttOptions, Outgoing};
6
7use crate::mqttbytes::v4::*;
8use async_channel::{bounded, Receiver, Sender};
9#[cfg(feature = "websocket")]
10use async_tungstenite::tokio::connect_async;
11#[cfg(all(feature = "use-rustls", feature = "websocket"))]
12use async_tungstenite::tokio::connect_async_with_tls_connector;
13use tokio::net::TcpStream;
14#[cfg(unix)]
15use tokio::net::UnixStream;
16use tokio::select;
17use tokio::time::{self, error::Elapsed, Instant, Sleep};
18#[cfg(feature = "websocket")]
19use ws_stream_tungstenite::WsStream;
20
21use std::io;
22#[cfg(unix)]
23use std::path::Path;
24use std::pin::Pin;
25use std::time::Duration;
26use std::vec::IntoIter;
27
28#[derive(Debug, thiserror::Error)]
30pub enum ConnectionError {
31 #[error("Mqtt state: {0}")]
32 MqttState(#[from] StateError),
33 #[error("Timeout")]
34 Timeout(#[from] Elapsed),
35 #[cfg(feature = "websocket")]
36 #[error("Websocket: {0}")]
37 Websocket(#[from] async_tungstenite::tungstenite::error::Error),
38 #[cfg(feature = "websocket")]
39 #[error("Websocket Connect: {0}")]
40 WsConnect(#[from] http::Error),
41 #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
42 #[error("TLS: {0}")]
43 Tls(#[from] tls::Error),
44 #[error("I/O: {0}")]
45 Io(#[from] io::Error),
46 #[error("Connection refused, return code: {0:?}")]
47 ConnectionRefused(ConnectReturnCode),
48 #[error("Expected ConnAck packet, received: {0:?}")]
49 NotConnAck(Packet),
50 #[error("Requests done")]
51 RequestsDone,
52 #[error("Cancel request by the user")]
53 Cancel,
54}
55
56pub struct EventLoop {
58 pub options: MqttOptions,
60 pub state: MqttState,
62 pub requests_rx: Receiver<Request>,
64 pub requests_tx: Sender<Request>,
66 pub pending: IntoIter<Request>,
68 pub(crate) network: Option<Network>,
70 pub(crate) keepalive_timeout: Option<Pin<Box<Sleep>>>,
72 pub(crate) cancel_rx: Receiver<()>,
74 pub(crate) cancel_tx: Sender<()>,
76}
77
78#[derive(Debug, PartialEq, Clone)]
80pub enum Event {
81 Incoming(Incoming),
82 Outgoing(Outgoing),
83}
84
85impl EventLoop {
86 pub fn new(options: MqttOptions, cap: usize) -> EventLoop {
91 let (cancel_tx, cancel_rx) = bounded(5);
92 let (requests_tx, requests_rx) = bounded(cap);
93 let pending = Vec::new();
94 let pending = pending.into_iter();
95 let max_inflight = options.inflight;
96 let manual_acks = options.manual_acks;
97
98 EventLoop {
99 options,
100 state: MqttState::new(max_inflight, manual_acks),
101 requests_tx,
102 requests_rx,
103 pending,
104 network: None,
105 keepalive_timeout: None,
106 cancel_rx,
107 cancel_tx,
108 }
109 }
110
111 pub fn handle(&self) -> Sender<Request> {
113 self.requests_tx.clone()
114 }
115
116 pub(crate) fn cancel_handle(&mut self) -> Sender<()> {
121 self.cancel_tx.clone()
122 }
123
124 fn clean(&mut self) {
125 self.network = None;
126 self.keepalive_timeout = None;
127 let pending = self.state.clean();
128 self.pending = pending.into_iter();
129 }
130
131 pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
136 if self.network.is_none() {
137 let (network, connack) = connect_or_cancel(&self.options, &self.cancel_rx).await?;
138 self.network = Some(network);
139
140 if self.keepalive_timeout.is_none() {
141 self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive)));
142 }
143
144 return Ok(Event::Incoming(connack));
145 }
146
147 match self.select().await {
148 Ok(v) => Ok(v),
149 Err(e) => {
150 self.clean();
151 Err(e)
152 }
153 }
154 }
155
156 async fn select(&mut self) -> Result<Event, ConnectionError> {
158 let network = self.network.as_mut().unwrap();
159 let inflight_full = self.state.inflight >= self.options.inflight;
161 let throttle = self.options.pending_throttle;
162 let pending = self.pending.len() > 0;
163 let collision = self.state.collision.is_some();
164
165 if let Some(event) = self.state.events.pop_front() {
167 return Ok(event);
168 }
169
170 select! {
173 o = network.readb(&mut self.state) => {
175 o?;
176 network.flush(&mut self.state.write).await?;
178 Ok(self.state.events.pop_front().unwrap())
179 },
180 o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o {
205 Ok(request) => {
206 self.state.handle_outgoing_packet(request)?;
207 network.flush(&mut self.state.write).await?;
208 Ok(self.state.events.pop_front().unwrap())
209 }
210 Err(_) => Err(ConnectionError::RequestsDone),
211 },
212 Some(request) = next_pending(throttle, &mut self.pending), if pending => {
215 self.state.handle_outgoing_packet(request)?;
216 network.flush(&mut self.state.write).await?;
217 Ok(self.state.events.pop_front().unwrap())
218 },
219 _ = self.keepalive_timeout.as_mut().unwrap() => {
222 let timeout = self.keepalive_timeout.as_mut().unwrap();
223 timeout.as_mut().reset(Instant::now() + self.options.keep_alive);
224
225 self.state.handle_outgoing_packet(Request::PingReq)?;
226 network.flush(&mut self.state.write).await?;
227 Ok(self.state.events.pop_front().unwrap())
228 }
229 _ = self.cancel_rx.recv() => {
231 Err(ConnectionError::Cancel)
232 }
233 }
234 }
235}
236
237async fn connect_or_cancel(
238 options: &MqttOptions,
239 cancel_rx: &Receiver<()>,
240) -> Result<(Network, Incoming), ConnectionError> {
241 select! {
244 o = connect(options) => o,
245 _ = cancel_rx.recv() => {
246 Err(ConnectionError::Cancel)
247 }
248 }
249}
250
251async fn connect(options: &MqttOptions) -> Result<(Network, Incoming), ConnectionError> {
257 let mut network = match network_connect(options).await {
259 Ok(network) => network,
260 Err(e) => {
261 return Err(e);
262 }
263 };
264
265 let packet = match mqtt_connect(options, &mut network).await {
267 Ok(p) => p,
268 Err(e) => return Err(e),
269 };
270
271 Ok((network, packet))
277}
278
279async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
280 let network = match options.transport() {
281 Transport::Tcp => {
282 let addr = options.broker_addr.as_str();
283 let port = options.port;
284 let socket = TcpStream::connect((addr, port)).await?;
285 Network::new(socket, options.max_incoming_packet_size)
286 }
287 #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
288 Transport::Tls(tls_config) => {
289 let socket = tls::tls_connect(options, &tls_config).await?;
290 Network::new(socket, options.max_incoming_packet_size)
291 }
292 #[cfg(unix)]
293 Transport::Unix => {
294 let file = options.broker_addr.as_str();
295 let socket = UnixStream::connect(Path::new(file)).await?;
296 Network::new(socket, options.max_incoming_packet_size)
297 }
298 #[cfg(feature = "websocket")]
299 Transport::Ws => {
300 let request = http::Request::builder()
301 .method(http::Method::GET)
302 .uri(options.broker_addr.as_str())
303 .header("Sec-WebSocket-Protocol", "mqttv3.1")
304 .body(())?;
305
306 let (socket, _) = connect_async(request).await?;
307
308 Network::new(WsStream::new(socket), options.max_incoming_packet_size)
309 }
310 #[cfg(all(feature = "use-rustls", feature = "websocket"))]
311 Transport::Wss(tls_config) => {
312 let request = http::Request::builder()
313 .method(http::Method::GET)
314 .uri(options.broker_addr.as_str())
315 .header("Sec-WebSocket-Protocol", "mqttv3.1")
316 .body(())?;
317
318 let connector = tls::rustls_connector(&tls_config).await?;
319
320 let (socket, _) = connect_async_with_tls_connector(request, Some(connector)).await?;
321
322 Network::new(WsStream::new(socket), options.max_incoming_packet_size)
323 }
324 };
325
326 Ok(network)
327}
328
329async fn mqtt_connect(
330 options: &MqttOptions,
331 network: &mut Network,
332) -> Result<Incoming, ConnectionError> {
333 let keep_alive = options.keep_alive().as_secs() as u16;
334 let clean_session = options.clean_session();
335 let last_will = options.last_will();
336
337 let mut connect = Connect::new(options.client_id());
338 connect.keep_alive = keep_alive;
339 connect.clean_session = clean_session;
340 connect.last_will = last_will;
341
342 if let Some((username, password)) = options.credentials() {
343 let login = Login::new(username, password);
344 connect.login = Some(login);
345 }
346
347 time::timeout(Duration::from_secs(options.connection_timeout()), async {
349 network.connect(connect).await?;
350 Ok::<_, ConnectionError>(())
351 })
352 .await??;
353
354 let packet = time::timeout(Duration::from_secs(options.connection_timeout()), async {
356 match network.read().await? {
357 Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
358 Ok(Packet::ConnAck(connack))
359 }
360 Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
361 packet => Err(ConnectionError::NotConnAck(packet)),
362 }
363 })
364 .await??;
365
366 Ok(packet)
367}
368
369pub(crate) async fn next_pending(
372 delay: Duration,
373 pending: &mut IntoIter<Request>,
374) -> Option<Request> {
375 time::sleep(delay).await;
377 pending.next()
378}