sos_protocol/network_client/
http.rs

1//! HTTP client implementation.
2use crate::{
3    constants::{
4        routes::v1::{
5            SYNC_ACCOUNT, SYNC_ACCOUNT_EVENTS, SYNC_ACCOUNT_STATUS,
6        },
7        MIME_TYPE_JSON, MIME_TYPE_PROTOBUF, X_SOS_ACCOUNT_ID,
8    },
9    DiffRequest, DiffResponse, Error, NetworkError, PatchRequest,
10    PatchResponse, Result, ScanRequest, ScanResponse, SyncClient,
11    WireEncodeDecode,
12};
13use async_trait::async_trait;
14use http::StatusCode;
15use reqwest::{
16    header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT},
17    RequestBuilder,
18};
19use serde_json::Value;
20use sos_core::{AccountId, Origin};
21use sos_signer::ed25519::BoxedEd25519Signer;
22use sos_sync::{CreateSet, SyncPacket, SyncStatus, UpdateSet};
23use std::{fmt, sync::OnceLock, time::Duration};
24use tracing::instrument;
25use url::Url;
26
27#[cfg(feature = "listen")]
28use futures::Future;
29
30use super::{bearer_prefix, encode_device_signature};
31
32#[cfg(feature = "listen")]
33use crate::{
34    network_client::websocket::{
35        ListenOptions, WebSocketChangeListener, WebSocketHandle,
36    },
37    ChangeNotification,
38};
39
40#[cfg(feature = "files")]
41use {
42    crate::transfer::{
43        FileSet, FileSyncClient, FileTransfersSet, ProgressChannel,
44    },
45    sos_core::ExternalFile,
46};
47
48static REQUEST_USER_AGENT: OnceLock<String> = OnceLock::new();
49
50/// Set user agent for requests.
51pub fn set_user_agent(user_agent: String) {
52    REQUEST_USER_AGENT.get_or_init(|| user_agent);
53}
54
55/// Client that can synchronize with a server over HTTP(S).
56#[derive(Clone)]
57pub struct HttpClient {
58    account_id: AccountId,
59    origin: Origin,
60    device_signer: BoxedEd25519Signer,
61    client: reqwest::Client,
62    connection_id: String,
63}
64
65impl PartialEq for HttpClient {
66    fn eq(&self, other: &Self) -> bool {
67        self.origin == other.origin
68            && self.connection_id == other.connection_id
69    }
70}
71
72impl Eq for HttpClient {}
73
74impl fmt::Debug for HttpClient {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_struct("HttpClient")
77            .field("url", self.origin.url())
78            .field("connection_id", &self.connection_id)
79            .finish()
80    }
81}
82
83impl HttpClient {
84    /// Create a new client.
85    pub fn new(
86        account_id: AccountId,
87        origin: Origin,
88        device_signer: BoxedEd25519Signer,
89        connection_id: String,
90    ) -> Result<Self> {
91        #[cfg(not(target_arch = "wasm32"))]
92        let client = reqwest::ClientBuilder::new()
93            .read_timeout(Duration::from_millis(15000))
94            .connect_timeout(Duration::from_millis(5000))
95            .build()?;
96
97        #[cfg(target_arch = "wasm32")]
98        let client = reqwest::ClientBuilder::new().build()?;
99
100        Ok(Self {
101            account_id,
102            origin,
103            device_signer,
104            client,
105            connection_id,
106        })
107    }
108
109    /// Device signing key.
110    pub fn device_signer(&self) -> &BoxedEd25519Signer {
111        &self.device_signer
112    }
113
114    /// Spawn a thread that listens for changes
115    /// from the remote server using a websocket
116    /// that performs automatic re-connection.
117    #[cfg(feature = "listen")]
118    pub fn listen<F>(
119        &self,
120        options: ListenOptions,
121        handler: impl Fn(ChangeNotification) -> F + Send + Sync + 'static,
122    ) -> WebSocketHandle
123    where
124        F: Future<Output = ()> + Send + 'static,
125    {
126        let listener = WebSocketChangeListener::new(
127            self.account_id.clone(),
128            self.origin.clone(),
129            self.device_signer.clone(),
130            options,
131        );
132        listener.spawn(handler)
133    }
134
135    /// Total number of websocket connections on remote.
136    pub async fn num_connections(server: &Url) -> Result<usize> {
137        let client = reqwest::Client::new();
138        let url = server.join("api/v1/sync/connections")?;
139        let response = client.get(url).send().await?;
140        let response = response.error_for_status()?;
141        Ok(response.json::<usize>().await?)
142    }
143
144    /// Build a URL including the connection identifier
145    /// in the query string.
146    fn build_url(&self, route: &str) -> Result<Url> {
147        let mut url = self.origin.url().join(route)?;
148        url.query_pairs_mut()
149            .append_pair("connection_id", &self.connection_id);
150        Ok(url)
151    }
152
153    /// Check if we are able to handle a response status code
154    /// and content type.
155    async fn check_response(
156        &self,
157        response: reqwest::Response,
158    ) -> Result<reqwest::Response> {
159        use reqwest::header::{self, HeaderValue};
160        let protobuf_type = HeaderValue::from_static(MIME_TYPE_PROTOBUF);
161        let status = response.status();
162        let content_type = response.headers().get(&header::CONTENT_TYPE);
163        match (status, content_type) {
164            // OK with the correct MIME type can be handled
165            (http::StatusCode::OK, Some(content_type)) => {
166                if content_type == &protobuf_type {
167                    Ok(response)
168                } else {
169                    Err(NetworkError::ContentType(
170                        content_type.to_str()?.to_owned(),
171                        MIME_TYPE_PROTOBUF.to_string(),
172                    )
173                    .into())
174                }
175            }
176            // Otherwise exit out early
177            _ => self.error_json(response).await,
178        }
179    }
180
181    /// Convert an error response that may be JSON
182    /// into an error.
183    async fn error_json(
184        &self,
185        response: reqwest::Response,
186    ) -> Result<reqwest::Response> {
187        use reqwest::header::{self, HeaderValue};
188
189        let status = response.status();
190        let json_type = HeaderValue::from_static(MIME_TYPE_JSON);
191        let content_type = response.headers().get(&header::CONTENT_TYPE);
192        if !status.is_success() {
193            if let Some(content_type) = content_type {
194                if content_type == json_type {
195                    let value: Value = response.json().await?;
196                    Err(NetworkError::ResponseJson(status, value).into())
197                } else {
198                    Err(NetworkError::ResponseCode(status).into())
199                }
200            } else {
201                Ok(response)
202            }
203        } else {
204            Ok(response)
205        }
206    }
207
208    /// Set headers for all requests.
209    async fn request_headers(
210        &self,
211        mut request: RequestBuilder,
212        sign_bytes: &[u8],
213    ) -> Result<RequestBuilder> {
214        let device_signature = encode_device_signature(
215            self.device_signer.sign(sign_bytes).await?,
216        )
217        .await?;
218        let auth = bearer_prefix(&device_signature);
219
220        request = request
221            .header(X_SOS_ACCOUNT_ID, self.account_id.to_string())
222            .header(AUTHORIZATION, auth);
223
224        if let Some(user_agent) = REQUEST_USER_AGENT.get() {
225            request = request.header(USER_AGENT, user_agent);
226        }
227
228        Ok(request)
229    }
230}
231
232#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
233#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
234impl SyncClient for HttpClient {
235    type Error = crate::Error;
236
237    fn origin(&self) -> &Origin {
238        &self.origin
239    }
240
241    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
242    async fn account_exists(&self) -> Result<bool> {
243        let url = self.build_url(SYNC_ACCOUNT)?;
244        let sign_url = url.path().to_owned();
245
246        tracing::debug!(url = %url, "http::account_exists");
247        let request = self.client.head(url);
248        let request =
249            self.request_headers(request, sign_url.as_bytes()).await?;
250        let response = request.send().await?;
251        let status = response.status();
252        tracing::debug!(status = %status, "http::account_exists");
253        let exists = match status {
254            StatusCode::OK => true,
255            StatusCode::NOT_FOUND => false,
256            _ => {
257                return Err(NetworkError::ResponseCode(status).into());
258            }
259        };
260        Ok(exists)
261    }
262
263    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
264    async fn create_account(&self, account: CreateSet) -> Result<()> {
265        let body = account.encode().await?;
266        let url = self.build_url(SYNC_ACCOUNT)?;
267
268        tracing::debug!(url = %url, "http::create_account");
269
270        let request = self
271            .client
272            .put(url)
273            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
274        let request = self.request_headers(request, &body).await?;
275        let response = request.body(body).send().await?;
276        let status = response.status();
277        tracing::debug!(status = %status, "http::create_account");
278        self.error_json(response).await?;
279        Ok(())
280    }
281
282    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
283    async fn update_account(&self, account: UpdateSet) -> Result<()> {
284        let body = account.encode().await?;
285        let url = self.build_url(SYNC_ACCOUNT)?;
286
287        tracing::debug!(url = %url, "http::update_account");
288
289        let request = self
290            .client
291            .post(url)
292            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
293        let request = self.request_headers(request, &body).await?;
294        let response = request.body(body).send().await?;
295        let status = response.status();
296        tracing::debug!(status = %status, "http::update_account");
297        self.error_json(response).await?;
298        Ok(())
299    }
300
301    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
302    async fn fetch_account(&self) -> Result<CreateSet> {
303        let url = self.build_url(SYNC_ACCOUNT)?;
304        let sign_url = url.path().to_owned();
305
306        tracing::debug!(url = %url, "http::fetch_account");
307
308        let request = self.client.get(url);
309        let request =
310            self.request_headers(request, sign_url.as_bytes()).await?;
311        let response = request.send().await?;
312        let status = response.status();
313        tracing::debug!(status = %status, "http::fetch_account");
314        let response = self.check_response(response).await?;
315        let buffer = response.bytes().await?;
316        Ok(CreateSet::decode(buffer).await?)
317    }
318
319    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
320    async fn delete_account(&self) -> Result<()> {
321        let url = self.build_url(SYNC_ACCOUNT)?;
322
323        let sign_url = url.path().to_owned();
324
325        tracing::debug!(url = %url, "http::delete_account");
326
327        let request = self.client.delete(url);
328        let request =
329            self.request_headers(request, sign_url.as_bytes()).await?;
330        let response = request.send().await?;
331        let status = response.status();
332        tracing::debug!(status = %status, "http::delete_account");
333        self.error_json(response).await?;
334        Ok(())
335    }
336
337    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
338    async fn sync_status(&self) -> Result<SyncStatus> {
339        let url = self.build_url(SYNC_ACCOUNT_STATUS)?;
340        let sign_url = url.path().to_owned();
341
342        tracing::debug!(url = %url, "http::sync_status");
343
344        let request = self.client.get(url);
345        let request =
346            self.request_headers(request, sign_url.as_bytes()).await?;
347        let response = request.send().await?;
348        let status = response.status();
349        tracing::debug!(status = %status, "http::sync_status");
350        let response = self.check_response(response).await?;
351        let buffer = response.bytes().await?;
352        Ok(SyncStatus::decode(buffer).await?)
353    }
354
355    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
356    async fn sync(&self, packet: SyncPacket) -> Result<SyncPacket> {
357        let body = packet.encode().await?;
358        let url = self.build_url(SYNC_ACCOUNT)?;
359        tracing::debug!(url = %url, "http::sync");
360
361        let request = self
362            .client
363            .patch(url)
364            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
365        let request = self.request_headers(request, &body).await?;
366        let response = request.body(body).send().await?;
367        let status = response.status();
368        tracing::debug!(status = %status, "http::sync");
369        let response = self.check_response(response).await?;
370        let buffer = response.bytes().await?;
371        Ok(SyncPacket::decode(buffer).await?)
372    }
373
374    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
375    async fn scan(&self, request: ScanRequest) -> Result<ScanResponse> {
376        let body = request.encode().await?;
377        let url = self.build_url(SYNC_ACCOUNT_EVENTS)?;
378
379        tracing::debug!(url = %url, "http::scan");
380
381        let request = self
382            .client
383            .get(url)
384            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
385        let request = self.request_headers(request, &body).await?;
386        let response = request.body(body).send().await?;
387        let status = response.status();
388        tracing::debug!(status = %status, "http::scan");
389        let response = self.check_response(response).await?;
390        let buffer = response.bytes().await?;
391        Ok(ScanResponse::decode(buffer).await?)
392    }
393
394    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
395    async fn diff(&self, request: DiffRequest) -> Result<DiffResponse> {
396        let body = request.encode().await?;
397        let url = self.build_url(SYNC_ACCOUNT_EVENTS)?;
398
399        tracing::debug!(url = %url, "http::diff");
400
401        let request = self
402            .client
403            .post(url)
404            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
405        let request = self.request_headers(request, &body).await?;
406        let response = request.body(body).send().await?;
407        let status = response.status();
408        tracing::debug!(status = %status, "http::diff");
409        let response = self.check_response(response).await?;
410        let buffer = response.bytes().await?;
411        Ok(DiffResponse::decode(buffer).await?)
412    }
413
414    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
415    async fn patch(&self, request: PatchRequest) -> Result<PatchResponse> {
416        let body = request.encode().await?;
417        let url = self.build_url(SYNC_ACCOUNT_EVENTS)?;
418        tracing::debug!(url = %url, "http::patch");
419        let request = self
420            .client
421            .patch(url)
422            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
423        let request = self.request_headers(request, &body).await?;
424        let response = request.body(body).send().await?;
425        let status = response.status();
426        tracing::debug!(status = %status, "http::patch");
427        let response = self.check_response(response).await?;
428        let buffer = response.bytes().await?;
429        Ok(PatchResponse::decode(buffer).await?)
430    }
431}
432
433#[cfg(feature = "files")]
434#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
435#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
436impl FileSyncClient for HttpClient {
437    type Error = crate::Error;
438
439    #[cfg_attr(
440        not(target_arch = "wasm32"),
441        instrument(skip(self, path, progress, cancel))
442    )]
443    async fn upload_file(
444        &self,
445        file_info: &ExternalFile,
446        path: &std::path::Path,
447        progress: ProgressChannel,
448        mut cancel: tokio::sync::watch::Receiver<
449            crate::transfer::CancelReason,
450        >,
451    ) -> Result<http::StatusCode> {
452        use futures::StreamExt;
453        use reqwest::{
454            header::{CONTENT_LENGTH, CONTENT_TYPE},
455            Body,
456        };
457        use sos_vfs as vfs;
458        use tokio::sync::mpsc;
459        use tokio_stream::wrappers::ReceiverStream;
460        use tokio_util::io::ReaderStream;
461
462        let url_path = format!("api/v1/sync/file/{}", file_info);
463        let url = self.build_url(&url_path)?;
464
465        tracing::debug!(url = %url, "http::upload_file");
466
467        let sign_url = url.path().to_owned();
468
469        let metadata = vfs::metadata(path).await?;
470        let file_size = metadata.len();
471        let file = vfs::File::open(path).await?;
472
473        let mut bytes_sent = 0;
474        if let Err(error) = progress.send((bytes_sent, Some(file_size))).await
475        {
476            tracing::warn!(error = ?error);
477        }
478
479        let mut reader_stream = ReaderStream::new(file);
480
481        let (tx, rx) = mpsc::channel(8);
482        tokio::task::spawn(async move {
483            loop {
484                tokio::select! {
485                  biased;
486                  _ = cancel.changed() => {
487                    let reason = cancel.borrow().clone();
488                    tracing::debug!(reason = ?reason, "upload::canceled");
489                    if let Err(e) = tx.send(Err(Error::TransferCanceled(reason))).await {
490                        tracing::warn!(error = %e);
491                    }
492                  }
493                  Some(chunk) = reader_stream.next() => {
494                    if let Ok(bytes) = &chunk {
495                        bytes_sent += bytes.len() as u64;
496                        if let Err(e) = progress.send((bytes_sent, Some(file_size))).await {
497                          tracing::warn!(error = %e);
498                        }
499                    }
500                    if let Err(e) = tx.send(chunk.map_err(Error::from)).await {
501                        tracing::error!(error = %e);
502                        break;
503                    }
504                  }
505                }
506            }
507        });
508
509        let progress_stream = ReceiverStream::new(rx);
510
511        // Use a client without the read timeout
512        // as this may be a long running request
513        let client = reqwest::ClientBuilder::new()
514            .connect_timeout(Duration::from_millis(5000))
515            .build()?;
516
517        let request = client
518            .put(url)
519            .header(CONTENT_LENGTH, file_size)
520            .header(CONTENT_TYPE, "application/octet-stream");
521        let request =
522            self.request_headers(request, sign_url.as_bytes()).await?;
523
524        let response = request
525            .body(Body::wrap_stream(progress_stream))
526            .send()
527            .await?;
528        let status = response.status();
529        tracing::debug!(status = %status, "http::upload_file");
530        if !status.is_success() && status != http::StatusCode::NOT_MODIFIED {
531            self.error_json(response).await?;
532        }
533        Ok(status)
534    }
535
536    #[cfg_attr(
537        not(target_arch = "wasm32"),
538        instrument(skip(self, path, progress, cancel))
539    )]
540    async fn download_file(
541        &self,
542        file_info: &ExternalFile,
543        path: &std::path::Path,
544        progress: ProgressChannel,
545        mut cancel: tokio::sync::watch::Receiver<
546            crate::transfer::CancelReason,
547        >,
548    ) -> Result<http::StatusCode> {
549        use sha2::{Digest, Sha256};
550        use sos_vfs as vfs;
551        use tokio::io::AsyncWriteExt;
552
553        let url_path = format!("api/v1/sync/file/{}", file_info);
554        let url = self.build_url(&url_path)?;
555
556        tracing::debug!(url = %url, "http::download_file");
557
558        let sign_url = url.path().to_owned();
559        let request = self.client.get(url);
560        let request =
561            self.request_headers(request, sign_url.as_bytes()).await?;
562        let mut response = request.send().await?;
563
564        let file_size = response.content_length();
565        let mut bytes_received = 0;
566        if let Err(error) = progress.send((bytes_received, file_size)).await {
567            tracing::warn!(error = ?error);
568        }
569
570        let mut download_path = path.to_path_buf();
571        download_path.set_extension("download");
572
573        let mut hasher = Sha256::new();
574        let mut file = vfs::File::create(&download_path).await?;
575
576        loop {
577            tokio::select! {
578                biased;
579                _ = cancel.changed() => {
580                  let reason = cancel.borrow().clone();
581                  vfs::remove_file(download_path).await?;
582                  tracing::debug!(reason = ?reason, "download::canceled");
583                  return Err(Error::TransferCanceled(reason));
584                }
585                chunk = response.chunk() => {
586                  if let Some(chunk) = chunk? {
587                    file.write_all(&chunk).await?;
588                    hasher.update(&chunk);
589
590                    bytes_received += chunk.len() as u64;
591                    if let Err(error) = progress.send((bytes_received, file_size)).await {
592                        tracing::warn!(error = ?error);
593                    }
594                  } else {
595                    break;
596                  }
597                }
598            }
599        }
600
601        file.flush().await?;
602
603        let digest = hasher.finalize();
604
605        let digest_valid =
606            digest.as_slice() == file_info.file_name().as_ref();
607        if !digest_valid {
608            vfs::remove_file(download_path).await?;
609            return Err(Error::FileChecksumMismatch(
610                file_info.file_name().to_string(),
611                hex::encode(digest.as_slice()),
612            ));
613        }
614
615        let status = response.status();
616        tracing::debug!(status = %status, "http::download_file");
617
618        if status == http::StatusCode::OK
619            && vfs::try_exists(&download_path).await?
620        {
621            vfs::rename(download_path, path).await?;
622        }
623
624        self.error_json(response).await?;
625        Ok(status)
626    }
627
628    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip(self)))]
629    async fn delete_file(
630        &self,
631        file_info: &ExternalFile,
632    ) -> Result<http::StatusCode> {
633        let url_path = format!("api/v1/sync/file/{}", file_info);
634        let url = self.build_url(&url_path)?;
635        let sign_url = url.path().to_owned();
636
637        tracing::debug!(url = %url, "http::delete_file");
638
639        let request = self.client.delete(url);
640        let request =
641            self.request_headers(request, sign_url.as_bytes()).await?;
642        let response = request.send().await?;
643        let status = response.status();
644        tracing::debug!(status = %status, "http::delete_file");
645        if !status.is_success() && status != http::StatusCode::NOT_FOUND {
646            self.error_json(response).await?;
647        }
648        Ok(status)
649    }
650
651    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip(self)))]
652    async fn move_file(
653        &self,
654        from: &ExternalFile,
655        to: &ExternalFile,
656    ) -> Result<http::StatusCode> {
657        let url_path = format!("api/v1/sync/file/{}", from);
658        let mut url = self.build_url(&url_path)?;
659
660        url.query_pairs_mut()
661            .append_pair("vault_id", &to.vault_id().to_string())
662            .append_pair("secret_id", &to.secret_id().to_string())
663            .append_pair("name", &to.file_name().to_string());
664
665        tracing::debug!(from = %from, to = %to, url = %url, "http::move_file");
666
667        let sign_url = url.path().to_owned();
668        let request = self.client.post(url);
669        let request =
670            self.request_headers(request, sign_url.as_bytes()).await?;
671        let response = request.send().await?;
672        let status = response.status();
673        tracing::debug!(status = %status, "http::move_file");
674        self.error_json(response).await?;
675        Ok(status)
676    }
677
678    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
679    async fn compare_files(
680        &self,
681        local_files: FileSet,
682    ) -> Result<FileTransfersSet> {
683        let url_path = format!("api/v1/sync/files");
684        let url = self.build_url(&url_path)?;
685        let sign_url = url.path().to_owned();
686        let body = local_files.encode().await?;
687
688        tracing::debug!(url = %url, "http::compare_files");
689
690        let request = self
691            .client
692            .post(url)
693            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
694        let request =
695            self.request_headers(request, sign_url.as_bytes()).await?;
696        let response = request.body(body).send().await?;
697        let status = response.status();
698        tracing::debug!(status = %status, "http::compare_files");
699        let response = self.check_response(response).await?;
700        let buffer = response.bytes().await?;
701        Ok(FileTransfersSet::decode(buffer).await?)
702    }
703}