sos_protocol/network_client/
websocket.rs1use crate::{
3 network_client::{NetworkRetry, WebSocketRequest},
4 transfer::CancelReason,
5 Error, NetworkChangeEvent, Result, WireEncodeDecode,
6};
7use futures::{
8 stream::{Map, SplitStream},
9 Future, FutureExt, StreamExt,
10};
11use prost::bytes::Bytes;
12use sos_core::{AccountId, Origin};
13use sos_signer::ed25519::BoxedEd25519Signer;
14use std::pin::Pin;
15use tokio::{net::TcpStream, sync::watch, time::Duration};
16use tokio_tungstenite::{
17 connect_async,
18 tungstenite::{
19 self,
20 protocol::{
21 frame::{coding::CloseCode, Utf8Bytes},
22 CloseFrame, Message,
23 },
24 },
25 MaybeTlsStream, WebSocketStream,
26};
27
28use super::{bearer_prefix, encode_device_signature};
29
30#[derive(Clone)]
32pub struct ListenOptions {
33 pub(crate) connection_id: String,
39
40 pub(crate) retry: NetworkRetry,
42}
43
44impl ListenOptions {
45 pub fn new(connection_id: String) -> Result<Self> {
48 Ok(Self {
49 connection_id,
50 retry: NetworkRetry::new(16, 1000),
51 })
52 }
53
54 pub fn new_retry(
58 connection_id: String,
59 retry: NetworkRetry,
60 ) -> Result<Self> {
61 Ok(Self {
62 connection_id,
63 retry,
64 })
65 }
66}
67
68async fn request_bearer(
70 request: &mut WebSocketRequest,
71 device: &BoxedEd25519Signer,
72 connection_id: &str,
73) -> Result<String> {
74 let sign_url = request.uri.path();
76
77 let device_signature =
78 encode_device_signature(device.sign(sign_url.as_bytes()).await?)
79 .await?;
80 let auth = bearer_prefix(&device_signature);
81
82 request
83 .uri
84 .query_pairs_mut()
85 .append_pair("connection_id", connection_id);
86
87 Ok(auth)
88}
89
90pub type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
92
93pub async fn connect(
95 account_id: AccountId,
96 origin: Origin,
97 device: BoxedEd25519Signer,
98 connection_id: String,
99) -> Result<WsStream> {
100 let mut request = WebSocketRequest::new(
101 account_id,
102 origin.url(),
103 "api/v1/sync/changes",
104 )?;
105
106 let bearer =
107 request_bearer(&mut request, &device, &connection_id).await?;
108 request.set_bearer(bearer);
109
110 tracing::debug!(uri = %request.uri, "ws_client::connect");
111
112 let (ws_stream, _) = connect_async(request).await?;
113 Ok(ws_stream)
114}
115
116pub fn changes(
120 stream: WsStream,
121) -> Map<
122 SplitStream<WsStream>,
123 impl FnMut(
124 std::result::Result<Message, tungstenite::Error>,
125 ) -> Result<
126 Pin<Box<dyn Future<Output = Result<NetworkChangeEvent>> + Send>>,
127 >,
128> {
129 let (_, read) = stream.split();
130 read.map(
131 move |message| -> Result<
132 Pin<Box<dyn Future<Output = Result<NetworkChangeEvent>> + Send>>,
133 > {
134 match message {
135 Ok(message) => Ok(Box::pin(async move {
136 Ok(decode_notification(message).await?)
137 })),
138 Err(e) => Ok(Box::pin(async move { Err(e.into()) })),
139 }
140 },
141 )
142}
143
144async fn decode_notification(message: Message) -> Result<NetworkChangeEvent> {
145 match message {
146 Message::Binary(buffer) => {
147 let buf: Bytes = buffer.into();
148 let notification = NetworkChangeEvent::decode(buf).await?;
149 Ok(notification)
150 }
151 _ => Err(Error::NotBinaryWebsocketMessageType),
152 }
153}
154
155#[derive(Clone)]
157pub struct WebSocketHandle {
158 notify: watch::Sender<()>,
159 cancel_retry: watch::Sender<CancelReason>,
160}
161
162impl WebSocketHandle {
163 pub async fn close(&self) {
165 tracing::debug!(
166 receivers = %self.notify.receiver_count(),
167 "ws_client::close");
168 if let Err(error) = self.notify.send(()) {
169 tracing::error!(error = ?error);
170 }
171
172 if let Err(error) = self.cancel_retry.send(CancelReason::Closed) {
173 tracing::error!(error = ?error);
174 }
175 }
176}
177
178pub struct WebSocketChangeListener {
181 account_id: AccountId,
182 origin: Origin,
183 device: BoxedEd25519Signer,
184 options: ListenOptions,
185 shutdown: watch::Sender<()>,
186 cancel_retry: watch::Sender<CancelReason>,
187}
188
189impl WebSocketChangeListener {
190 pub fn new(
192 account_id: AccountId,
193 origin: Origin,
194 device: BoxedEd25519Signer,
195 options: ListenOptions,
196 ) -> Self {
197 let (shutdown, _) = watch::channel(());
198 let (cancel_retry, _) = watch::channel(Default::default());
199 Self {
200 account_id,
201 origin,
202 device,
203 options,
204 shutdown,
205 cancel_retry,
206 }
207 }
208
209 pub fn spawn<F>(
212 self,
213 handler: impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static,
214 ) -> WebSocketHandle
215 where
216 F: Future<Output = ()> + Send + 'static,
217 {
218 let notify = self.shutdown.clone();
219 let cancel_retry = self.cancel_retry.clone();
220 tokio::task::spawn(async move {
221 let _ = self.connect_loop(&handler).await;
222 });
223 WebSocketHandle {
224 notify,
225 cancel_retry,
226 }
227 }
228
229 async fn listen<F>(
230 &self,
231 mut stream: WsStream,
232 handler: &(impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static),
233 ) -> Result<()>
234 where
235 F: Future<Output = ()> + Send + 'static,
236 {
237 tracing::debug!("ws_client::connected");
238
239 let mut shutdown_rx = self.shutdown.subscribe();
240 loop {
241 futures::select! {
242 _ = shutdown_rx.changed().fuse() => {
243 tracing::debug!("ws_client::shutting_down");
244 if let Err(error) = stream.close(Some(CloseFrame {
246 code: CloseCode::Normal,
247 reason: Utf8Bytes::from_static("closed"),
248 })).await {
249 tracing::warn!(
250 error = ?error,
251 "ws_client::websocket::close_error",
252 );
253 }
254 tracing::debug!("ws_client::shutdown");
255 return Ok(());
256 }
257 message = stream.next().fuse() => {
258 if let Some(message) = message {
259 match message {
260 Ok(message) => {
261 let notification = decode_notification(
262 message).await?;
263 let future = handler(notification);
265 future.await;
266 }
267 Err(e) => {
268 tracing::error!(error = ?e);
269 break;
270 }
271 }
272 } else {
273 break;
274 }
275 }
276 }
277 }
278
279 tracing::debug!("ws_client::disconnected");
280 Ok(())
281 }
282
283 async fn stream(&self) -> Result<WsStream> {
284 connect(
285 self.account_id.clone(),
286 self.origin.clone(),
287 self.device.clone(),
288 self.options.connection_id.clone(),
289 )
290 .await
291 }
292
293 async fn connect_loop<F>(
294 &self,
295 handler: &(impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static),
296 ) -> Result<()>
297 where
298 F: Future<Output = ()> + Send + 'static,
299 {
300 let mut cancel_retry_rx = self.cancel_retry.subscribe();
301
302 loop {
303 tokio::select! {
304 _ = cancel_retry_rx.changed() => {
305 tracing::debug!("ws_client::retry_canceled");
306 return Ok(());
307 }
308 result = self.stream() => {
309 match result {
310 Ok(stream) => {
311 self.options.retry.reset();
312 if let Err(e) = self.listen(stream, handler).await {
313 tracing::error!(
314 error = ?e,
315 "ws_client::listen_error");
316 }
317 }
318 Err(e) => {
319 tracing::error!(
320 error = ?e,
321 "ws_client::connect_error");
322 let retries = self.options.retry.retries();
323 if self.options.retry.is_exhausted(retries) {
324 tracing::debug!(
325 maximum_retries = %self.options.retry.maximum_retries,
326 "wsclient::retry_attempts_exhausted");
327 return Ok(());
328 }
329 }
330 }
331 }
332 }
333
334 let retries = self.options.retry.retries();
335 let delay = self.options.retry.delay(retries)?;
336 let maximum = self.options.retry.maximum();
337 tracing::debug!(
338 retries = %retries,
339 delay = %delay,
340 maximum_retries = %maximum,
341 "ws_client::retry");
342
343 tokio::select! {
344 _ = tokio::time::sleep(Duration::from_millis(delay)) => {
345 self.options.retry.increment();
346 }
347 _ = cancel_retry_rx.changed() => {
348 tracing::debug!("ws_client::retry_canceled");
349 return Ok(());
350 }
351 }
352 }
353 }
354}