sos_protocol/network_client/
websocket.rs

1//! Listen for change notifications on a websocket connection.
2use crate::{
3    network_client::{NetworkRetry, WebSocketRequest},
4    transfer::CancelReason,
5    Error, NetworkChangeEvent, Result, WireEncodeDecode,
6};
7use futures::{
8    stream::{Map, SplitStream},
9    Future, FutureExt, StreamExt,
10};
11use prost::bytes::Bytes;
12use sos_core::{AccountId, Origin};
13use sos_signer::ed25519::BoxedEd25519Signer;
14use std::pin::Pin;
15use tokio::{net::TcpStream, sync::watch, time::Duration};
16use tokio_tungstenite::{
17    connect_async,
18    tungstenite::{
19        self,
20        protocol::{
21            frame::{coding::CloseCode, Utf8Bytes},
22            CloseFrame, Message,
23        },
24    },
25    MaybeTlsStream, WebSocketStream,
26};
27
28use super::{bearer_prefix, encode_device_signature};
29
30/// Options used when listening for change notifications.
31#[derive(Clone)]
32pub struct ListenOptions {
33    /// Identifier for this connection.
34    ///
35    /// Should match the identifier used by the RPC
36    /// client so the server can ignore sending change notifications
37    /// to the caller.
38    pub(crate) connection_id: String,
39
40    /// Network retry state.
41    pub(crate) retry: NetworkRetry,
42}
43
44impl ListenOptions {
45    /// Create new listen options using the default retry
46    /// configuration.
47    pub fn new(connection_id: String) -> Result<Self> {
48        Ok(Self {
49            connection_id,
50            retry: NetworkRetry::new(16, 1000),
51        })
52    }
53
54    /// Create new listen options using a custom retry
55    /// configuration.
56    ///
57    pub fn new_retry(
58        connection_id: String,
59        retry: NetworkRetry,
60    ) -> Result<Self> {
61        Ok(Self {
62            connection_id,
63            retry,
64        })
65    }
66}
67
68/// Get the URI for a websocket changes connection.
69async fn request_bearer(
70    request: &mut WebSocketRequest,
71    device: &BoxedEd25519Signer,
72    connection_id: &str,
73) -> Result<String> {
74    //let endpoint = changes_endpoint_url(remote)?;
75    let sign_url = request.uri.path();
76
77    let device_signature =
78        encode_device_signature(device.sign(sign_url.as_bytes()).await?)
79            .await?;
80    let auth = bearer_prefix(&device_signature);
81
82    request
83        .uri
84        .query_pairs_mut()
85        .append_pair("connection_id", connection_id);
86
87    Ok(auth)
88}
89
90/// Type of stream created for websocket connections.
91pub type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
92
93/// Create the websocket connection and listen for events.
94pub async fn connect(
95    account_id: AccountId,
96    origin: Origin,
97    device: BoxedEd25519Signer,
98    connection_id: String,
99) -> Result<WsStream> {
100    let mut request = WebSocketRequest::new(
101        account_id,
102        origin.url(),
103        "api/v1/sync/changes",
104    )?;
105
106    let bearer =
107        request_bearer(&mut request, &device, &connection_id).await?;
108    request.set_bearer(bearer);
109
110    tracing::debug!(uri = %request.uri, "ws_client::connect");
111
112    let (ws_stream, _) = connect_async(request).await?;
113    Ok(ws_stream)
114}
115
116/// Read change messages from a websocket stream,
117/// and decode to change notifications that can
118/// be processed.
119pub fn changes(
120    stream: WsStream,
121) -> Map<
122    SplitStream<WsStream>,
123    impl FnMut(
124        std::result::Result<Message, tungstenite::Error>,
125    ) -> Result<
126        Pin<Box<dyn Future<Output = Result<NetworkChangeEvent>> + Send>>,
127    >,
128> {
129    let (_, read) = stream.split();
130    read.map(
131        move |message| -> Result<
132            Pin<Box<dyn Future<Output = Result<NetworkChangeEvent>> + Send>>,
133        > {
134            match message {
135                Ok(message) => Ok(Box::pin(async move {
136                    Ok(decode_notification(message).await?)
137                })),
138                Err(e) => Ok(Box::pin(async move { Err(e.into()) })),
139            }
140        },
141    )
142}
143
144async fn decode_notification(message: Message) -> Result<NetworkChangeEvent> {
145    match message {
146        Message::Binary(buffer) => {
147            let buf: Bytes = buffer.into();
148            let notification = NetworkChangeEvent::decode(buf).await?;
149            Ok(notification)
150        }
151        _ => Err(Error::NotBinaryWebsocketMessageType),
152    }
153}
154
155/// Handle to a websocket listener.
156#[derive(Clone)]
157pub struct WebSocketHandle {
158    notify: watch::Sender<()>,
159    cancel_retry: watch::Sender<CancelReason>,
160}
161
162impl WebSocketHandle {
163    /// Close the websocket.
164    pub async fn close(&self) {
165        tracing::debug!(
166            receivers = %self.notify.receiver_count(),
167            "ws_client::close");
168        if let Err(error) = self.notify.send(()) {
169            tracing::error!(error = ?error);
170        }
171
172        if let Err(error) = self.cancel_retry.send(CancelReason::Closed) {
173            tracing::error!(error = ?error);
174        }
175    }
176}
177
178/// Creates a websocket that listens for changes emitted by a remote
179/// server and invokes a handler with the change notifications.
180pub struct WebSocketChangeListener {
181    account_id: AccountId,
182    origin: Origin,
183    device: BoxedEd25519Signer,
184    options: ListenOptions,
185    shutdown: watch::Sender<()>,
186    cancel_retry: watch::Sender<CancelReason>,
187}
188
189impl WebSocketChangeListener {
190    /// Create a new websocket changes listener.
191    pub fn new(
192        account_id: AccountId,
193        origin: Origin,
194        device: BoxedEd25519Signer,
195        options: ListenOptions,
196    ) -> Self {
197        let (shutdown, _) = watch::channel(());
198        let (cancel_retry, _) = watch::channel(Default::default());
199        Self {
200            account_id,
201            origin,
202            device,
203            options,
204            shutdown,
205            cancel_retry,
206        }
207    }
208
209    /// Spawn a task to listen for changes notifications and invoke
210    /// the handler with the notifications.
211    pub fn spawn<F>(
212        self,
213        handler: impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static,
214    ) -> WebSocketHandle
215    where
216        F: Future<Output = ()> + Send + 'static,
217    {
218        let notify = self.shutdown.clone();
219        let cancel_retry = self.cancel_retry.clone();
220        tokio::task::spawn(async move {
221            let _ = self.connect_loop(&handler).await;
222        });
223        WebSocketHandle {
224            notify,
225            cancel_retry,
226        }
227    }
228
229    async fn listen<F>(
230        &self,
231        mut stream: WsStream,
232        handler: &(impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static),
233    ) -> Result<()>
234    where
235        F: Future<Output = ()> + Send + 'static,
236    {
237        tracing::debug!("ws_client::connected");
238
239        let mut shutdown_rx = self.shutdown.subscribe();
240        loop {
241            futures::select! {
242                _ = shutdown_rx.changed().fuse() => {
243                    tracing::debug!("ws_client::shutting_down");
244                    // Perform close handshake
245                    if let Err(error) = stream.close(Some(CloseFrame {
246                        code: CloseCode::Normal,
247                        reason: Utf8Bytes::from_static("closed"),
248                    })).await {
249                        tracing::warn!(
250                            error = ?error,
251                            "ws_client::websocket::close_error",
252                        );
253                    }
254                    tracing::debug!("ws_client::shutdown");
255                    return Ok(());
256                }
257                message = stream.next().fuse() => {
258                    if let Some(message) = message {
259                        match message {
260                            Ok(message) => {
261                                let notification = decode_notification(
262                                    message).await?;
263                                // Call the handler
264                                let future = handler(notification);
265                                future.await;
266                            }
267                            Err(e) => {
268                                tracing::error!(error = ?e);
269                                break;
270                            }
271                        }
272                    } else {
273                        break;
274                    }
275                }
276            }
277        }
278
279        tracing::debug!("ws_client::disconnected");
280        Ok(())
281    }
282
283    async fn stream(&self) -> Result<WsStream> {
284        connect(
285            self.account_id.clone(),
286            self.origin.clone(),
287            self.device.clone(),
288            self.options.connection_id.clone(),
289        )
290        .await
291    }
292
293    async fn connect_loop<F>(
294        &self,
295        handler: &(impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static),
296    ) -> Result<()>
297    where
298        F: Future<Output = ()> + Send + 'static,
299    {
300        let mut cancel_retry_rx = self.cancel_retry.subscribe();
301
302        loop {
303            tokio::select! {
304                _ = cancel_retry_rx.changed() => {
305                    tracing::debug!("ws_client::retry_canceled");
306                    return Ok(());
307                }
308                result = self.stream() => {
309                    match result {
310                        Ok(stream) => {
311                            self.options.retry.reset();
312                            if let Err(e) = self.listen(stream, handler).await {
313                                tracing::error!(
314                                    error = ?e,
315                                    "ws_client::listen_error");
316                            }
317                        }
318                        Err(e) => {
319                            tracing::error!(
320                                error = ?e,
321                                "ws_client::connect_error");
322                            let retries = self.options.retry.retries();
323                            if self.options.retry.is_exhausted(retries) {
324                                tracing::debug!(
325                                    maximum_retries = %self.options.retry.maximum_retries,
326                                    "wsclient::retry_attempts_exhausted");
327                                return Ok(());
328                            }
329                        }
330                    }
331                }
332            }
333
334            let retries = self.options.retry.retries();
335            let delay = self.options.retry.delay(retries)?;
336            let maximum = self.options.retry.maximum();
337            tracing::debug!(
338              retries = %retries,
339              delay = %delay,
340              maximum_retries = %maximum,
341              "ws_client::retry");
342
343            tokio::select! {
344                _ = tokio::time::sleep(Duration::from_millis(delay)) => {
345                  self.options.retry.increment();
346                }
347                _ = cancel_retry_rx.changed() => {
348                    tracing::debug!("ws_client::retry_canceled");
349                    return Ok(());
350                }
351            }
352        }
353    }
354}