1#![deny(unsafe_code)]
92
93pub mod arguments;
94pub mod builder;
95pub mod error;
96pub mod hub;
97pub mod invocation;
98
99mod messages;
100mod protocol;
101mod stream_ext;
102mod transport;
103
104use self::{error::ClientError, invocation::InvocationBuilder};
105use crate::{
106 builder::ClientBuilder,
107 hub::Hub,
108 messages::ClientMessage,
109 protocol::{Completion, MessageType, StreamItem},
110};
111use flume::{r#async::RecvStream, Sender};
112use futures::{stream::FuturesUnordered, Stream, StreamExt};
113use serde::{de::DeserializeOwned, Deserialize};
114use std::{
115 collections::HashMap,
116 marker::PhantomData,
117 pin::Pin,
118 sync::Arc,
119 task::{Context, Poll},
120};
121use tokio::task::JoinHandle;
122use tracing::*;
123
124pub struct SignalRClient {
125 invocations: Invocations,
126 transport_handle: Sender<ClientMessage>,
127}
128
129pub(crate) struct TransportClientHandle {
130 invocations: Invocations,
131 hub: Option<Hub>,
132}
133
134#[derive(Default, Clone)]
135pub(crate) struct Invocations {
136 invocations: Arc<std::sync::Mutex<HashMap<String, Sender<ClientMessageWrapper>>>>,
137}
138
139#[derive(Deserialize)]
140struct RoutingData {
141 #[serde(rename = "invocationId")]
142 pub invocation_id: Option<String>,
143 #[serde(rename = "type")]
144 pub message_type: MessageType,
145}
146
147pub(crate) enum Command {
148 None,
149 Close,
150}
151
152pub struct ResponseStream<'a, T> {
153 items: RecvStream<'a, ClientMessageWrapper>,
154 invocation_id: String,
155 client: &'a SignalRClient,
156 upload: JoinHandle<Result<(), ClientError>>,
157 _phantom: PhantomData<T>,
158}
159
160pub(crate) struct ClientMessageWrapper {
161 message_type: MessageType,
162 message: ClientMessage,
163}
164
165pub(crate) fn new_client(
166 transport_handle: Sender<ClientMessage>,
167 hub: Option<Hub>,
168) -> (TransportClientHandle, SignalRClient) {
169 let invocations = Invocations::default();
170 let transport_client_handle = TransportClientHandle::new(&invocations, hub);
171 let client = SignalRClient::new(&invocations, transport_handle);
172
173 (transport_client_handle, client)
174}
175
176impl TransportClientHandle {
177 pub(crate) fn new(invocations: &Invocations, hub: Option<Hub>) -> Self {
178 TransportClientHandle {
179 invocations: invocations.to_owned(),
180 hub,
181 }
182 }
183
184 pub(crate) fn receive_messages(&self, messages: ClientMessage) -> Result<Command, ClientError> {
185 for message in messages.split() {
186 self.receive_message(message)?;
189 }
190
191 Ok(Command::None)
192 }
193
194 pub(crate) fn receive_message(&self, message: ClientMessage) -> Result<Command, ClientError> {
195 let RoutingData {
196 invocation_id,
197 message_type,
198 } = message
199 .deserialize()
200 .map_err(|error| ClientError::malformed_response(error))?;
201
202 return match message_type {
203 MessageType::Invocation => self.receive_invocation(message),
204 MessageType::Completion => self.receive_completion(invocation_id, message),
205 MessageType::StreamItem => self.receive_stream_item(invocation_id, message),
206 MessageType::Ping => self.receive_ping(),
207 MessageType::Close => self.receive_close(),
208 x => log_unsupported(x),
209 };
210
211 fn log_unsupported(message_type: MessageType) -> Result<Command, ClientError> {
212 warn!("received unsupported message type: {message_type}");
213 Ok(Command::None)
214 }
215 }
216
217 fn receive_invocation(&self, message: ClientMessage) -> Result<Command, ClientError> {
218 if let Some(hub) = &self.hub {
219 hub.call(message)?;
220 }
221
222 Ok(Command::None)
223 }
224
225 fn receive_completion(
226 &self,
227 invocation_id: Option<String>,
228 message: ClientMessage,
229 ) -> Result<Command, ClientError> {
230 let invocation_id = invocation_id.ok_or_else(|| {
231 ClientError::protocol_violation("received completion without invocation id")
232 })?;
233
234 let sender = self.invocations.remove_invocation(&invocation_id);
235
236 if let Some(sender) = sender {
237 if let Err(_) = sender.send(ClientMessageWrapper {
238 message_type: MessageType::Completion,
239 message,
240 }) {
241 warn!("received completion for a dropped invocation");
242 self.invocations.remove_invocation(&invocation_id);
243 }
244 } else {
245 warn!("received completion with unknown id");
246 }
247
248 Ok(Command::None)
249 }
250
251 fn receive_stream_item(
252 &self,
253 invocation_id: Option<String>,
254 message: ClientMessage,
255 ) -> Result<Command, ClientError> {
256 let invocation_id = invocation_id.ok_or_else(|| {
257 ClientError::protocol_violation("received stream item without stream id")
258 })?;
259
260 let sender = {
261 let invocations = self.invocations.invocations.lock().unwrap(); invocations
263 .get(&invocation_id)
264 .and_then(|sender| Some(sender.clone()))
265 };
266
267 if let Some(sender) = sender {
268 if let Err(_) = sender.send(ClientMessageWrapper {
269 message_type: MessageType::StreamItem,
270 message,
271 }) {
272 warn!("received stream item for a dropped invocation");
273 self.invocations.remove_stream_invocation(&invocation_id);
274 }
275 } else {
276 warn!("received stream item with unknown id");
277 }
278
279 Ok(Command::None)
280 }
281
282 fn receive_ping(&self) -> Result<Command, ClientError> {
283 debug!("ping received");
284 Ok(Command::None)
285 }
286
287 fn receive_close(&self) -> Result<Command, ClientError> {
288 info!("close received");
289 Ok(Command::Close)
290 }
291}
292
293impl SignalRClient {
294 pub fn builder(domain: impl ToString) -> ClientBuilder {
295 ClientBuilder::new(domain)
296 }
297
298 pub fn method<'a>(&'a self, method: impl ToString) -> InvocationBuilder<'a> {
299 InvocationBuilder::new(self, method)
300 }
301
302 pub(crate) fn new(invocations: &Invocations, transport_handle: Sender<ClientMessage>) -> Self {
303 SignalRClient {
304 invocations: invocations.to_owned(),
305 transport_handle,
306 }
307 }
308
309 pub(crate) fn get_transport_handle(&self) -> Sender<ClientMessage> {
310 self.transport_handle.clone()
311 }
312
313 pub(crate) async fn invoke_option<T>(
314 &self,
315 invocation_id: String,
316 message: ClientMessage,
317 streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
318 ) -> Result<Option<T>, ClientError>
319 where
320 T: DeserializeOwned,
321 {
322 let (tx, rx) = flume::bounded::<ClientMessageWrapper>(1);
323 self.invocations
324 .insert_invocation(invocation_id.to_owned(), tx);
325
326 if let Err(error) = self.send_message(message).await {
327 self.invocations.remove_invocation(&invocation_id);
328 return Err(error);
329 }
330
331 let upload = tokio::spawn(Self::send_streams(self.transport_handle.clone(), streams));
332
333 let result = rx.recv_async().await;
334 upload.abort();
335
336 self.invocations.remove_invocation(&invocation_id);
337
338 let completion = result
339 .map_err(|error| ClientError::no_response(error))
340 .and_then(|message| {
341 message
342 .message
343 .deserialize::<Completion<T>>()
344 .map_err(|error| ClientError::malformed_response(error))
345 })?;
346
347 event!(Level::DEBUG, "response received");
348
349 if completion.is_result() {
350 Ok(Some(completion.unwrap_result()))
351 } else if completion.is_error() {
352 Err(ClientError::result(completion.unwrap_error()))
353 } else {
354 Ok(None)
355 }
356 }
357
358 pub(crate) async fn invoke_stream<'a, T>(
359 &'a self,
360 invocation_id: String,
361 message: ClientMessage,
362 streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
363 ) -> Result<ResponseStream<'a, T>, ClientError>
364 where
365 T: DeserializeOwned,
366 {
367 let (tx, rx) = flume::bounded::<ClientMessageWrapper>(100);
368 self.invocations
369 .insert_stream_invocation(invocation_id.to_owned(), tx);
370
371 if let Err(error) = self.send_message(message).await {
372 self.invocations.remove_stream_invocation(&invocation_id);
373 return Err(error);
374 }
375
376 let handle = tokio::spawn(Self::send_streams(self.transport_handle.clone(), streams));
377
378 let response_stream = ResponseStream {
379 items: rx.into_stream(),
380 invocation_id,
381 client: &self,
382 upload: handle,
383 _phantom: Default::default(),
384 };
385
386 Ok(response_stream)
387 }
388
389 pub(crate) async fn send_message(&self, message: ClientMessage) -> Result<(), ClientError> {
390 self.transport_handle
391 .send_async(message)
392 .await
393 .map_err(|e| ClientError::transport(e))?;
394
395 event!(Level::DEBUG, "message sent");
396
397 Ok(())
398 }
399
400 pub(crate) async fn send_streams(
401 transport_handle: Sender<ClientMessage>,
402 streams: Vec<Box<dyn Stream<Item = ClientMessage> + Unpin + Send>>,
403 ) -> Result<(), ClientError> {
404 let mut futures = FuturesUnordered::new();
405 for stream in streams.into_iter() {
406 futures.push(Self::send_stream_internal(&transport_handle, stream));
407 }
408
409 while let Some(result) = futures.next().await {
410 result?;
411 }
412
413 Ok(())
414 }
415
416 async fn send_stream_internal(
417 transport_handle: &Sender<ClientMessage>,
418 mut stream: Box<dyn Stream<Item = ClientMessage> + Unpin + Send>,
419 ) -> Result<(), ClientError> {
420 while let Some(item) = stream.next().await {
421 transport_handle
422 .send_async(item)
423 .await
424 .map_err(|e| ClientError::transport(e))?;
425
426 event!(Level::TRACE, "stream item sent");
427 }
428
429 event!(Level::DEBUG, "stream sent");
430
431 Ok(())
432 }
433}
434
435impl Invocations {
436 fn insert_invocation(&self, id: String, sender: flume::Sender<ClientMessageWrapper>) {
437 let mut invocations = self.invocations.lock().unwrap();
438 (*invocations).insert(id, sender);
439 }
440
441 fn insert_stream_invocation(&self, id: String, sender: flume::Sender<ClientMessageWrapper>) {
442 let mut invocations = self.invocations.lock().unwrap();
443 (*invocations).insert(id, sender);
444 }
445
446 pub fn remove_invocation(&self, id: &String) -> Option<flume::Sender<ClientMessageWrapper>> {
447 let mut invocations = self.invocations.lock().unwrap();
448 (*invocations).remove(id)
449 }
450
451 pub fn remove_stream_invocation(&self, id: &String) {
452 let mut invocations = self.invocations.lock().unwrap();
453 (*invocations).remove(id);
454 }
455}
456
457impl<'a, T> Drop for ResponseStream<'a, T> {
458 fn drop(&mut self) {
459 self.client
460 .invocations
461 .remove_stream_invocation(&self.invocation_id);
462
463 self.upload.abort();
464 }
465}
466
467impl<'a, T> Unpin for ResponseStream<'a, T> {}
469
470impl<'a, T> Stream for ResponseStream<'a, T>
471where
472 T: DeserializeOwned,
473{
474 type Item = Result<T, ClientError>;
475
476 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
477 match self.items.poll_next_unpin(cx) {
478 Poll::Ready(Some(message_wrapper)) => match message_wrapper.message_type {
479 MessageType::StreamItem => {
480 let item = message_wrapper
481 .message
482 .deserialize::<StreamItem<T>>()
483 .map_err(|e| ClientError::malformed_response(e))
484 .and_then(|item| Ok(item.item));
485 Poll::Ready(Some(item))
486 }
487 MessageType::Completion => {
488 let deserialized = message_wrapper.message.deserialize::<Completion<T>>();
489
490 match deserialized {
491 Ok(completion) => {
492 if completion.is_error() {
493 error!(
494 "invocation ended with error: {}",
495 completion.unwrap_error()
496 );
497 }
498 }
499 Err(error) => error!("completion deserialization error: {}", error),
500 }
501
502 Poll::Ready(None)
503 }
504 _ => unreachable!(),
505 },
506 Poll::Ready(None) => Poll::Ready(None),
507 Poll::Pending => Poll::Pending,
508 }
509 }
510}