signalrs_client/
lib.rs

1//! # SignalR client
2//!
3//! SignalR is an open-source protocol that simplifies adding real-time web functionality to apps.
4//! Real-time web functionality enables server-side code to push content to clients instantly.
5//!
6//! This library is an open source implementation of this protocol's client in Rust.
7//! Originally developed at Microsoft in .NET ecosystem. Read more about it in [`offical documentation`].
8//!
9//! In technical terms it is a RPC framework with bidirectional streaming capabilities.
10//!
11//! ## Why SignalR
12//!
13//! ### Ergonomics
14//!
15//! It allows bidirectional communication with ergonimic programming model.
16//! In cases where real-time communication is required it provides an easy to use framework, abstracting underlying transport layer.
17//! Getting WebSockets right is not an easy task.
18//!
19//! ### Integration with existing services
20//!
21//! Since SignalR originated in .NET ecosystem, there are services that expose SignalR endpoints. This library allows easy integration with them.
22//! This might be especially true for internal tooling at companies that do mostly C#. Truth to be told, it was a reason this library was created in the first place.
23//!  
24//! # Example
25//!
26//! ```rust, no_run
27//! use signalrs_client::SignalRClient;
28//!
29//! #[tokio::main]
30//! async fn main() -> anyhow::Result<()> {
31//!     let client = SignalRClient::builder("localhost")
32//!         .use_port(8080)
33//!         .use_hub("echo")
34//!         .build()
35//!         .await?;
36//!
37//!     let result = client
38//!         .method("echo")
39//!         .arg("message to the server")?
40//!         .invoke::<String>()
41//!         .await?;
42//!
43//! # Ok(())
44//! }
45//! ```
46//!
47//! For more examples see examples folder in [`signalrs-client` examples].
48//!
49//! # Features of SignalR supported by `signalrs_client`
50//!
51//! SignalR as a protocol is defined by two documents:
52//! - [HubProtcol]
53//! - [TransportProtocol]
54//!
55//! Those documents describe details and full capabilities of the protocol.
56//!
57//! Unfortunately, `signalrs_client` only supports a subset of all features, especially regarding supported transports and message formats.
58//! Hovewer, set of features supported is big enough for this library to be usable in simple scenarios.
59//!
60//! ## Known supported features
61//!
62//! - calling hub methods with all possible return types
63//! - calling hub methods using both value and stream arguments
64//! - client-side hub supports value arguments
65//!
66//! ### Not (yet) supported features
67//!
68//! - client-side hub with stream argumets
69//!
70//! ## Transport
71//!
72//! SignalR allows two types of transports to be used:
73//! - WebSockets
74//! - HTTP long polling + Server Sent Events
75//!
76//! This library only supports WebSockets now.
77//!
78//! ## Message encoding
79//!
80//! Two message encoding formats are allowed:
81//! - JSON
82//! - Message Pack
83//!
84//! This library only supports JSON now.
85//!
86//! [`offical documentation`]: https://learn.microsoft.com/en-us/aspnet/core/signalr/introduction?view=aspnetcore-7.0
87//! [HubProtcol]: https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md
88//! [TransportProtocol]: https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/TransportProtocols.md
89//! [`signalrs-client` examples]: https://github.com/szarykott/signalrs/tree/main/lib/signalrs-client/examples
90
91#![deny(unsafe_code)]
92
93pub mod arguments;
94pub mod builder;
95pub mod error;
96pub mod hub;
97pub mod invocation;
98
99mod messages;
100mod protocol;
101mod stream_ext;
102mod transport;
103
104use self::{error::ClientError, invocation::InvocationBuilder};
105use crate::{
106    builder::ClientBuilder,
107    hub::Hub,
108    messages::ClientMessage,
109    protocol::{Completion, MessageType, StreamItem},
110};
111use flume::{r#async::RecvStream, Sender};
112use futures::{stream::FuturesUnordered, Stream, StreamExt};
113use serde::{de::DeserializeOwned, Deserialize};
114use std::{
115    collections::HashMap,
116    marker::PhantomData,
117    pin::Pin,
118    sync::Arc,
119    task::{Context, Poll},
120};
121use tokio::task::JoinHandle;
122use tracing::*;
123
124pub struct SignalRClient {
125    invocations: Invocations,
126    transport_handle: Sender<ClientMessage>,
127}
128
129pub(crate) struct TransportClientHandle {
130    invocations: Invocations,
131    hub: Option<Hub>,
132}
133
134#[derive(Default, Clone)]
135pub(crate) struct Invocations {
136    invocations: Arc<std::sync::Mutex<HashMap<String, Sender<ClientMessageWrapper>>>>,
137}
138
139#[derive(Deserialize)]
140struct RoutingData {
141    #[serde(rename = "invocationId")]
142    pub invocation_id: Option<String>,
143    #[serde(rename = "type")]
144    pub message_type: MessageType,
145}
146
147pub(crate) enum Command {
148    None,
149    Close,
150}
151
152pub struct ResponseStream<'a, T> {
153    items: RecvStream<'a, ClientMessageWrapper>,
154    invocation_id: String,
155    client: &'a SignalRClient,
156    upload: JoinHandle<Result<(), ClientError>>,
157    _phantom: PhantomData<T>,
158}
159
160pub(crate) struct ClientMessageWrapper {
161    message_type: MessageType,
162    message: ClientMessage,
163}
164
165pub(crate) fn new_client(
166    transport_handle: Sender<ClientMessage>,
167    hub: Option<Hub>,
168) -> (TransportClientHandle, SignalRClient) {
169    let invocations = Invocations::default();
170    let transport_client_handle = TransportClientHandle::new(&invocations, hub);
171    let client = SignalRClient::new(&invocations, transport_handle);
172
173    (transport_client_handle, client)
174}
175
176impl TransportClientHandle {
177    pub(crate) fn new(invocations: &Invocations, hub: Option<Hub>) -> Self {
178        TransportClientHandle {
179            invocations: invocations.to_owned(),
180            hub,
181        }
182    }
183
184    pub(crate) fn receive_messages(&self, messages: ClientMessage) -> Result<Command, ClientError> {
185        for message in messages.split() {
186            // TODO: Add aggregate error subtype or log here
187            // TODO: service close properly
188            self.receive_message(message)?;
189        }
190
191        Ok(Command::None)
192    }
193
194    pub(crate) fn receive_message(&self, message: ClientMessage) -> Result<Command, ClientError> {
195        let RoutingData {
196            invocation_id,
197            message_type,
198        } = message
199            .deserialize()
200            .map_err(|error| ClientError::malformed_response(error))?;
201
202        return match message_type {
203            MessageType::Invocation => self.receive_invocation(message),
204            MessageType::Completion => self.receive_completion(invocation_id, message),
205            MessageType::StreamItem => self.receive_stream_item(invocation_id, message),
206            MessageType::Ping => self.receive_ping(),
207            MessageType::Close => self.receive_close(),
208            x => log_unsupported(x),
209        };
210
211        fn log_unsupported(message_type: MessageType) -> Result<Command, ClientError> {
212            warn!("received unsupported message type: {message_type}");
213            Ok(Command::None)
214        }
215    }
216
217    fn receive_invocation(&self, message: ClientMessage) -> Result<Command, ClientError> {
218        if let Some(hub) = &self.hub {
219            hub.call(message)?;
220        }
221
222        Ok(Command::None)
223    }
224
225    fn receive_completion(
226        &self,
227        invocation_id: Option<String>,
228        message: ClientMessage,
229    ) -> Result<Command, ClientError> {
230        let invocation_id = invocation_id.ok_or_else(|| {
231            ClientError::protocol_violation("received completion without invocation id")
232        })?;
233
234        let sender = self.invocations.remove_invocation(&invocation_id);
235
236        if let Some(sender) = sender {
237            if let Err(_) = sender.send(ClientMessageWrapper {
238                message_type: MessageType::Completion,
239                message,
240            }) {
241                warn!("received completion for a dropped invocation");
242                self.invocations.remove_invocation(&invocation_id);
243            }
244        } else {
245            warn!("received completion with unknown id");
246        }
247
248        Ok(Command::None)
249    }
250
251    fn receive_stream_item(
252        &self,
253        invocation_id: Option<String>,
254        message: ClientMessage,
255    ) -> Result<Command, ClientError> {
256        let invocation_id = invocation_id.ok_or_else(|| {
257            ClientError::protocol_violation("received stream item without stream id")
258        })?;
259
260        let sender = {
261            let invocations = self.invocations.invocations.lock().unwrap(); // TODO: can it be posioned, use parking_lot?
262            invocations
263                .get(&invocation_id)
264                .and_then(|sender| Some(sender.clone()))
265        };
266
267        if let Some(sender) = sender {
268            if let Err(_) = sender.send(ClientMessageWrapper {
269                message_type: MessageType::StreamItem,
270                message,
271            }) {
272                warn!("received stream item for a dropped invocation");
273                self.invocations.remove_stream_invocation(&invocation_id);
274            }
275        } else {
276            warn!("received stream item with unknown id");
277        }
278
279        Ok(Command::None)
280    }
281
282    fn receive_ping(&self) -> Result<Command, ClientError> {
283        debug!("ping received");
284        Ok(Command::None)
285    }
286
287    fn receive_close(&self) -> Result<Command, ClientError> {
288        info!("close received");
289        Ok(Command::Close)
290    }
291}
292
293impl SignalRClient {
294    pub fn builder(domain: impl ToString) -> ClientBuilder {
295        ClientBuilder::new(domain)
296    }
297
298    pub fn method<'a>(&'a self, method: impl ToString) -> InvocationBuilder<'a> {
299        InvocationBuilder::new(self, method)
300    }
301
302    pub(crate) fn new(invocations: &Invocations, transport_handle: Sender<ClientMessage>) -> Self {
303        SignalRClient {
304            invocations: invocations.to_owned(),
305            transport_handle,
306        }
307    }
308
309    pub(crate) fn get_transport_handle(&self) -> Sender<ClientMessage> {
310        self.transport_handle.clone()
311    }
312
313    pub(crate) async fn invoke_option<T>(
314        &self,
315        invocation_id: String,
316        message: ClientMessage,
317        streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
318    ) -> Result<Option<T>, ClientError>
319    where
320        T: DeserializeOwned,
321    {
322        let (tx, rx) = flume::bounded::<ClientMessageWrapper>(1);
323        self.invocations
324            .insert_invocation(invocation_id.to_owned(), tx);
325
326        if let Err(error) = self.send_message(message).await {
327            self.invocations.remove_invocation(&invocation_id);
328            return Err(error);
329        }
330
331        let upload = tokio::spawn(Self::send_streams(self.transport_handle.clone(), streams));
332
333        let result = rx.recv_async().await;
334        upload.abort();
335
336        self.invocations.remove_invocation(&invocation_id);
337
338        let completion = result
339            .map_err(|error| ClientError::no_response(error))
340            .and_then(|message| {
341                message
342                    .message
343                    .deserialize::<Completion<T>>()
344                    .map_err(|error| ClientError::malformed_response(error))
345            })?;
346
347        event!(Level::DEBUG, "response received");
348
349        if completion.is_result() {
350            Ok(Some(completion.unwrap_result()))
351        } else if completion.is_error() {
352            Err(ClientError::result(completion.unwrap_error()))
353        } else {
354            Ok(None)
355        }
356    }
357
358    pub(crate) async fn invoke_stream<'a, T>(
359        &'a self,
360        invocation_id: String,
361        message: ClientMessage,
362        streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
363    ) -> Result<ResponseStream<'a, T>, ClientError>
364    where
365        T: DeserializeOwned,
366    {
367        let (tx, rx) = flume::bounded::<ClientMessageWrapper>(100);
368        self.invocations
369            .insert_stream_invocation(invocation_id.to_owned(), tx);
370
371        if let Err(error) = self.send_message(message).await {
372            self.invocations.remove_stream_invocation(&invocation_id);
373            return Err(error);
374        }
375
376        let handle = tokio::spawn(Self::send_streams(self.transport_handle.clone(), streams));
377
378        let response_stream = ResponseStream {
379            items: rx.into_stream(),
380            invocation_id,
381            client: &self,
382            upload: handle,
383            _phantom: Default::default(),
384        };
385
386        Ok(response_stream)
387    }
388
389    pub(crate) async fn send_message(&self, message: ClientMessage) -> Result<(), ClientError> {
390        self.transport_handle
391            .send_async(message)
392            .await
393            .map_err(|e| ClientError::transport(e))?;
394
395        event!(Level::DEBUG, "message sent");
396
397        Ok(())
398    }
399
400    pub(crate) async fn send_streams(
401        transport_handle: Sender<ClientMessage>,
402        streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
403    ) -> Result<(), ClientError> {
404        let mut futures = FuturesUnordered::new();
405        for stream in streams.into_iter() {
406            futures.push(Self::send_stream_internal(&transport_handle, stream));
407        }
408
409        while let Some(result) = futures.next().await {
410            result?;
411        }
412
413        Ok(())
414    }
415
416    async fn send_stream_internal(
417        transport_handle: &Sender<ClientMessage>,
418        mut stream: Box<dyn Stream<Item = ClientMessage> + Unpin + Send>,
419    ) -> Result<(), ClientError> {
420        while let Some(item) = stream.next().await {
421            transport_handle
422                .send_async(item)
423                .await
424                .map_err(|e| ClientError::transport(e))?;
425
426            event!(Level::TRACE, "stream item sent");
427        }
428
429        event!(Level::DEBUG, "stream sent");
430
431        Ok(())
432    }
433}
434
435impl Invocations {
436    fn insert_invocation(&self, id: String, sender: flume::Sender<ClientMessageWrapper>) {
437        let mut invocations = self.invocations.lock().unwrap();
438        (*invocations).insert(id, sender);
439    }
440
441    fn insert_stream_invocation(&self, id: String, sender: flume::Sender<ClientMessageWrapper>) {
442        let mut invocations = self.invocations.lock().unwrap();
443        (*invocations).insert(id, sender);
444    }
445
446    pub fn remove_invocation(&self, id: &String) -> Option<flume::Sender<ClientMessageWrapper>> {
447        let mut invocations = self.invocations.lock().unwrap();
448        (*invocations).remove(id)
449    }
450
451    pub fn remove_stream_invocation(&self, id: &String) {
452        let mut invocations = self.invocations.lock().unwrap();
453        (*invocations).remove(id);
454    }
455}
456
457impl<'a, T> Drop for ResponseStream<'a, T> {
458    fn drop(&mut self) {
459        self.client
460            .invocations
461            .remove_stream_invocation(&self.invocation_id);
462
463        self.upload.abort();
464    }
465}
466
467// took this hack from: https://users.rust-lang.org/t/cannot-assign-to-data-in-a-dereference-of-pin-mut-myfutureimpl-t/70887
468impl<'a, T> Unpin for ResponseStream<'a, T> {}
469
470impl<'a, T> Stream for ResponseStream<'a, T>
471where
472    T: DeserializeOwned,
473{
474    type Item = Result<T, ClientError>;
475
476    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
477        match self.items.poll_next_unpin(cx) {
478            Poll::Ready(Some(message_wrapper)) => match message_wrapper.message_type {
479                MessageType::StreamItem => {
480                    let item = message_wrapper
481                        .message
482                        .deserialize::<StreamItem<T>>()
483                        .map_err(|e| ClientError::malformed_response(e))
484                        .and_then(|item| Ok(item.item));
485                    Poll::Ready(Some(item))
486                }
487                MessageType::Completion => {
488                    let deserialized = message_wrapper.message.deserialize::<Completion<T>>();
489
490                    match deserialized {
491                        Ok(completion) => {
492                            if completion.is_error() {
493                                error!(
494                                    "invocation ended with error: {}",
495                                    completion.unwrap_error()
496                                );
497                            }
498                        }
499                        Err(error) => error!("completion deserialization error: {}", error),
500                    }
501
502                    Poll::Ready(None)
503                }
504                _ => unreachable!(),
505            },
506            Poll::Ready(None) => Poll::Ready(None),
507            Poll::Pending => Poll::Pending,
508        }
509    }
510}