simple_pub_sub/client/
mod.rs

1use crate::error::PubSubError::ClientNotConnected;
2use crate::message;
3use crate::message::Msg;
4use crate::stream;
5use crate::Header;
6use crate::PktType;
7use anyhow::Result;
8use log::{info, trace};
9use std::fs::File;
10use std::io::Read;
11use tokio::{
12    io::{AsyncReadExt, AsyncWriteExt},
13    net::TcpStream,
14    net::UnixStream,
15};
16use tokio_native_tls::native_tls::{Certificate, TlsConnector};
17use tokio_native_tls::TlsStream;
18
19/// Simple pub sub Client for Tcp connection
20#[derive(Debug, Clone)]
21pub struct PubSubTcpClient {
22    /// domain/host for the server
23    /// for example: `127.0.0.1`
24    pub server: String,
25    /// port for the simple_pub_sub server
26    pub port: u16,
27    /// tls certificate (`.pem`) file
28    pub cert: Option<String>,
29    /// password for the tls certificate
30    pub cert_password: Option<String>,
31}
32
33/// Simple pub sub Client for Unix connection
34#[derive(Debug, Clone)]
35pub struct PubSubUnixClient {
36    /// path for the unix sock file
37    /// for example: `/tmp/simple-pub-sub.sock`
38    pub path: String,
39}
40
41/// Simple pub sub Client
42#[derive(Debug, Clone)]
43pub enum PubSubClient {
44    /// tcp client for the simple pub sub
45    Tcp(PubSubTcpClient),
46    /// unix socket client for the simple pub sub
47    Unix(PubSubUnixClient),
48}
49
50impl PubSubClient {
51    fn server(&self) -> &str {
52        match self {
53            PubSubClient::Tcp(pub_sub_tcp_client) => &pub_sub_tcp_client.server,
54            PubSubClient::Unix(pub_sub_unix_client) => &pub_sub_unix_client.path,
55        }
56    }
57}
58
59/// Stream for Tcp and Unix connection
60#[derive(Debug)]
61pub enum StreamType {
62    /// tcp stream
63    Tcp(TcpStream),
64    /// tls stream
65    Tls(Box<TlsStream<TcpStream>>),
66    /// unix socket stream
67    Unix(UnixStream),
68}
69
70impl StreamType {
71    async fn read_message(&mut self) -> Result<Msg> {
72        match self {
73            StreamType::Tcp(stream) => Ok(stream::read_message(stream).await?),
74            StreamType::Tls(stream) => Ok(stream::read_message(stream).await?),
75            StreamType::Unix(stream) => Ok(stream::read_message(stream).await?),
76        }
77    }
78
79    async fn read_buf(&mut self, message: &mut Vec<u8>) -> Result<usize> {
80        let size = match self {
81            StreamType::Tls(ref mut tls_stream) => tls_stream.read_buf(message).await?,
82            StreamType::Tcp(ref mut tcp_stream) => tcp_stream.read_buf(message).await?,
83            StreamType::Unix(ref mut unix_stream) => unix_stream.read_buf(message).await?,
84        };
85        Ok(size)
86    }
87
88    async fn read(&mut self, message: &mut [u8]) -> Result<usize> {
89        let size = match self {
90            StreamType::Tls(ref mut tls_stream) => tls_stream.read(message).await?,
91            StreamType::Tcp(ref mut tcp_stream) => tcp_stream.read(message).await?,
92            StreamType::Unix(ref mut unix_stream) => unix_stream.read(message).await?,
93        };
94        Ok(size)
95    }
96
97    async fn write_all(&mut self, message: Vec<u8>) -> Result<()> {
98        match self {
99            StreamType::Tls(tls_stream) => tls_stream.write_all(&message).await?,
100            StreamType::Tcp(ref mut tcp_stream) => tcp_stream.write_all(&message).await?,
101            StreamType::Unix(ref mut unix_stream) => unix_stream.write_all(&message).await?,
102        };
103        Ok(())
104    }
105}
106
107/// Simple pub sub Client
108#[derive(Debug)]
109pub struct Client {
110    pub client_type: PubSubClient,
111    stream: Option<StreamType>,
112}
113
114/// default implementation for callback function
115pub fn on_message(topic: String, message: Vec<u8>) {
116    match String::from_utf8(message.clone()) {
117        Ok(msg_str) => {
118            info!("Topic: {} message: {}", topic, msg_str);
119        }
120        Err(_) => {
121            info!("Topic: {} message: {:?}", topic, message);
122        }
123    };
124}
125
126impl Client {
127    /// Creates a new instance of `Client`
128    /// ```
129    /// use simple_pub_sub::client::{self, PubSubClient, Client};
130    /// let client_type = simple_pub_sub::client::PubSubTcpClient {
131    ///        server: "localhost".to_string(),
132    ///        port: 6480,
133    ///        cert: None,
134    ///        cert_password: None,
135    /// };
136    ///
137    /// // initialize the client.
138    /// let mut pub_sub_client = simple_pub_sub::client::Client::new(
139    ///     simple_pub_sub::client::PubSubClient::Tcp(client_type)
140    /// );
141    /// ```
142    pub fn new(client_type: PubSubClient) -> Self {
143        Client {
144            client_type,
145            stream: None,
146        }
147    }
148
149    async fn connect_tls(&mut self, url: String, cert: String) -> Result<()> {
150        // Load CA certificate
151        let mut file = File::open(cert)?;
152        let mut ca_cert = vec![];
153        file.read_to_end(&mut ca_cert)?;
154        let ca_cert = Certificate::from_pem(&ca_cert)?;
155
156        // Configure TLS
157        let connector = TlsConnector::builder()
158            .add_root_certificate(ca_cert)
159            .build()?;
160
161        let connector = tokio_native_tls::TlsConnector::from(connector);
162
163        // Connect to the server
164        let stream = TcpStream::connect(&url).await?;
165
166        // create the StreamType::Tls
167        let connector = connector.connect(self.client_type.server(), stream).await?;
168        self.stream = Some(StreamType::Tls(Box::new(connector)));
169        Ok(())
170    }
171
172    /// Connects to the server
173    ///```
174    /// use simple_pub_sub::client::{self, PubSubClient, Client};
175    /// let client_type = simple_pub_sub::client::PubSubTcpClient {
176    ///        server: "localhost".to_string(),
177    ///        port: 6480,
178    ///        cert: None,
179    ///        cert_password: None,
180    /// };
181    ///
182    /// // initialize the client.
183    /// let mut pub_sub_client = simple_pub_sub::client::Client::new(
184    ///     simple_pub_sub::client::PubSubClient::Tcp(client_type),
185    /// );
186    /// pub_sub_client.connect();
187    /// ```
188    pub async fn connect(&mut self) -> Result<()> {
189        match self.client_type.clone() {
190            PubSubClient::Tcp(tcp_client) => {
191                let server_url: String = format!("{}:{}", tcp_client.server, tcp_client.port);
192                if let Some(cert) = tcp_client.cert {
193                    self.connect_tls(server_url, cert).await?;
194                } else {
195                    let stream = TcpStream::connect(server_url).await?;
196                    self.stream = Some(StreamType::Tcp(stream));
197                }
198            }
199            PubSubClient::Unix(unix_stream) => {
200                let path = unix_stream.path;
201                let stream = UnixStream::connect(path).await?;
202                self.stream = Some(StreamType::Unix(stream));
203            }
204        }
205        Ok(())
206    }
207
208    /// Sends the message to the given server and returns the ack
209    /// the server could be either a tcp or unix server
210    ///```
211    /// use simple_pub_sub::client::{PubSubClient, Client};
212    /// use simple_pub_sub::message::Msg;
213    /// use simple_pub_sub::PktType;
214    /// async fn publish_msg(){
215    ///   let client_type = simple_pub_sub::client::PubSubTcpClient {
216    ///          server: "localhost".to_string(),
217    ///          port: 6480,
218    ///          cert: None,
219    ///          cert_password: None,
220    ///   };
221    ///
222    /// // initialize the client.
223    /// let mut pub_sub_client = simple_pub_sub::client::Client::new(
224    ///     simple_pub_sub::client::PubSubClient::Tcp(client_type),
225    /// );
226    /// pub_sub_client.connect().await.unwrap();
227    /// let msg = Msg::new(PktType::PUBLISH, "Test".to_string(), Some(b"The message".to_vec()));
228    ///   pub_sub_client.post(msg).await.unwrap();
229    /// }
230    /// ```
231    pub async fn post(&mut self, msg: Msg) -> Result<Vec<u8>> {
232        self.write(msg.bytes()).await?;
233        let mut response_message_buffer: Vec<u8>;
234        response_message_buffer = vec![0; 8];
235        self.read(&mut response_message_buffer).await?;
236        trace!("Buf: {:?}", response_message_buffer);
237        let response_header = Header::try_from(response_message_buffer.clone())?;
238        trace!("Resp: {:?}", response_header);
239        let mut response_body = Vec::with_capacity(response_header.message_length as usize);
240        trace!("Reading remaining bytes");
241        self.read_buf(&mut response_body).await?;
242        response_message_buffer.extend(response_body);
243        Ok(response_message_buffer)
244    }
245
246    /// Publishes the message to the given topic
247    /// ```
248    /// use simple_pub_sub::client::{PubSubClient, Client};
249    /// async fn publish_msg(){
250    ///   let client_type = simple_pub_sub::client::PubSubTcpClient {
251    ///          server: "localhost".to_string(),
252    ///          port: 6480,
253    ///          cert: None,
254    ///          cert_password: None,
255    ///   };
256    ///
257    /// // initialize the client.
258    /// let mut pub_sub_client = simple_pub_sub::client::Client::new(
259    ///     simple_pub_sub::client::PubSubClient::Tcp(client_type),
260    /// );
261    /// pub_sub_client.connect().await.unwrap();
262    /// // subscribe to the given topic.
263    /// pub_sub_client
264    ///   .publish(
265    ///     "Abc".to_string(),
266    ///     "Test message".to_string().into_bytes().to_vec(),
267    ///   ).await.unwrap();
268    /// }
269    /// ```
270    pub async fn publish(&mut self, topic: String, message: Vec<u8>) -> Result<()> {
271        let msg: Msg = Msg::new(PktType::PUBLISH, topic, Some(message));
272        trace!("Msg: {:?}", msg);
273        let buf = self.post(msg).await?;
274        trace!("The raw buffer is: {:?}", buf);
275        let resp_: Header = Header::try_from(buf)?;
276        trace!("{:?}", resp_);
277        Ok(())
278    }
279
280    /// Sends the query message to the server
281    /// ```
282    /// use simple_pub_sub::client::{self, PubSubClient, Client};
283    /// async fn query(){
284    ///   let client_type = simple_pub_sub::client::PubSubTcpClient {
285    ///          server: "localhost".to_string(),
286    ///          port: 6480,
287    ///          cert: None,
288    ///          cert_password: None,
289    ///   };
290    ///
291    /// // initialize the client.
292    /// let mut pub_sub_client = simple_pub_sub::client::Client::new(
293    ///     simple_pub_sub::client::PubSubClient::Tcp(client_type),
294    /// );
295    /// pub_sub_client.connect().await.unwrap();
296    /// pub_sub_client.query("Test".to_string());
297    /// }
298    /// ```
299    pub async fn query(&mut self, topic: String) -> Result<String> {
300        let msg: Msg = Msg::new(
301            PktType::QUERY,
302            topic,
303            Some(" ".to_string().as_bytes().to_vec()),
304        );
305        trace!("Msg: {:?}", msg);
306
307        self.write(msg.bytes()).await?;
308        let msg = self.read_message().await?;
309        Ok(String::from_utf8(msg.message)?)
310    }
311
312    /// subscribes to the given topic
313    ///```
314    /// use simple_pub_sub::client::{self, PubSubClient, Client};
315    /// let client_type = simple_pub_sub::client::PubSubTcpClient {
316    ///        server: "localhost".to_string(),
317    ///        port: 6480,
318    ///        cert: None,
319    ///        cert_password: None,
320    /// };
321    /// // initialize the client.
322    /// let mut pub_sub_client = simple_pub_sub::client::Client::new(
323    ///     simple_pub_sub::client::PubSubClient::Tcp(client_type));
324    /// pub_sub_client.subscribe("Test".to_string());
325    /// ```
326    pub async fn subscribe(&mut self, topic: String) -> Result<()> {
327        let msg: message::Msg = message::Msg::new(PktType::SUBSCRIBE, topic, None);
328        trace!("Msg: {:?}", msg);
329        self.write(msg.bytes()).await
330    }
331
332    async fn write(&mut self, message: Vec<u8>) -> Result<()> {
333        if let Some(stream) = &mut self.stream {
334            stream.write_all(message).await?;
335            Ok(())
336        } else {
337            Err(anyhow::anyhow!(ClientNotConnected))
338        }
339    }
340
341    async fn read(&mut self, message: &mut [u8]) -> Result<()> {
342        if let Some(stream) = &mut self.stream {
343            let size = stream.read(message).await?;
344            trace!("Read: {} bytes", size);
345            Ok(())
346        } else {
347            Err(anyhow::anyhow!(ClientNotConnected))
348        }
349    }
350    async fn read_buf(&mut self, message: &mut Vec<u8>) -> Result<()> {
351        if let Some(stream) = &mut self.stream {
352            let size = stream.read_buf(message).await?;
353            trace!("Read: {} bytes", size);
354            Ok(())
355        } else {
356            Err(anyhow::anyhow!(ClientNotConnected))
357        }
358    }
359
360    /// reads the incoming message from the server
361    /// useful when you need to read the messages in loop
362    /// ```
363    /// use simple_pub_sub::client::{self, PubSubClient, Client};
364    ///
365    /// async fn read_messages(){
366    ///   let client_type = simple_pub_sub::client::PubSubTcpClient {
367    ///          server: "localhost".to_string(),
368    ///          port: 6480,
369    ///          cert: None,
370    ///          cert_password: None,
371    ///   };
372    ///   // initialize the client.
373    ///   let mut pub_sub_client = simple_pub_sub::client::Client::new(
374    ///       simple_pub_sub::client::PubSubClient::Tcp(client_type));
375    ///   pub_sub_client.connect().await.unwrap();
376    ///   pub_sub_client.subscribe("Test".to_string()).await.unwrap();
377    ///
378    ///   loop {
379    ///       match pub_sub_client.read_message().await{
380    ///           Ok(msg)=>{
381    ///               println!("{}: {:?}", msg.topic, msg.message);
382    ///           }
383    ///           Err(e)=>{
384    ///               println!("error: {:?}", e);
385    ///               break
386    ///           }
387    ///       }
388    ///   }
389    /// }
390    /// ```
391    pub async fn read_message(&mut self) -> Result<Msg> {
392        if let Some(stream) = &mut self.stream {
393            stream.read_message().await
394        } else {
395            Err(anyhow::anyhow!(ClientNotConnected))
396        }
397    }
398}