simple_pub_sub/server/
mod.rs

1mod client_handler;
2use crate::topics;
3use anyhow::Result;
4use log::info;
5use std::fs::File;
6use std::io::Read;
7use tokio::net::TcpListener;
8use tokio::net::UnixListener;
9use tokio_native_tls::native_tls::{Identity, TlsAcceptor};
10
11pub trait ServerTrait {
12    fn start(&self) -> impl std::future::Future<Output = Result<()>> + Send;
13}
14pub struct Tcp {
15    pub host: String,
16    pub port: u16,
17    pub cert: Option<String>,
18    pub cert_password: Option<String>,
19    pub capacity: usize,
20}
21
22impl ServerTrait for Tcp {
23    /// Starts the simple pub sub server for the given server type
24    /// ```
25    /// use simple_pub_sub::server::ServerTrait as _;
26    /// // for tcp
27    /// async fn run_server_tcp(){
28    ///   let server = simple_pub_sub::server::ServerType::Tcp(simple_pub_sub::server::Tcp {
29    ///     host: "localhost".to_string(),
30    ///     port: 6480,
31    ///     cert: None,
32    ///     cert_password: None,
33    ///     capacity: 1024,
34    ///   });
35    ///   let _ = server.start().await;
36    /// }
37    /// // for tls
38    /// async fn run_server_tls(){
39    ///   let server = simple_pub_sub::server::ServerType::Tcp(simple_pub_sub::server::Tcp {
40    ///     host: "localhost".to_string(),
41    ///     port: 6480,
42    ///     cert: Some("certs/cert.pem".to_string()),
43    ///     cert_password: Some("password".to_string()),
44    ///     capacity: 1024,
45    ///   });
46    ///   let _ = server.start().await;
47    /// }
48    /// ```
49    async fn start(&self) -> Result<()> {
50        if let Some(cert) = &self.cert {
51            start_tls_server(
52                self.host.clone(),
53                self.port,
54                cert.clone(),
55                self.cert_password.clone(),
56                self.capacity,
57            )
58            .await
59        } else {
60            start_tcp_server(format!("{}:{}", self.host, self.port), self.capacity).await
61        }
62    }
63}
64pub struct Unix {
65    pub path: String,
66    pub capacity: usize,
67}
68
69impl ServerTrait for Unix {
70    /// start the pub-sub server over a unix socket
71    ///```
72    /// use crate::simple_pub_sub::server::ServerTrait as _;
73    /// let server = simple_pub_sub::server::ServerType::Unix(simple_pub_sub::server::Unix {
74    ///   path: "/tmp/sample.sock".to_string(),
75    ///   capacity: 1024,
76    /// });
77    /// let result = server.start();
78    ///```
79    async fn start(&self) -> Result<()> {
80        start_unix_server(self.path.clone(), self.capacity).await
81    }
82}
83impl Drop for Unix {
84    fn drop(&mut self) {
85        if std::path::Path::new(&self.path).exists() {
86            std::fs::remove_file(&self.path).unwrap();
87        }
88    }
89}
90
91pub enum ServerType {
92    Tcp(Tcp),
93    Unix(Unix),
94}
95impl ServerTrait for ServerType {
96    /// starts the simple-pub-sub server on the given server type
97    ///```
98    /// use simple_pub_sub::server::ServerTrait as _;
99    ///
100    /// // for tcp
101    ///   let server = simple_pub_sub::server::ServerType::Tcp(simple_pub_sub::server::Tcp {
102    ///     host: "localhost".to_string(),
103    ///     port: 6480,
104    ///     cert: None,
105    ///     cert_password: None,
106    ///     capacity: 1024,
107    ///   });
108    ///   server.start();
109    ///
110    /// // for tls
111    ///
112    ///   let server = simple_pub_sub::server::ServerType::Tcp(simple_pub_sub::server::Tcp {
113    ///     host: "localhost".to_string(),
114    ///     port: 6480,
115    ///     cert: Some("certs/cert.pem".to_string()),
116    ///     cert_password: Some("password".to_string()),
117    ///     capacity: 1024,
118    ///   });
119    ///   server.start();
120    ///
121    /// // for unix socket
122    /// use crate::simple_pub_sub::server::ServerTrait as _;
123    /// let server = simple_pub_sub::server::ServerType::Unix(simple_pub_sub::server::Unix {
124    ///   path: "/tmp/sample.sock".to_string(),
125    ///   capacity: 1024,
126    /// });
127    /// let result = server.start();
128    ///```
129    async fn start(&self) -> Result<()> {
130        match self {
131            ServerType::Tcp(tcp) => tcp.start().await,
132            ServerType::Unix(unix) => unix.start().await,
133        }
134    }
135}
136
137pub struct Server {
138    pub server_type: ServerType,
139}
140
141impl Server {
142    pub async fn start(&self) -> Result<()> {
143        self.server_type.start().await
144    }
145}
146
147/// Started a tls server on the given address with the given certificate (.pfx file)
148async fn start_tls_server(
149    host: String,
150    port: u16,
151    cert: String,
152    cert_password: Option<String>,
153    capacity: usize,
154) -> Result<()> {
155    // Load TLS identity (certificate and private key)
156    let mut file = File::open(&cert)?;
157    let mut identity_vec = vec![];
158    file.read_to_end(&mut identity_vec)?;
159
160    let identity: Identity;
161    if let Some(cert_password) = cert_password {
162        identity = Identity::from_pkcs12(&identity_vec, cert_password.as_str())?;
163    } else {
164        identity = Identity::from_pkcs12(&identity_vec, "")?;
165    }
166
167    let acceptor = TlsAcceptor::builder(identity).build()?;
168    let acceptor = tokio_native_tls::TlsAcceptor::from(acceptor);
169
170    // Bind TCP listener
171    let listener = TcpListener::bind(format!("{host}:{port}")).await?;
172
173    info!("Server listening on port {}:{}", host, port);
174    let tx = topics::get_global_broadcaster(capacity);
175    let _topic_handler = tokio::spawn(topics::topic_manager(tx.clone()));
176    loop {
177        let (stream, addr) = listener.accept().await?;
178        info!("Accepted connection from {:?}", addr);
179        let acceptor = acceptor.clone();
180        let tls_stream = acceptor.accept(stream).await?;
181        client_handler::handle_client(tls_stream, tx.clone()).await;
182    }
183}
184
185/// Starts a tcp server on the given address
186async fn start_tcp_server(addr: String, capacity: usize) -> Result<()> {
187    let listener = TcpListener::bind(&addr).await?;
188    info!("Listening on: {}", addr);
189    info!("Getting global broadcaster");
190
191    let tx = topics::get_global_broadcaster(capacity);
192    let _topic_handler = tokio::spawn(topics::topic_manager(tx.clone()));
193    loop {
194        let (socket, addr) = listener.accept().await?;
195        info!("Addr is: {addr}");
196        client_handler::handle_client(socket, tx.clone()).await;
197    }
198}
199
200/// Starts a unix server on the given path
201async fn start_unix_server(path: String, capacity: usize) -> Result<()> {
202    if std::path::Path::new(&path).exists() {
203        std::fs::remove_file(path.clone())?;
204    }
205
206    let listener = UnixListener::bind(&path)?;
207    info!("Listening on: {}", path);
208    info!("Getting global broadcaster");
209    let tx = topics::get_global_broadcaster(capacity);
210    let _topic_handler = tokio::spawn(topics::topic_manager(tx.clone()));
211    loop {
212        let (socket, addr) = listener.accept().await?;
213        info!("Addr is: {:?}", addr.as_pathname());
214        client_handler::handle_client(socket, tx.clone()).await;
215    }
216}