signalrs_client_custom_auth/
builder.rs

1//! SignalR client builder
2
3use super::{hub::Hub, transport, SignalRClient};
4use crate::{
5    messages::ClientMessage, protocol::NegotiateResponseV0, transport::error::TransportError,
6};
7use thiserror::Error;
8use tokio::net::TcpStream;
9use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
10use tracing::*;
11
12/// [`SignalRClient`] builder.
13///
14/// Allows configuring connection and behavior details.
15///  
16/// # Example
17/// ```rust, no_run
18/// use signalrs_client::SignalRClient;
19///
20/// #[tokio::main]
21/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
22///     let client = SignalRClient::builder("example.com")
23///         .use_port(8080)
24///         .use_hub("echo")
25///         .build()
26///         .await?;
27///
28/// # Ok(())
29/// }
30/// ```
31pub struct ClientBuilder {
32    domain: String,
33    hub: Option<Hub>,
34    auth: Auth,
35    secure_connection: bool,
36    port: Option<usize>,
37    query_string: Option<String>,
38    hub_path: Option<String>,
39}
40
41/// Authentication for negotiate and further network connection
42pub enum Auth {
43    None,
44    Basic {
45        user: String,
46        password: Option<String>,
47    },
48    Bearer {
49        token: String,
50    },
51    Res {
52        user_id: String,
53        token: String,
54    },
55}
56
57/// Errors that can occur during building of the client
58#[derive(Error, Debug)]
59pub enum BuilderError {
60    #[error("negotiate error")]
61    Negotiate {
62        #[from]
63        source: NegotiateError,
64    },
65    #[error("invalid {0} url")]
66    Url(String),
67    #[error("transport error")]
68    Transport {
69        #[from]
70        source: TransportError,
71    },
72}
73
74/// Errors that can occur during negotiating protocol version
75#[derive(Error, Debug)]
76pub enum NegotiateError {
77    #[error("request error")]
78    Request {
79        #[from]
80        source: dusks_reqwest::Error,
81    },
82    #[error("deserialization error")]
83    Deserialization {
84        #[from]
85        source: serde_json::Error,
86    },
87    #[error("server does not support requested features")]
88    Unsupported,
89}
90
91impl ClientBuilder {
92    pub fn new(domain: impl ToString) -> Self {
93        ClientBuilder {
94            domain: domain.to_string(),
95            hub: None,
96            auth: Auth::None,
97            secure_connection: true,
98            port: None,
99            query_string: None,
100            hub_path: None,
101        }
102    }
103
104    /// Specifies port on the server to connect to.
105    pub fn use_port(mut self, port: usize) -> Self {
106        self.port = Some(port);
107        self
108    }
109
110    /// If used, client will use unencrypted connection
111    ///
112    /// For example HTTP will be used instead of HTTPS and WS instead of WSS.
113    /// **Use only when necessary and for local development.**
114    pub fn use_unencrypted_connection(mut self) -> Self {
115        self.secure_connection = false;
116        self
117    }
118
119    /// Specifies authentication to use
120    pub fn use_authentication(mut self, auth: Auth) -> Self {
121        self.auth = auth;
122        self
123    }
124
125    /// Specifies query string to attch to handshake on the server.
126    ///
127    /// Since life of a SignalR connection begins with HTTP request it is possible to attach some data in query string.
128    /// Some servers would use this data to have initial information about new connection.
129    /// There are no standard obligatory parameters, what is obligatory or nice-to-have is dependent of particual hub.
130    pub fn use_query_string(mut self, query: String) -> Self {
131        self.query_string = Some(query);
132        self
133    }
134
135    /// Specifies path to a hub on the server to use
136    ///
137    /// It should be a full path without first `/` e.g `echo` or `full/path/to/echo`.
138    pub fn use_hub(mut self, hub: impl ToString) -> Self {
139        self.hub_path = Some(hub.to_string());
140        self
141    }
142
143    /// Specifies a [`Hub`] to use on the client side.
144    ///
145    /// SignalR allows servers to invoke methods on a client. Pass a hub here to allow server invoking its methods.
146    pub fn with_client_hub(mut self, hub: Hub) -> Self {
147        self.hub = Some(hub);
148        self
149    }
150
151    /// Builds an actual clients
152    ///
153    /// Performs protocol negotiation and server handshake.
154    pub async fn build(self) -> Result<SignalRClient, BuilderError> {
155        let negotiate_response = self.get_server_supported_features().await?;
156
157        if !can_connect(negotiate_response) {
158            return Err(BuilderError::Negotiate {
159                source: NegotiateError::Unsupported,
160            });
161        }
162
163        let mut ws_handle = self.connect_websocket().await?;
164
165        let (tx, rx) = flume::bounded::<ClientMessage>(1);
166
167        let (transport_handle, client) = crate::new_client(tx, self.hub);
168
169        transport::websocket::handshake(&mut ws_handle)
170            .await
171            .map_err(|error| BuilderError::Transport { source: error })?;
172
173        let transport_future = transport::websocket::websocket_hub(ws_handle, transport_handle, rx);
174
175        tokio::spawn(transport_future);
176
177        event!(Level::DEBUG, "constructed client");
178
179        Ok(client)
180    }
181
182    async fn connect_websocket(
183        &self,
184    ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, BuilderError> {
185        let scheme = self.get_ws_scheme();
186        let domain_and_path = self.get_domain_with_path();
187        let query = self.get_query_string();
188
189        let url = format!("{}://{}?{}", scheme, domain_and_path, query);
190
191        let (ws_handle, _) = tokio_tungstenite::connect_async(url)
192            .await
193            .map_err(|error| BuilderError::Transport {
194                source: TransportError::Websocket { source: error },
195            })?;
196
197        Ok(ws_handle)
198    }
199
200    async fn get_server_supported_features(&self) -> Result<NegotiateResponseV0, NegotiateError> {
201        let negotiate_endpoint = format!(
202            "{}://{}/negotiate?{}",
203            self.get_http_scheme(),
204            self.get_domain_with_path(),
205            self.get_query_string()
206        );
207
208        let mut request = dusks_reqwest::Client::new().post(negotiate_endpoint);
209
210        request = match &self.auth {
211            Auth::None => request,
212            Auth::Basic { user, password } => request.basic_auth(user, password.clone()),
213            Auth::Bearer { token } => request.bearer_auth(token),
214            Auth::Res { user_id, token } => request.res_auth(user_id, token),
215        };
216
217        let http_response = request.send().await?.error_for_status()?;
218
219        let response: NegotiateResponseV0 = serde_json::from_str(&http_response.text().await?)?;
220
221        Ok(response)
222    }
223
224    fn get_query_string(&self) -> String {
225        if let Some(qs) = &self.query_string {
226            qs.clone()
227        } else {
228            Default::default()
229        }
230    }
231
232    fn get_http_scheme(&self) -> &str {
233        if self.secure_connection {
234            "https"
235        } else {
236            "http"
237        }
238    }
239
240    fn get_ws_scheme(&self) -> &str {
241        if self.secure_connection {
242            "wss"
243        } else {
244            "ws"
245        }
246    }
247
248    fn get_domain_with_path(&self) -> String {
249        match (&self.hub_path, &self.port) {
250            (None, None) => self.domain.clone(),
251            (None, Some(port)) => format!("{}:{}", self.domain, port),
252            (Some(path), None) => format!("{}/{}", self.domain, path),
253            (Some(path), Some(port)) => format!("{}:{}/{}", self.domain, port, path),
254        }
255    }
256}
257
258fn can_connect(negotiate_response: NegotiateResponseV0) -> bool {
259    negotiate_response
260        .available_transports
261        .iter()
262        .find(|i| i.transport == crate::protocol::WEB_SOCKET_TRANSPORT)
263        .and_then(|i| {
264            i.transfer_formats
265                .iter()
266                .find(|j| j.as_str() == crate::protocol::TEXT_TRANSPORT_FORMAT)
267        })
268        .is_some()
269}