signalrs_client_custom_auth/
builder.rs1use super::{hub::Hub, transport, SignalRClient};
4use crate::{
5 messages::ClientMessage, protocol::NegotiateResponseV0, transport::error::TransportError,
6};
7use thiserror::Error;
8use tokio::net::TcpStream;
9use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
10use tracing::*;
11
12pub struct ClientBuilder {
32 domain: String,
33 hub: Option<Hub>,
34 auth: Auth,
35 secure_connection: bool,
36 port: Option<usize>,
37 query_string: Option<String>,
38 hub_path: Option<String>,
39}
40
41pub enum Auth {
43 None,
44 Basic {
45 user: String,
46 password: Option<String>,
47 },
48 Bearer {
49 token: String,
50 },
51 Res {
52 user_id: String,
53 token: String,
54 },
55}
56
57#[derive(Error, Debug)]
59pub enum BuilderError {
60 #[error("negotiate error")]
61 Negotiate {
62 #[from]
63 source: NegotiateError,
64 },
65 #[error("invalid {0} url")]
66 Url(String),
67 #[error("transport error")]
68 Transport {
69 #[from]
70 source: TransportError,
71 },
72}
73
74#[derive(Error, Debug)]
76pub enum NegotiateError {
77 #[error("request error")]
78 Request {
79 #[from]
80 source: dusks_reqwest::Error,
81 },
82 #[error("deserialization error")]
83 Deserialization {
84 #[from]
85 source: serde_json::Error,
86 },
87 #[error("server does not support requested features")]
88 Unsupported,
89}
90
91impl ClientBuilder {
92 pub fn new(domain: impl ToString) -> Self {
93 ClientBuilder {
94 domain: domain.to_string(),
95 hub: None,
96 auth: Auth::None,
97 secure_connection: true,
98 port: None,
99 query_string: None,
100 hub_path: None,
101 }
102 }
103
104 pub fn use_port(mut self, port: usize) -> Self {
106 self.port = Some(port);
107 self
108 }
109
110 pub fn use_unencrypted_connection(mut self) -> Self {
115 self.secure_connection = false;
116 self
117 }
118
119 pub fn use_authentication(mut self, auth: Auth) -> Self {
121 self.auth = auth;
122 self
123 }
124
125 pub fn use_query_string(mut self, query: String) -> Self {
131 self.query_string = Some(query);
132 self
133 }
134
135 pub fn use_hub(mut self, hub: impl ToString) -> Self {
139 self.hub_path = Some(hub.to_string());
140 self
141 }
142
143 pub fn with_client_hub(mut self, hub: Hub) -> Self {
147 self.hub = Some(hub);
148 self
149 }
150
151 pub async fn build(self) -> Result<SignalRClient, BuilderError> {
155 let negotiate_response = self.get_server_supported_features().await?;
156
157 if !can_connect(negotiate_response) {
158 return Err(BuilderError::Negotiate {
159 source: NegotiateError::Unsupported,
160 });
161 }
162
163 let mut ws_handle = self.connect_websocket().await?;
164
165 let (tx, rx) = flume::bounded::<ClientMessage>(1);
166
167 let (transport_handle, client) = crate::new_client(tx, self.hub);
168
169 transport::websocket::handshake(&mut ws_handle)
170 .await
171 .map_err(|error| BuilderError::Transport { source: error })?;
172
173 let transport_future = transport::websocket::websocket_hub(ws_handle, transport_handle, rx);
174
175 tokio::spawn(transport_future);
176
177 event!(Level::DEBUG, "constructed client");
178
179 Ok(client)
180 }
181
182 async fn connect_websocket(
183 &self,
184 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, BuilderError> {
185 let scheme = self.get_ws_scheme();
186 let domain_and_path = self.get_domain_with_path();
187 let query = self.get_query_string();
188
189 let url = format!("{}://{}?{}", scheme, domain_and_path, query);
190
191 let (ws_handle, _) = tokio_tungstenite::connect_async(url)
192 .await
193 .map_err(|error| BuilderError::Transport {
194 source: TransportError::Websocket { source: error },
195 })?;
196
197 Ok(ws_handle)
198 }
199
200 async fn get_server_supported_features(&self) -> Result<NegotiateResponseV0, NegotiateError> {
201 let negotiate_endpoint = format!(
202 "{}://{}/negotiate?{}",
203 self.get_http_scheme(),
204 self.get_domain_with_path(),
205 self.get_query_string()
206 );
207
208 let mut request = dusks_reqwest::Client::new().post(negotiate_endpoint);
209
210 request = match &self.auth {
211 Auth::None => request,
212 Auth::Basic { user, password } => request.basic_auth(user, password.clone()),
213 Auth::Bearer { token } => request.bearer_auth(token),
214 Auth::Res { user_id, token } => request.res_auth(user_id, token),
215 };
216
217 let http_response = request.send().await?.error_for_status()?;
218
219 let response: NegotiateResponseV0 = serde_json::from_str(&http_response.text().await?)?;
220
221 Ok(response)
222 }
223
224 fn get_query_string(&self) -> String {
225 if let Some(qs) = &self.query_string {
226 qs.clone()
227 } else {
228 Default::default()
229 }
230 }
231
232 fn get_http_scheme(&self) -> &str {
233 if self.secure_connection {
234 "https"
235 } else {
236 "http"
237 }
238 }
239
240 fn get_ws_scheme(&self) -> &str {
241 if self.secure_connection {
242 "wss"
243 } else {
244 "ws"
245 }
246 }
247
248 fn get_domain_with_path(&self) -> String {
249 match (&self.hub_path, &self.port) {
250 (None, None) => self.domain.clone(),
251 (None, Some(port)) => format!("{}:{}", self.domain, port),
252 (Some(path), None) => format!("{}/{}", self.domain, path),
253 (Some(path), Some(port)) => format!("{}:{}/{}", self.domain, port, path),
254 }
255 }
256}
257
258fn can_connect(negotiate_response: NegotiateResponseV0) -> bool {
259 negotiate_response
260 .available_transports
261 .iter()
262 .find(|i| i.transport == crate::protocol::WEB_SOCKET_TRANSPORT)
263 .and_then(|i| {
264 i.transfer_formats
265 .iter()
266 .find(|j| j.as_str() == crate::protocol::TEXT_TRANSPORT_FORMAT)
267 })
268 .is_some()
269}