signalrs_client_custom_auth/
lib.rs

1//! # SignalR client
2//!
3//! SignalR is an open-source protocol that simplifies adding real-time web functionality to apps.
4//! Real-time web functionality enables server-side code to push content to clients instantly.
5//!
6//! This library is an open source implementation of this protocol's client in Rust.
7//! Originally developed at Microsoft in .NET ecosystem. Read more about it in [`offical documentation`].
8//!
9//! In technical terms it is a RPC framework with bidirectional streaming capabilities.
10//!
11//! ## Why SignalR
12//!
13//! ### Ergonomics
14//!
15//! It allows bidirectional communication with ergonimic programming model.
16//! In cases where real-time communication is required it provides an easy to use framework, abstracting underlying transport layer.
17//! Getting WebSockets right is not an easy task.
18//!
19//! ### Integration with existing services
20//!
21//! Since SignalR originated in .NET ecosystem, there are services that expose SignalR endpoints. This library allows easy integration with them.
22//! This might be especially true for internal tooling at companies that do mostly C#. Truth to be told, it was a reason this library was created in the first place.
23//!  
24//! # Example
25//!
26//! ```rust, no_run
27//! use signalrs_client::SignalRClient;
28//!
29//! #[tokio::main]
30//! async fn main() -> anyhow::Result<()> {
31//!     let client = SignalRClient::builder("localhost")
32//!         .use_port(8080)
33//!         .use_hub("echo")
34//!         .build()
35//!         .await?;
36//!
37//!     let result = client
38//!         .method("echo")
39//!         .arg("message to the server")?
40//!         .invoke::<String>()
41//!         .await?;
42//!
43//! # Ok(())
44//! }
45//! ```
46//!
47//! For more examples see examples folder in [`signalrs-client` examples].
48//!
49//! # Features of SignalR supported by `signalrs_client`
50//!
51//! SignalR as a protocol is defined by two documents:
52//! - [HubProtcol]
53//! - [TransportProtocol]
54//!
55//! Those documents describe details and full capabilities of the protocol.
56//!
57//! Unfortunately, `signalrs_client` only supports a subset of all features, especially regarding supported transports and message formats.
58//! Hovewer, set of features supported is big enough for this library to be usable in simple scenarios.
59//!
60//! ## Known supported features
61//!
62//! - calling hub methods with all possible return types
63//! - calling hub methods using both value and stream arguments
64//! - client-side hub supports value arguments
65//!
66//! ### Not (yet) supported features
67//!
68//! - client-side hub with stream argumets
69//!
70//! ## Transport
71//!
72//! SignalR allows two types of transports to be used:
73//! - WebSockets
74//! - HTTP long polling + Server Sent Events
75//!
76//! This library only supports WebSockets now.
77//!
78//! ## Message encoding
79//!
80//! Two message encoding formats are allowed:
81//! - JSON
82//! - Message Pack
83//!
84//! This library only supports JSON now.
85//!
86//! [`offical documentation`]: https://learn.microsoft.com/en-us/aspnet/core/signalr/introduction?view=aspnetcore-7.0
87//! [HubProtcol]: https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md
88//! [TransportProtocol]: https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/TransportProtocols.md
89//! [`signalrs-client` examples]: https://github.com/szarykott/signalrs/tree/main/lib/signalrs-client/examples
90
91#![deny(unsafe_code)]
92
93pub mod arguments;
94pub mod builder;
95pub mod error;
96pub mod hub;
97pub mod invocation;
98
99mod messages;
100mod protocol;
101mod stream_ext;
102mod transport;
103
104use self::{error::ClientError, invocation::InvocationBuilder};
105use crate::{
106    builder::ClientBuilder,
107    hub::Hub,
108    messages::ClientMessage,
109    protocol::{Completion, MessageType, StreamItem},
110};
111use flume::{r#async::RecvStream, Sender};
112use futures::{stream::FuturesUnordered, Stream, StreamExt};
113use serde::{de::DeserializeOwned, Deserialize};
114use std::{
115    collections::HashMap,
116    marker::PhantomData,
117    pin::Pin,
118    sync::Arc,
119    task::{Context, Poll},
120};
121use tokio::task::JoinHandle;
122use tracing::*;
123
124pub struct SignalRClient {
125    invocations: Invocations,
126    transport_handle: Sender<ClientMessage>,
127}
128
129pub(crate) struct TransportClientHandle {
130    invocations: Invocations,
131    hub: Option<Hub>,
132}
133
134#[derive(Default, Clone)]
135pub(crate) struct Invocations {
136    invocations: Arc<std::sync::Mutex<HashMap<String, Sender<ClientMessageWrapper>>>>,
137}
138
139#[derive(Deserialize)]
140struct RoutingData {
141    #[serde(rename = "invocationId")]
142    pub invocation_id: Option<String>,
143    #[serde(rename = "type")]
144    pub message_type: MessageType,
145}
146
147pub(crate) enum Command {
148    None,
149    Close,
150}
151
152pub struct ResponseStream<'a, T> {
153    items: RecvStream<'a, ClientMessageWrapper>,
154    invocation_id: String,
155    client: &'a SignalRClient,
156    upload: JoinHandle<Result<(), ClientError>>,
157    _phantom: PhantomData<T>,
158}
159
160pub(crate) struct ClientMessageWrapper {
161    message_type: MessageType,
162    message: ClientMessage,
163}
164
165pub(crate) fn new_client(
166    transport_handle: Sender<ClientMessage>,
167    hub: Option<Hub>,
168) -> (TransportClientHandle, SignalRClient) {
169    let invocations = Invocations::default();
170    let transport_client_handle = TransportClientHandle::new(&invocations, hub);
171    let client = SignalRClient::new(&invocations, transport_handle);
172
173    (transport_client_handle, client)
174}
175
176impl TransportClientHandle {
177    pub(crate) fn new(invocations: &Invocations, hub: Option<Hub>) -> Self {
178        TransportClientHandle {
179            invocations: invocations.to_owned(),
180            hub,
181        }
182    }
183
184    pub(crate) fn receive_messages(&self, messages: ClientMessage) -> Result<Command, ClientError> {
185        for message in messages.split() {
186            // TODO: Add aggregate error subtype or log here
187            // TODO: service close properly
188            self.receive_message(message)?;
189        }
190
191        Ok(Command::None)
192    }
193
194    pub(crate) fn receive_message(&self, message: ClientMessage) -> Result<Command, ClientError> {
195        let RoutingData {
196            invocation_id,
197            message_type,
198        } = message
199            .deserialize()
200            .map_err(ClientError::malformed_response)?;
201
202        return match message_type {
203            MessageType::Invocation => self.receive_invocation(message),
204            MessageType::Completion => self.receive_completion(invocation_id, message),
205            MessageType::StreamItem => self.receive_stream_item(invocation_id, message),
206            MessageType::Ping => self.receive_ping(),
207            MessageType::Close => self.receive_close(),
208            x => log_unsupported(x),
209        };
210
211        fn log_unsupported(message_type: MessageType) -> Result<Command, ClientError> {
212            warn!("received unsupported message type: {message_type}");
213            Ok(Command::None)
214        }
215    }
216
217    fn receive_invocation(&self, message: ClientMessage) -> Result<Command, ClientError> {
218        if let Some(hub) = &self.hub {
219            hub.call(message)?;
220        }
221
222        Ok(Command::None)
223    }
224
225    fn receive_completion(
226        &self,
227        invocation_id: Option<String>,
228        message: ClientMessage,
229    ) -> Result<Command, ClientError> {
230        let invocation_id = invocation_id.ok_or_else(|| {
231            ClientError::protocol_violation("received completion without invocation id")
232        })?;
233
234        let sender = self.invocations.remove_invocation(&invocation_id);
235
236        if let Some(sender) = sender {
237            if sender
238                .send(ClientMessageWrapper {
239                    message_type: MessageType::Completion,
240                    message,
241                })
242                .is_err()
243            {
244                warn!("received completion for a dropped invocation");
245                self.invocations.remove_invocation(&invocation_id);
246            }
247        } else {
248            warn!("received completion with unknown id");
249        }
250
251        Ok(Command::None)
252    }
253
254    fn receive_stream_item(
255        &self,
256        invocation_id: Option<String>,
257        message: ClientMessage,
258    ) -> Result<Command, ClientError> {
259        let invocation_id = invocation_id.ok_or_else(|| {
260            ClientError::protocol_violation("received stream item without stream id")
261        })?;
262
263        let sender = {
264            let invocations = self.invocations.invocations.lock().unwrap(); // TODO: can it be posioned, use parking_lot?
265            invocations.get(&invocation_id).cloned()
266        };
267
268        if let Some(sender) = sender {
269            if sender
270                .send(ClientMessageWrapper {
271                    message_type: MessageType::StreamItem,
272                    message,
273                })
274                .is_err()
275            {
276                warn!("received stream item for a dropped invocation");
277                self.invocations.remove_stream_invocation(&invocation_id);
278            }
279        } else {
280            warn!("received stream item with unknown id");
281        }
282
283        Ok(Command::None)
284    }
285
286    fn receive_ping(&self) -> Result<Command, ClientError> {
287        debug!("ping received");
288        Ok(Command::None)
289    }
290
291    fn receive_close(&self) -> Result<Command, ClientError> {
292        info!("close received");
293        Ok(Command::Close)
294    }
295}
296
297impl SignalRClient {
298    pub fn builder(domain: impl ToString) -> ClientBuilder {
299        ClientBuilder::new(domain)
300    }
301
302    pub fn method(&self, method: impl ToString) -> InvocationBuilder<'_> {
303        InvocationBuilder::new(self, method)
304    }
305
306    pub(crate) fn new(invocations: &Invocations, transport_handle: Sender<ClientMessage>) -> Self {
307        SignalRClient {
308            invocations: invocations.to_owned(),
309            transport_handle,
310        }
311    }
312
313    pub(crate) fn get_transport_handle(&self) -> Sender<ClientMessage> {
314        self.transport_handle.clone()
315    }
316
317    pub(crate) async fn invoke_option<T>(
318        &self,
319        invocation_id: String,
320        message: ClientMessage,
321        streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
322    ) -> Result<Option<T>, ClientError>
323    where
324        T: DeserializeOwned,
325    {
326        let (tx, rx) = flume::bounded::<ClientMessageWrapper>(1);
327        self.invocations
328            .insert_invocation(invocation_id.to_owned(), tx);
329
330        if let Err(error) = self.send_message(message).await {
331            self.invocations.remove_invocation(&invocation_id);
332            return Err(error);
333        }
334
335        let upload = tokio::spawn(Self::send_streams(self.transport_handle.clone(), streams));
336
337        let result = rx.recv_async().await;
338        upload.abort();
339
340        self.invocations.remove_invocation(&invocation_id);
341
342        let completion = result
343            .map_err(ClientError::no_response)
344            .and_then(|message| {
345                message
346                    .message
347                    .deserialize::<Completion<T>>()
348                    .map_err(ClientError::malformed_response)
349            })?;
350
351        event!(Level::DEBUG, "response received");
352
353        if completion.is_result() {
354            Ok(Some(completion.unwrap_result()))
355        } else if completion.is_error() {
356            Err(ClientError::result(completion.unwrap_error()))
357        } else {
358            Ok(None)
359        }
360    }
361
362    pub(crate) async fn invoke_stream<T>(
363        &self,
364        invocation_id: String,
365        message: ClientMessage,
366        streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
367    ) -> Result<ResponseStream<'_, T>, ClientError>
368    where
369        T: DeserializeOwned,
370    {
371        let (tx, rx) = flume::bounded::<ClientMessageWrapper>(100);
372        self.invocations
373            .insert_stream_invocation(invocation_id.to_owned(), tx);
374
375        if let Err(error) = self.send_message(message).await {
376            self.invocations.remove_stream_invocation(&invocation_id);
377            return Err(error);
378        }
379
380        let handle = tokio::spawn(Self::send_streams(self.transport_handle.clone(), streams));
381
382        let response_stream = ResponseStream {
383            items: rx.into_stream(),
384            invocation_id,
385            client: self,
386            upload: handle,
387            _phantom: Default::default(),
388        };
389
390        Ok(response_stream)
391    }
392
393    pub(crate) async fn send_message(&self, message: ClientMessage) -> Result<(), ClientError> {
394        self.transport_handle
395            .send_async(message)
396            .await
397            .map_err(ClientError::transport)?;
398
399        event!(Level::DEBUG, "message sent");
400
401        Ok(())
402    }
403
404    pub(crate) async fn send_streams(
405        transport_handle: Sender<ClientMessage>,
406        streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
407    ) -> Result<(), ClientError> {
408        let mut futures = FuturesUnordered::new();
409        for stream in streams.into_iter() {
410            futures.push(Self::send_stream_internal(&transport_handle, stream));
411        }
412
413        while let Some(result) = futures.next().await {
414            result?;
415        }
416
417        Ok(())
418    }
419
420    async fn send_stream_internal(
421        transport_handle: &Sender<ClientMessage>,
422        mut stream: Box<dyn Stream<Item = ClientMessage> + Unpin + Send>,
423    ) -> Result<(), ClientError> {
424        while let Some(item) = stream.next().await {
425            transport_handle
426                .send_async(item)
427                .await
428                .map_err(ClientError::transport)?;
429
430            event!(Level::TRACE, "stream item sent");
431        }
432
433        event!(Level::DEBUG, "stream sent");
434
435        Ok(())
436    }
437}
438
439impl Invocations {
440    fn insert_invocation(&self, id: String, sender: flume::Sender<ClientMessageWrapper>) {
441        let mut invocations = self.invocations.lock().unwrap();
442        (*invocations).insert(id, sender);
443    }
444
445    fn insert_stream_invocation(&self, id: String, sender: flume::Sender<ClientMessageWrapper>) {
446        let mut invocations = self.invocations.lock().unwrap();
447        (*invocations).insert(id, sender);
448    }
449
450    pub fn remove_invocation(&self, id: &String) -> Option<flume::Sender<ClientMessageWrapper>> {
451        let mut invocations = self.invocations.lock().unwrap();
452        (*invocations).remove(id)
453    }
454
455    pub fn remove_stream_invocation(&self, id: &String) {
456        let mut invocations = self.invocations.lock().unwrap();
457        (*invocations).remove(id);
458    }
459}
460
461impl<'a, T> Drop for ResponseStream<'a, T> {
462    fn drop(&mut self) {
463        self.client
464            .invocations
465            .remove_stream_invocation(&self.invocation_id);
466
467        self.upload.abort();
468    }
469}
470
471// took this hack from: https://users.rust-lang.org/t/cannot-assign-to-data-in-a-dereference-of-pin-mut-myfutureimpl-t/70887
472impl<'a, T> Unpin for ResponseStream<'a, T> {}
473
474impl<'a, T> Stream for ResponseStream<'a, T>
475where
476    T: DeserializeOwned,
477{
478    type Item = Result<T, ClientError>;
479
480    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
481        match self.items.poll_next_unpin(cx) {
482            Poll::Ready(Some(message_wrapper)) => match message_wrapper.message_type {
483                MessageType::StreamItem => {
484                    let item = message_wrapper
485                        .message
486                        .deserialize::<StreamItem<T>>()
487                        .map_err(ClientError::malformed_response)
488                        .map(|item| item.item);
489                    Poll::Ready(Some(item))
490                }
491                MessageType::Completion => {
492                    let deserialized = message_wrapper.message.deserialize::<Completion<T>>();
493
494                    match deserialized {
495                        Ok(completion) => {
496                            if completion.is_error() {
497                                error!(
498                                    "invocation ended with error: {}",
499                                    completion.unwrap_error()
500                                );
501                            }
502                        }
503                        Err(error) => error!("completion deserialization error: {}", error),
504                    }
505
506                    Poll::Ready(None)
507                }
508                _ => unreachable!(),
509            },
510            Poll::Ready(None) => Poll::Ready(None),
511            Poll::Pending => Poll::Pending,
512        }
513    }
514}