sos_protocol/network_client/
mod.rs

1//! HTTP transport trait and implementations.
2use super::{Error, Result};
3use crate::transfer::CancelReason;
4use sos_core::encode;
5use sos_signer::ed25519::{
6    BinaryEd25519Signature, Signature as Ed25519Signature,
7};
8use std::{
9    future::Future,
10    sync::{
11        atomic::{AtomicU32, Ordering},
12        Arc,
13    },
14    time::Duration,
15};
16use tokio::sync::watch;
17
18mod http;
19#[cfg(feature = "listen")]
20mod websocket;
21
22pub use self::http::{set_user_agent, HttpClient};
23
24#[cfg(feature = "listen")]
25pub use websocket::{changes, connect, ListenOptions, WebSocketHandle};
26
27/// Network retry state and logic for exponential backoff.
28#[cfg(not(target_arch = "wasm32"))]
29#[derive(Debug, Clone)]
30pub struct NetworkRetry {
31    retries: Arc<AtomicU32>,
32    /// Reconnect interval.
33    pub reconnect_interval: u16,
34    /// Maximum number of retries.
35    pub maximum_retries: u32,
36}
37
38#[cfg(not(target_arch = "wasm32"))]
39impl Default for NetworkRetry {
40    fn default() -> Self {
41        Self::new(4, 1000)
42    }
43}
44
45#[cfg(not(target_arch = "wasm32"))]
46impl NetworkRetry {
47    /// Create a new network retry.
48    ///
49    /// The reconnect interval is a *base interval* in milliseconds
50    /// for the exponential backoff so use a small value such as
51    /// `1000` or `2000`.
52    pub fn new(maximum_retries: u32, reconnect_interval: u16) -> Self {
53        Self {
54            retries: Arc::new(AtomicU32::from(1)),
55            reconnect_interval,
56            maximum_retries,
57        }
58    }
59
60    /// Exponential backoff millisecond delay for a retry counter.
61    pub fn delay(&self, retries: u32) -> Result<u64> {
62        let factor = 2u64.checked_pow(retries).ok_or(Error::RetryOverflow)?;
63        Ok(self.reconnect_interval as u64 * factor)
64    }
65
66    /// Current number of retries.
67    pub fn retries(&self) -> u32 {
68        self.retries.load(Ordering::SeqCst)
69    }
70
71    /// Maximum number of retries.
72    pub fn maximum(&self) -> u32 {
73        self.maximum_retries
74    }
75
76    /// Reset retries counter.
77    pub fn reset(&self) {
78        self.retries.store(1, Ordering::SeqCst)
79    }
80
81    /// Clone of this network retry with the retry counter reset.
82    pub fn clone_reset(&self) -> Self {
83        Self {
84            retries: Arc::new(AtomicU32::from(1)),
85            reconnect_interval: self.reconnect_interval,
86            maximum_retries: self.maximum_retries,
87        }
88    }
89
90    /// Increment for next retry attempt.
91    pub fn increment(&self) -> u32 {
92        self.retries.fetch_add(1, Ordering::SeqCst)
93    }
94
95    /// Determine if retry attempts are exhausted.
96    pub fn is_exhausted(&self, retries: u32) -> bool {
97        retries > self.maximum_retries
98    }
99
100    /// Wait and then retry.
101    pub async fn wait_and_retry<D, T, F>(
102        &self,
103        id: D,
104        retries: u32,
105        callback: F,
106        mut cancel: watch::Receiver<CancelReason>,
107    ) -> Result<T>
108    where
109        D: std::fmt::Display,
110        F: Future<Output = T>,
111    {
112        let delay = self.delay(retries)?;
113        tracing::debug!(
114            id = %id,
115            delay = %delay,
116            retries = %retries,
117            maximum_retries = %self.maximum_retries,
118            "retry",
119        );
120
121        tokio::select! {
122            _ = cancel.changed() => {
123                let reason = cancel.borrow();
124                tracing::debug!(id = %id, "retry::canceled");
125                Err(Error::RetryCanceled(reason.clone()))
126            }
127            _ = tokio::time::sleep(Duration::from_millis(delay)) => {
128                Ok(callback.await)
129            }
130        }
131    }
132}
133
134pub(crate) async fn encode_device_signature(
135    signature: Ed25519Signature,
136) -> Result<String> {
137    let signature: BinaryEd25519Signature = signature.into();
138    Ok(bs58::encode(encode(&signature).await?).into_string())
139}
140
141pub(crate) fn bearer_prefix(device_signature: &str) -> String {
142    format!("Bearer {}", device_signature)
143}
144
145#[cfg(any(feature = "listen", feature = "pairing"))]
146mod websocket_request {
147    use crate::constants::X_SOS_ACCOUNT_ID;
148
149    use super::Result;
150    use sos_core::AccountId;
151    use tokio_tungstenite::tungstenite::{
152        self, client::IntoClientRequest, handshake::client::generate_key,
153    };
154    use url::Url;
155
156    /// Build a websocket connection request.
157    pub struct WebSocketRequest {
158        /// Account identifier.
159        pub account_id: AccountId,
160        /// Remote URI.
161        pub uri: Url,
162        /// Remote host.
163        pub host: String,
164        /// Bearer authentication.
165        pub bearer: Option<String>,
166        /// URL origin.
167        pub origin: url::Origin,
168    }
169
170    impl WebSocketRequest {
171        /// Create a new websocket request.
172        pub fn new(
173            account_id: AccountId,
174            url: &Url,
175            path: &str,
176        ) -> Result<Self> {
177            let origin = url.origin();
178            let host = url.host_str().unwrap().to_string();
179
180            let mut uri = url.join(path)?;
181            let scheme = if uri.scheme() == "http" {
182                "ws"
183            } else if uri.scheme() == "https" {
184                "wss"
185            } else {
186                panic!("bad url scheme for websocket, requires http(s)");
187            };
188
189            uri.set_scheme(scheme)
190                .expect("failed to set websocket scheme");
191
192            Ok(Self {
193                account_id,
194                host,
195                uri,
196                origin,
197                bearer: None,
198            })
199        }
200
201        /// Set bearer authorization.
202        pub fn set_bearer(&mut self, bearer: String) {
203            self.bearer = Some(bearer);
204        }
205    }
206
207    impl IntoClientRequest for WebSocketRequest {
208        fn into_client_request(
209            self,
210        ) -> std::result::Result<http::Request<()>, tungstenite::Error>
211        {
212            let origin = self.origin.unicode_serialization();
213            let mut request =
214                http::Request::builder().uri(self.uri.to_string());
215            if let Some(bearer) = self.bearer {
216                request = request.header("authorization", bearer);
217            }
218            request = request
219                .header("sec-websocket-key", generate_key())
220                .header("sec-websocket-version", "13")
221                .header("host", self.host)
222                .header("origin", origin)
223                .header("connection", "keep-alive, Upgrade")
224                .header(X_SOS_ACCOUNT_ID, self.account_id.to_string())
225                .header("upgrade", "websocket");
226            Ok(request.body(())?)
227        }
228    }
229}
230
231#[cfg(any(feature = "listen", feature = "pairing"))]
232pub use websocket_request::WebSocketRequest;