signalrs_client_custom_auth/
lib.rs1#![deny(unsafe_code)]
92
93pub mod arguments;
94pub mod builder;
95pub mod error;
96pub mod hub;
97pub mod invocation;
98
99mod messages;
100mod protocol;
101mod stream_ext;
102mod transport;
103
104use self::{error::ClientError, invocation::InvocationBuilder};
105use crate::{
106 builder::ClientBuilder,
107 hub::Hub,
108 messages::ClientMessage,
109 protocol::{Completion, MessageType, StreamItem},
110};
111use flume::{r#async::RecvStream, Sender};
112use futures::{stream::FuturesUnordered, Stream, StreamExt};
113use serde::{de::DeserializeOwned, Deserialize};
114use std::{
115 collections::HashMap,
116 marker::PhantomData,
117 pin::Pin,
118 sync::Arc,
119 task::{Context, Poll},
120};
121use tokio::task::JoinHandle;
122use tracing::*;
123
124pub struct SignalRClient {
125 invocations: Invocations,
126 transport_handle: Sender<ClientMessage>,
127}
128
129pub(crate) struct TransportClientHandle {
130 invocations: Invocations,
131 hub: Option<Hub>,
132}
133
134#[derive(Default, Clone)]
135pub(crate) struct Invocations {
136 invocations: Arc<std::sync::Mutex<HashMap<String, Sender<ClientMessageWrapper>>>>,
137}
138
139#[derive(Deserialize)]
140struct RoutingData {
141 #[serde(rename = "invocationId")]
142 pub invocation_id: Option<String>,
143 #[serde(rename = "type")]
144 pub message_type: MessageType,
145}
146
147pub(crate) enum Command {
148 None,
149 Close,
150}
151
152pub struct ResponseStream<'a, T> {
153 items: RecvStream<'a, ClientMessageWrapper>,
154 invocation_id: String,
155 client: &'a SignalRClient,
156 upload: JoinHandle<Result<(), ClientError>>,
157 _phantom: PhantomData<T>,
158}
159
160pub(crate) struct ClientMessageWrapper {
161 message_type: MessageType,
162 message: ClientMessage,
163}
164
165pub(crate) fn new_client(
166 transport_handle: Sender<ClientMessage>,
167 hub: Option<Hub>,
168) -> (TransportClientHandle, SignalRClient) {
169 let invocations = Invocations::default();
170 let transport_client_handle = TransportClientHandle::new(&invocations, hub);
171 let client = SignalRClient::new(&invocations, transport_handle);
172
173 (transport_client_handle, client)
174}
175
176impl TransportClientHandle {
177 pub(crate) fn new(invocations: &Invocations, hub: Option<Hub>) -> Self {
178 TransportClientHandle {
179 invocations: invocations.to_owned(),
180 hub,
181 }
182 }
183
184 pub(crate) fn receive_messages(&self, messages: ClientMessage) -> Result<Command, ClientError> {
185 for message in messages.split() {
186 self.receive_message(message)?;
189 }
190
191 Ok(Command::None)
192 }
193
194 pub(crate) fn receive_message(&self, message: ClientMessage) -> Result<Command, ClientError> {
195 let RoutingData {
196 invocation_id,
197 message_type,
198 } = message
199 .deserialize()
200 .map_err(ClientError::malformed_response)?;
201
202 return match message_type {
203 MessageType::Invocation => self.receive_invocation(message),
204 MessageType::Completion => self.receive_completion(invocation_id, message),
205 MessageType::StreamItem => self.receive_stream_item(invocation_id, message),
206 MessageType::Ping => self.receive_ping(),
207 MessageType::Close => self.receive_close(),
208 x => log_unsupported(x),
209 };
210
211 fn log_unsupported(message_type: MessageType) -> Result<Command, ClientError> {
212 warn!("received unsupported message type: {message_type}");
213 Ok(Command::None)
214 }
215 }
216
217 fn receive_invocation(&self, message: ClientMessage) -> Result<Command, ClientError> {
218 if let Some(hub) = &self.hub {
219 hub.call(message)?;
220 }
221
222 Ok(Command::None)
223 }
224
225 fn receive_completion(
226 &self,
227 invocation_id: Option<String>,
228 message: ClientMessage,
229 ) -> Result<Command, ClientError> {
230 let invocation_id = invocation_id.ok_or_else(|| {
231 ClientError::protocol_violation("received completion without invocation id")
232 })?;
233
234 let sender = self.invocations.remove_invocation(&invocation_id);
235
236 if let Some(sender) = sender {
237 if sender
238 .send(ClientMessageWrapper {
239 message_type: MessageType::Completion,
240 message,
241 })
242 .is_err()
243 {
244 warn!("received completion for a dropped invocation");
245 self.invocations.remove_invocation(&invocation_id);
246 }
247 } else {
248 warn!("received completion with unknown id");
249 }
250
251 Ok(Command::None)
252 }
253
254 fn receive_stream_item(
255 &self,
256 invocation_id: Option<String>,
257 message: ClientMessage,
258 ) -> Result<Command, ClientError> {
259 let invocation_id = invocation_id.ok_or_else(|| {
260 ClientError::protocol_violation("received stream item without stream id")
261 })?;
262
263 let sender = {
264 let invocations = self.invocations.invocations.lock().unwrap(); invocations.get(&invocation_id).cloned()
266 };
267
268 if let Some(sender) = sender {
269 if sender
270 .send(ClientMessageWrapper {
271 message_type: MessageType::StreamItem,
272 message,
273 })
274 .is_err()
275 {
276 warn!("received stream item for a dropped invocation");
277 self.invocations.remove_stream_invocation(&invocation_id);
278 }
279 } else {
280 warn!("received stream item with unknown id");
281 }
282
283 Ok(Command::None)
284 }
285
286 fn receive_ping(&self) -> Result<Command, ClientError> {
287 debug!("ping received");
288 Ok(Command::None)
289 }
290
291 fn receive_close(&self) -> Result<Command, ClientError> {
292 info!("close received");
293 Ok(Command::Close)
294 }
295}
296
297impl SignalRClient {
298 pub fn builder(domain: impl ToString) -> ClientBuilder {
299 ClientBuilder::new(domain)
300 }
301
302 pub fn method(&self, method: impl ToString) -> InvocationBuilder<'_> {
303 InvocationBuilder::new(self, method)
304 }
305
306 pub(crate) fn new(invocations: &Invocations, transport_handle: Sender<ClientMessage>) -> Self {
307 SignalRClient {
308 invocations: invocations.to_owned(),
309 transport_handle,
310 }
311 }
312
313 pub(crate) fn get_transport_handle(&self) -> Sender<ClientMessage> {
314 self.transport_handle.clone()
315 }
316
317 pub(crate) async fn invoke_option<T>(
318 &self,
319 invocation_id: String,
320 message: ClientMessage,
321 streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
322 ) -> Result<Option<T>, ClientError>
323 where
324 T: DeserializeOwned,
325 {
326 let (tx, rx) = flume::bounded::<ClientMessageWrapper>(1);
327 self.invocations
328 .insert_invocation(invocation_id.to_owned(), tx);
329
330 if let Err(error) = self.send_message(message).await {
331 self.invocations.remove_invocation(&invocation_id);
332 return Err(error);
333 }
334
335 let upload = tokio::spawn(Self::send_streams(self.transport_handle.clone(), streams));
336
337 let result = rx.recv_async().await;
338 upload.abort();
339
340 self.invocations.remove_invocation(&invocation_id);
341
342 let completion = result
343 .map_err(ClientError::no_response)
344 .and_then(|message| {
345 message
346 .message
347 .deserialize::<Completion<T>>()
348 .map_err(ClientError::malformed_response)
349 })?;
350
351 event!(Level::DEBUG, "response received");
352
353 if completion.is_result() {
354 Ok(Some(completion.unwrap_result()))
355 } else if completion.is_error() {
356 Err(ClientError::result(completion.unwrap_error()))
357 } else {
358 Ok(None)
359 }
360 }
361
362 pub(crate) async fn invoke_stream<T>(
363 &self,
364 invocation_id: String,
365 message: ClientMessage,
366 streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
367 ) -> Result<ResponseStream<'_, T>, ClientError>
368 where
369 T: DeserializeOwned,
370 {
371 let (tx, rx) = flume::bounded::<ClientMessageWrapper>(100);
372 self.invocations
373 .insert_stream_invocation(invocation_id.to_owned(), tx);
374
375 if let Err(error) = self.send_message(message).await {
376 self.invocations.remove_stream_invocation(&invocation_id);
377 return Err(error);
378 }
379
380 let handle = tokio::spawn(Self::send_streams(self.transport_handle.clone(), streams));
381
382 let response_stream = ResponseStream {
383 items: rx.into_stream(),
384 invocation_id,
385 client: self,
386 upload: handle,
387 _phantom: Default::default(),
388 };
389
390 Ok(response_stream)
391 }
392
393 pub(crate) async fn send_message(&self, message: ClientMessage) -> Result<(), ClientError> {
394 self.transport_handle
395 .send_async(message)
396 .await
397 .map_err(ClientError::transport)?;
398
399 event!(Level::DEBUG, "message sent");
400
401 Ok(())
402 }
403
404 pub(crate) async fn send_streams(
405 transport_handle: Sender<ClientMessage>,
406 streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
407 ) -> Result<(), ClientError> {
408 let mut futures = FuturesUnordered::new();
409 for stream in streams.into_iter() {
410 futures.push(Self::send_stream_internal(&transport_handle, stream));
411 }
412
413 while let Some(result) = futures.next().await {
414 result?;
415 }
416
417 Ok(())
418 }
419
420 async fn send_stream_internal(
421 transport_handle: &Sender<ClientMessage>,
422 mut stream: Box<dyn Stream<Item = ClientMessage> + Unpin + Send>,
423 ) -> Result<(), ClientError> {
424 while let Some(item) = stream.next().await {
425 transport_handle
426 .send_async(item)
427 .await
428 .map_err(ClientError::transport)?;
429
430 event!(Level::TRACE, "stream item sent");
431 }
432
433 event!(Level::DEBUG, "stream sent");
434
435 Ok(())
436 }
437}
438
439impl Invocations {
440 fn insert_invocation(&self, id: String, sender: flume::Sender<ClientMessageWrapper>) {
441 let mut invocations = self.invocations.lock().unwrap();
442 (*invocations).insert(id, sender);
443 }
444
445 fn insert_stream_invocation(&self, id: String, sender: flume::Sender<ClientMessageWrapper>) {
446 let mut invocations = self.invocations.lock().unwrap();
447 (*invocations).insert(id, sender);
448 }
449
450 pub fn remove_invocation(&self, id: &String) -> Option<flume::Sender<ClientMessageWrapper>> {
451 let mut invocations = self.invocations.lock().unwrap();
452 (*invocations).remove(id)
453 }
454
455 pub fn remove_stream_invocation(&self, id: &String) {
456 let mut invocations = self.invocations.lock().unwrap();
457 (*invocations).remove(id);
458 }
459}
460
461impl<'a, T> Drop for ResponseStream<'a, T> {
462 fn drop(&mut self) {
463 self.client
464 .invocations
465 .remove_stream_invocation(&self.invocation_id);
466
467 self.upload.abort();
468 }
469}
470
471impl<'a, T> Unpin for ResponseStream<'a, T> {}
473
474impl<'a, T> Stream for ResponseStream<'a, T>
475where
476 T: DeserializeOwned,
477{
478 type Item = Result<T, ClientError>;
479
480 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
481 match self.items.poll_next_unpin(cx) {
482 Poll::Ready(Some(message_wrapper)) => match message_wrapper.message_type {
483 MessageType::StreamItem => {
484 let item = message_wrapper
485 .message
486 .deserialize::<StreamItem<T>>()
487 .map_err(ClientError::malformed_response)
488 .map(|item| item.item);
489 Poll::Ready(Some(item))
490 }
491 MessageType::Completion => {
492 let deserialized = message_wrapper.message.deserialize::<Completion<T>>();
493
494 match deserialized {
495 Ok(completion) => {
496 if completion.is_error() {
497 error!(
498 "invocation ended with error: {}",
499 completion.unwrap_error()
500 );
501 }
502 }
503 Err(error) => error!("completion deserialization error: {}", error),
504 }
505
506 Poll::Ready(None)
507 }
508 _ => unreachable!(),
509 },
510 Poll::Ready(None) => Poll::Ready(None),
511 Poll::Pending => Poll::Pending,
512 }
513 }
514}