sos_protocol/network_client/
websocket.rs1use crate::{
3 network_client::{NetworkRetry, WebSocketRequest},
4 transfer::CancelReason,
5 ChangeNotification, Error, Result, WireEncodeDecode,
6};
7use futures::{
8 stream::{Map, SplitStream},
9 Future, FutureExt, StreamExt,
10};
11use prost::bytes::Bytes;
12use sos_core::{AccountId, Origin};
13use sos_signer::ed25519::BoxedEd25519Signer;
14use std::pin::Pin;
15use tokio::{net::TcpStream, sync::watch, time::Duration};
16use tokio_tungstenite::{
17 connect_async,
18 tungstenite::{
19 self,
20 protocol::{
21 frame::{coding::CloseCode, Utf8Bytes},
22 CloseFrame, Message,
23 },
24 },
25 MaybeTlsStream, WebSocketStream,
26};
27
28use super::{bearer_prefix, encode_device_signature};
29
30#[derive(Clone)]
32pub struct ListenOptions {
33 pub(crate) connection_id: String,
39
40 pub(crate) retry: NetworkRetry,
42}
43
44impl ListenOptions {
45 pub fn new(connection_id: String) -> Result<Self> {
48 Ok(Self {
49 connection_id,
50 retry: NetworkRetry::new(16, 1000),
51 })
52 }
53
54 pub fn new_retry(
58 connection_id: String,
59 retry: NetworkRetry,
60 ) -> Result<Self> {
61 Ok(Self {
62 connection_id,
63 retry,
64 })
65 }
66}
67
68async fn request_bearer(
70 request: &mut WebSocketRequest,
71 device: &BoxedEd25519Signer,
72 connection_id: &str,
73) -> Result<String> {
74 let sign_url = request.uri.path();
76
77 let device_signature =
78 encode_device_signature(device.sign(sign_url.as_bytes()).await?)
79 .await?;
80 let auth = bearer_prefix(&device_signature);
81
82 request
83 .uri
84 .query_pairs_mut()
85 .append_pair("connection_id", connection_id);
86
87 Ok(auth)
88}
89
90pub type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
92
93pub async fn connect(
95 account_id: AccountId,
96 origin: Origin,
97 device: BoxedEd25519Signer,
98 connection_id: String,
99) -> Result<WsStream> {
100 let mut request = WebSocketRequest::new(
101 account_id,
102 origin.url(),
103 "api/v1/sync/changes",
104 )?;
105
106 let bearer =
107 request_bearer(&mut request, &device, &connection_id).await?;
108 request.set_bearer(bearer);
109
110 tracing::debug!(uri = %request.uri, "ws_client::connect");
111
112 let (ws_stream, _) = connect_async(request).await?;
113 Ok(ws_stream)
114}
115
116pub fn changes(
120 stream: WsStream,
121) -> Map<
122 SplitStream<WsStream>,
123 impl FnMut(
124 std::result::Result<Message, tungstenite::Error>,
125 ) -> Result<
126 Pin<Box<dyn Future<Output = Result<ChangeNotification>> + Send>>,
127 >,
128> {
129 let (_, read) = stream.split();
130 read.map(
131 move |message| -> Result<
132 Pin<Box<dyn Future<Output = Result<ChangeNotification>> + Send>>,
133 > {
134 match message {
135 Ok(message) => Ok(Box::pin(async move {
136 Ok(decode_notification(message).await?)
137 })),
138 Err(e) => Ok(Box::pin(async move { Err(e.into()) })),
139 }
140 },
141 )
142}
143
144async fn decode_notification(message: Message) -> Result<ChangeNotification> {
145 match message {
146 Message::Binary(buffer) => {
147 let buf: Bytes = buffer.into();
148 let notification = ChangeNotification::decode(buf).await?;
149 Ok(notification)
150 }
151 _ => Err(Error::NotBinaryWebsocketMessageType),
152 }
153}
154
155#[derive(Clone)]
157pub struct WebSocketHandle {
158 notify: watch::Sender<()>,
159 cancel_retry: watch::Sender<CancelReason>,
160}
161
162impl WebSocketHandle {
163 pub async fn close(&self) {
165 tracing::debug!(
166 receivers = %self.notify.receiver_count(),
167 "ws_client::close");
168 if let Err(error) = self.notify.send(()) {
169 tracing::error!(error = ?error);
170 }
171
172 if let Err(error) = self.cancel_retry.send(CancelReason::Closed) {
173 tracing::error!(error = ?error);
174 }
175 }
176}
177
178pub struct WebSocketChangeListener {
181 account_id: AccountId,
182 origin: Origin,
183 device: BoxedEd25519Signer,
184 options: ListenOptions,
185 shutdown: watch::Sender<()>,
186 cancel_retry: watch::Sender<CancelReason>,
187}
188
189impl WebSocketChangeListener {
190 pub fn new(
192 account_id: AccountId,
193 origin: Origin,
194 device: BoxedEd25519Signer,
195 options: ListenOptions,
196 ) -> Self {
197 let (shutdown, _) = watch::channel(());
198 let (cancel_retry, _) = watch::channel(Default::default());
199 Self {
200 account_id,
201 origin,
202 device,
203 options,
204 shutdown,
205 cancel_retry,
206 }
207 }
208
209 pub fn spawn<F>(
212 self,
213 handler: impl Fn(ChangeNotification) -> F + Send + Sync + 'static,
214 ) -> WebSocketHandle
215 where
216 F: Future<Output = ()> + Send + 'static,
217 {
218 let notify = self.shutdown.clone();
219 let cancel_retry = self.cancel_retry.clone();
220 tokio::task::spawn(async move {
221 let _ = self.connect_loop(&handler).await;
222 });
223 WebSocketHandle {
224 notify,
225 cancel_retry,
226 }
227 }
228
229 async fn listen<F>(
230 &self,
231 mut stream: WsStream,
232 handler: &(impl Fn(ChangeNotification) -> F + Send + Sync + 'static),
233 ) -> Result<()>
234 where
235 F: Future<Output = ()> + Send + 'static,
236 {
237 tracing::debug!("ws_client::connected");
238
239 let mut shutdown_rx = self.shutdown.subscribe();
240 loop {
241 futures::select! {
242 _ = shutdown_rx.changed().fuse() => {
243 tracing::debug!("ws_client::shutting_down");
244 if let Err(error) = stream.close(Some(CloseFrame {
246 code: CloseCode::Normal,
247 reason: Utf8Bytes::from_static("closed"),
248 })).await {
249 tracing::warn!(error = ?error);
250 }
251 tracing::debug!("ws_client::shutdown");
252 return Ok(());
253 }
254 message = stream.next().fuse() => {
255 if let Some(message) = message {
256 match message {
257 Ok(message) => {
258 let notification = decode_notification(
259 message).await?;
260 let future = handler(notification);
262 future.await;
263 }
264 Err(e) => {
265 tracing::error!(error = ?e);
266 break;
267 }
268 }
269 } else {
270 break;
271 }
272 }
273 }
274 }
275
276 tracing::debug!("ws_client::disconnected");
277 Ok(())
278 }
279
280 async fn stream(&self) -> Result<WsStream> {
281 connect(
282 self.account_id.clone(),
283 self.origin.clone(),
284 self.device.clone(),
285 self.options.connection_id.clone(),
286 )
287 .await
288 }
289
290 async fn connect_loop<F>(
291 &self,
292 handler: &(impl Fn(ChangeNotification) -> F + Send + Sync + 'static),
293 ) -> Result<()>
294 where
295 F: Future<Output = ()> + Send + 'static,
296 {
297 let mut cancel_retry_rx = self.cancel_retry.subscribe();
298
299 loop {
300 tokio::select! {
301 _ = cancel_retry_rx.changed() => {
302 tracing::debug!("ws_client::retry_canceled");
303 return Ok(());
304 }
305 result = self.stream() => {
306 match result {
307 Ok(stream) => {
308 self.options.retry.reset();
309 if let Err(e) = self.listen(stream, handler).await {
310 tracing::error!(
311 error = ?e,
312 "ws_client::listen_error");
313 }
314 }
315 Err(e) => {
316 tracing::error!(
317 error = ?e,
318 "ws_client::connect_error");
319 let retries = self.options.retry.retries();
320 if self.options.retry.is_exhausted(retries) {
321 tracing::debug!(
322 maximum_retries = %self.options.retry.maximum_retries,
323 "wsclient::retry_attempts_exhausted");
324 return Ok(());
325 }
326 }
327 }
328 }
329 }
330
331 let retries = self.options.retry.retries();
332 let delay = self.options.retry.delay(retries)?;
333 let maximum = self.options.retry.maximum();
334 tracing::debug!(
335 retries = %retries,
336 delay = %delay,
337 maximum_retries = %maximum,
338 "ws_client::retry");
339
340 tokio::select! {
341 _ = tokio::time::sleep(Duration::from_millis(delay)) => {
342 self.options.retry.increment();
343 }
344 _ = cancel_retry_rx.changed() => {
345 tracing::debug!("ws_client::retry_canceled");
346 return Ok(());
347 }
348 }
349 }
350 }
351}