sos_protocol/network_client/
http.rs

1//! HTTP client implementation.
2use crate::{
3    constants::{
4        routes::v1::{
5            SYNC_ACCOUNT, SYNC_ACCOUNT_EVENTS, SYNC_ACCOUNT_STATUS,
6        },
7        MIME_TYPE_JSON, MIME_TYPE_PROTOBUF, X_SOS_ACCOUNT_ID,
8    },
9    DiffRequest, DiffResponse, Error, NetworkError, PatchRequest,
10    PatchResponse, Result, ScanRequest, ScanResponse, SyncClient,
11    WireEncodeDecode,
12};
13use async_trait::async_trait;
14use http::StatusCode;
15use reqwest::{
16    header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT},
17    RequestBuilder,
18};
19use serde_json::Value;
20use sos_core::{AccountId, Origin};
21use sos_signer::ed25519::BoxedEd25519Signer;
22use sos_sync::{CreateSet, SyncPacket, SyncStatus, UpdateSet};
23use std::{fmt, sync::OnceLock, time::Duration};
24use tracing::instrument;
25use url::Url;
26
27#[cfg(feature = "listen")]
28use futures::Future;
29
30use super::{bearer_prefix, encode_device_signature};
31
32#[cfg(feature = "listen")]
33use crate::{
34    network_client::websocket::{
35        ListenOptions, WebSocketChangeListener, WebSocketHandle,
36    },
37    NetworkChangeEvent,
38};
39
40#[cfg(feature = "files")]
41use {
42    crate::transfer::{
43        FileSet, FileSyncClient, FileTransfersSet, ProgressChannel,
44    },
45    sos_core::ExternalFile,
46};
47
48static REQUEST_USER_AGENT: OnceLock<String> = OnceLock::new();
49
50/// Set user agent for requests.
51pub fn set_user_agent(user_agent: String) {
52    REQUEST_USER_AGENT.get_or_init(|| user_agent);
53}
54
55/// Client that can synchronize with a server over HTTP(S).
56#[derive(Clone)]
57pub struct HttpClient {
58    account_id: AccountId,
59    origin: Origin,
60    device_signer: BoxedEd25519Signer,
61    client: reqwest::Client,
62    connection_id: String,
63}
64
65impl PartialEq for HttpClient {
66    fn eq(&self, other: &Self) -> bool {
67        self.origin == other.origin
68            && self.connection_id == other.connection_id
69    }
70}
71
72impl Eq for HttpClient {}
73
74impl fmt::Debug for HttpClient {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_struct("HttpClient")
77            .field("url", self.origin.url())
78            .field("connection_id", &self.connection_id)
79            .finish()
80    }
81}
82
83impl HttpClient {
84    /// Create a new client.
85    pub fn new(
86        account_id: AccountId,
87        origin: Origin,
88        device_signer: BoxedEd25519Signer,
89        connection_id: String,
90    ) -> Result<Self> {
91        #[cfg(not(target_arch = "wasm32"))]
92        let client = reqwest::ClientBuilder::new()
93            .read_timeout(Duration::from_millis(15000))
94            .connect_timeout(Duration::from_millis(5000))
95            .build()?;
96
97        #[cfg(target_arch = "wasm32")]
98        let client = reqwest::ClientBuilder::new().build()?;
99
100        Ok(Self {
101            account_id,
102            origin,
103            device_signer,
104            client,
105            connection_id,
106        })
107    }
108
109    /// Device signing key.
110    pub fn device_signer(&self) -> &BoxedEd25519Signer {
111        &self.device_signer
112    }
113
114    /// Spawn a thread that listens for changes
115    /// from the remote server using a websocket
116    /// that performs automatic re-connection.
117    #[cfg(feature = "listen")]
118    pub fn listen<F>(
119        &self,
120        options: ListenOptions,
121        handler: impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static,
122    ) -> WebSocketHandle
123    where
124        F: Future<Output = ()> + Send + 'static,
125    {
126        let listener = WebSocketChangeListener::new(
127            self.account_id.clone(),
128            self.origin.clone(),
129            self.device_signer.clone(),
130            options,
131        );
132        listener.spawn(handler)
133    }
134
135    /// Total number of websocket connections on remote.
136    pub async fn num_connections(server: &Url) -> Result<usize> {
137        let client = reqwest::Client::new();
138        let url = server.join("api/v1/sync/connections")?;
139        let response = client.get(url).send().await?;
140        let response = response.error_for_status()?;
141        Ok(response.json::<usize>().await?)
142    }
143
144    /// Build a URL including the connection identifier
145    /// in the query string.
146    fn build_url(&self, route: &str) -> Result<Url> {
147        let mut url = self.origin.url().join(route)?;
148        url.query_pairs_mut()
149            .append_pair("connection_id", &self.connection_id);
150        Ok(url)
151    }
152
153    /// Check if we are able to handle a response status code
154    /// and content type.
155    async fn check_response(
156        &self,
157        response: reqwest::Response,
158    ) -> Result<reqwest::Response> {
159        use reqwest::header::{self, HeaderValue};
160        let protobuf_type = HeaderValue::from_static(MIME_TYPE_PROTOBUF);
161        let status = response.status();
162        let content_type = response.headers().get(&header::CONTENT_TYPE);
163        match (status, content_type) {
164            // OK with the correct MIME type can be handled
165            (http::StatusCode::OK, Some(content_type)) => {
166                if content_type == &protobuf_type {
167                    Ok(response)
168                } else {
169                    Err(NetworkError::ContentType(
170                        content_type.to_str()?.to_owned(),
171                        MIME_TYPE_PROTOBUF.to_string(),
172                    )
173                    .into())
174                }
175            }
176            // Otherwise exit out early
177            _ => self.error_json(response).await,
178        }
179    }
180
181    /// Convert an error response that may be JSON
182    /// into an error.
183    async fn error_json(
184        &self,
185        response: reqwest::Response,
186    ) -> Result<reqwest::Response> {
187        use reqwest::header::{self, HeaderValue};
188
189        let status = response.status();
190        let json_type = HeaderValue::from_static(MIME_TYPE_JSON);
191        let content_type = response.headers().get(&header::CONTENT_TYPE);
192        if !status.is_success() {
193            if let Some(content_type) = content_type {
194                if content_type == json_type {
195                    let value: Value = response.json().await?;
196                    Err(NetworkError::ResponseJson(status, value).into())
197                } else {
198                    Err(NetworkError::ResponseCode(status).into())
199                }
200            } else {
201                Ok(response)
202            }
203        } else {
204            Ok(response)
205        }
206    }
207
208    /// Set headers for all requests.
209    async fn request_headers(
210        &self,
211        mut request: RequestBuilder,
212        sign_bytes: &[u8],
213    ) -> Result<RequestBuilder> {
214        let device_signature = encode_device_signature(
215            self.device_signer.sign(sign_bytes).await?,
216        )
217        .await?;
218        let auth = bearer_prefix(&device_signature);
219
220        request = request
221            .header(X_SOS_ACCOUNT_ID, self.account_id.to_string())
222            .header(AUTHORIZATION, auth);
223
224        if let Some(user_agent) = REQUEST_USER_AGENT.get() {
225            request = request.header(USER_AGENT, user_agent);
226        }
227
228        Ok(request)
229    }
230}
231
232#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
233#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
234impl SyncClient for HttpClient {
235    type Error = crate::Error;
236
237    fn origin(&self) -> &Origin {
238        &self.origin
239    }
240
241    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
242    async fn account_exists(&self) -> Result<bool> {
243        let url = self.build_url(SYNC_ACCOUNT)?;
244        let sign_url = url.path().to_owned();
245
246        tracing::debug!(url = %url, "http::account_exists");
247        let request = self.client.head(url);
248        let request =
249            self.request_headers(request, sign_url.as_bytes()).await?;
250        let response = request.send().await?;
251        let status = response.status();
252        tracing::debug!(status = %status, "http::account_exists");
253        let exists = match status {
254            StatusCode::OK => true,
255            StatusCode::NOT_FOUND => false,
256            _ => {
257                return Err(NetworkError::ResponseCode(status).into());
258            }
259        };
260        Ok(exists)
261    }
262
263    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
264    async fn create_account(&self, account: CreateSet) -> Result<()> {
265        let body = account.encode().await?;
266        let url = self.build_url(SYNC_ACCOUNT)?;
267
268        tracing::debug!(url = %url, "http::create_account");
269
270        let request = self
271            .client
272            .put(url)
273            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
274        let request = self.request_headers(request, &body).await?;
275        let response = request.body(body).send().await?;
276        let status = response.status();
277        tracing::debug!(status = %status, "http::create_account");
278        self.error_json(response).await?;
279        Ok(())
280    }
281
282    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
283    async fn update_account(&self, account: UpdateSet) -> Result<()> {
284        let body = account.encode().await?;
285        let url = self.build_url(SYNC_ACCOUNT)?;
286
287        tracing::debug!(url = %url, "http::update_account");
288
289        let request = self
290            .client
291            .post(url)
292            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
293        let request = self.request_headers(request, &body).await?;
294        let response = request.body(body).send().await?;
295        let status = response.status();
296        tracing::debug!(status = %status, "http::update_account");
297        self.error_json(response).await?;
298        Ok(())
299    }
300
301    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
302    async fn fetch_account(&self) -> Result<CreateSet> {
303        let url = self.build_url(SYNC_ACCOUNT)?;
304        let sign_url = url.path().to_owned();
305
306        tracing::debug!(url = %url, "http::fetch_account");
307
308        let request = self.client.get(url);
309        let request =
310            self.request_headers(request, sign_url.as_bytes()).await?;
311        let response = request.send().await?;
312        let status = response.status();
313        tracing::debug!(status = %status, "http::fetch_account");
314        let response = self.check_response(response).await?;
315        let buffer = response.bytes().await?;
316        Ok(CreateSet::decode(buffer).await?)
317    }
318
319    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
320    async fn delete_account(&self) -> Result<()> {
321        let url = self.build_url(SYNC_ACCOUNT)?;
322
323        let sign_url = url.path().to_owned();
324
325        tracing::debug!(url = %url, "http::delete_account");
326
327        let request = self.client.delete(url);
328        let request =
329            self.request_headers(request, sign_url.as_bytes()).await?;
330        let response = request.send().await?;
331        let status = response.status();
332        tracing::debug!(status = %status, "http::delete_account");
333        self.error_json(response).await?;
334        Ok(())
335    }
336
337    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
338    async fn sync_status(&self) -> Result<SyncStatus> {
339        let url = self.build_url(SYNC_ACCOUNT_STATUS)?;
340        let sign_url = url.path().to_owned();
341
342        tracing::debug!(url = %url, "http::sync_status");
343
344        let request = self.client.get(url);
345        let request =
346            self.request_headers(request, sign_url.as_bytes()).await?;
347        let response = request.send().await?;
348        let status = response.status();
349        tracing::debug!(status = %status, "http::sync_status");
350        let response = self.check_response(response).await?;
351        let buffer = response.bytes().await?;
352        Ok(SyncStatus::decode(buffer).await?)
353    }
354
355    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
356    async fn sync(&self, packet: SyncPacket) -> Result<SyncPacket> {
357        let body = packet.encode().await?;
358        let url = self.build_url(SYNC_ACCOUNT)?;
359        tracing::debug!(url = %url, "http::sync");
360
361        let request = self
362            .client
363            .patch(url)
364            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
365        let request = self.request_headers(request, &body).await?;
366        let response = request.body(body).send().await?;
367        let status = response.status();
368        tracing::debug!(status = %status, "http::sync");
369        let response = self.check_response(response).await?;
370        let buffer = response.bytes().await?;
371        Ok(SyncPacket::decode(buffer).await?)
372    }
373
374    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
375    async fn scan(&self, request: ScanRequest) -> Result<ScanResponse> {
376        let body = request.encode().await?;
377        let url = self.build_url(SYNC_ACCOUNT_EVENTS)?;
378
379        tracing::debug!(url = %url, "http::scan");
380
381        let request = self
382            .client
383            .get(url)
384            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
385        let request = self.request_headers(request, &body).await?;
386        let response = request.body(body).send().await?;
387        let status = response.status();
388        tracing::debug!(status = %status, "http::scan");
389        let response = self.check_response(response).await?;
390        let buffer = response.bytes().await?;
391        Ok(ScanResponse::decode(buffer).await?)
392    }
393
394    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
395    async fn diff(&self, request: DiffRequest) -> Result<DiffResponse> {
396        let body = request.encode().await?;
397        let url = self.build_url(SYNC_ACCOUNT_EVENTS)?;
398
399        tracing::debug!(url = %url, "http::diff");
400
401        let request = self
402            .client
403            .post(url)
404            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
405        let request = self.request_headers(request, &body).await?;
406        let response = request.body(body).send().await?;
407        let status = response.status();
408        tracing::debug!(status = %status, "http::diff");
409        let response = self.check_response(response).await?;
410        let buffer = response.bytes().await?;
411        Ok(DiffResponse::decode(buffer).await?)
412    }
413
414    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
415    async fn patch(&self, request: PatchRequest) -> Result<PatchResponse> {
416        let body = request.encode().await?;
417        let url = self.build_url(SYNC_ACCOUNT_EVENTS)?;
418        tracing::debug!(url = %url, "http::patch");
419        let request = self
420            .client
421            .patch(url)
422            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
423        let request = self.request_headers(request, &body).await?;
424        let response = request.body(body).send().await?;
425        let status = response.status();
426        tracing::debug!(status = %status, "http::patch");
427        let response = self.check_response(response).await?;
428        let buffer = response.bytes().await?;
429        Ok(PatchResponse::decode(buffer).await?)
430    }
431}
432
433#[cfg(feature = "files")]
434#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
435#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
436impl FileSyncClient for HttpClient {
437    type Error = crate::Error;
438
439    #[cfg_attr(
440        not(target_arch = "wasm32"),
441        instrument(skip(self, path, progress, cancel))
442    )]
443    async fn upload_file(
444        &self,
445        file_info: &ExternalFile,
446        path: &std::path::Path,
447        progress: ProgressChannel,
448        mut cancel: tokio::sync::watch::Receiver<
449            crate::transfer::CancelReason,
450        >,
451    ) -> Result<http::StatusCode> {
452        use futures::StreamExt;
453        use reqwest::{
454            header::{CONTENT_LENGTH, CONTENT_TYPE},
455            Body,
456        };
457        use sos_vfs as vfs;
458        use tokio::sync::mpsc;
459        use tokio_stream::wrappers::ReceiverStream;
460        use tokio_util::io::ReaderStream;
461
462        let url_path = format!("api/v1/sync/file/{}", file_info);
463        let url = self.build_url(&url_path)?;
464
465        tracing::debug!(url = %url, "http::upload_file");
466
467        let sign_url = url.path().to_owned();
468
469        let metadata = vfs::metadata(path).await?;
470        let file_size = metadata.len();
471        let file = vfs::File::open(path).await?;
472
473        let mut bytes_sent = 0;
474        if let Err(error) = progress.send((bytes_sent, Some(file_size))).await
475        {
476            tracing::warn!(
477                error = ?error,
478                "http::progress_send_initial_size",
479            );
480        }
481
482        let (tx, rx) = mpsc::channel(128);
483        tokio::task::spawn(async move {
484            let mut reader_stream = ReaderStream::new(file);
485            let upload_channel = tx.clone();
486            loop {
487                tokio::select! {
488                  biased;
489                  _= cancel.changed() => {
490                    let reason = cancel.borrow_and_update().clone();
491                    if reason != crate::transfer::CancelReason::default() {
492                        tracing::debug!(
493                            reason = ?reason,
494                            "upload::canceled",
495                        );
496                        if let Err(error) = upload_channel.send(Err(Error::TransferCanceled(reason))).await {
497                            tracing::warn!(
498                                error = %error,
499                                "http::send_transfer_canceled",
500                            );
501                        }
502
503                        break;
504                    }
505                  }
506                  Some(chunk) = reader_stream.next() => {
507                    if let Ok(bytes) = &chunk {
508                        bytes_sent += bytes.len() as u64;
509                        if let Err(error) = progress.send((bytes_sent, Some(file_size))).await {
510                            tracing::warn!(
511                                error = %error,
512                                "http::send_transfer_progress_update",
513                        );
514                        }
515                    }
516                    if let Err(error) = upload_channel.send(chunk.map_err(Error::from)).await {
517                        tracing::error!(
518                            error = %error,
519                            "http::send_transfer_chunk",
520                        );
521                        break;
522                    }
523                  }
524                }
525            }
526        });
527
528        let upload_stream = ReceiverStream::new(rx);
529
530        // Use a client without the read timeout
531        // as this may be a long running request
532        let client = reqwest::ClientBuilder::new()
533            .connect_timeout(Duration::from_millis(5000))
534            .build()?;
535
536        let request = client
537            .put(url)
538            .header(CONTENT_LENGTH, file_size)
539            .header(CONTENT_TYPE, "application/octet-stream");
540        let request =
541            self.request_headers(request, sign_url.as_bytes()).await?;
542
543        let response = request
544            .body(Body::wrap_stream(upload_stream))
545            .send()
546            .await?;
547        let status = response.status();
548        tracing::debug!(status = %status, "http::upload_file");
549        if !status.is_success() && status != http::StatusCode::NOT_MODIFIED {
550            self.error_json(response).await?;
551        }
552        Ok(status)
553    }
554
555    #[cfg_attr(
556        not(target_arch = "wasm32"),
557        instrument(skip(self, path, progress, cancel))
558    )]
559    async fn download_file(
560        &self,
561        file_info: &ExternalFile,
562        path: &std::path::Path,
563        progress: ProgressChannel,
564        mut cancel: tokio::sync::watch::Receiver<
565            crate::transfer::CancelReason,
566        >,
567    ) -> Result<http::StatusCode> {
568        use sha2::{Digest, Sha256};
569        use sos_vfs as vfs;
570        use tokio::io::AsyncWriteExt;
571
572        let url_path = format!("api/v1/sync/file/{}", file_info);
573        let url = self.build_url(&url_path)?;
574
575        tracing::debug!(url = %url, "http::download_file");
576
577        let sign_url = url.path().to_owned();
578        let request = self.client.get(url);
579        let request =
580            self.request_headers(request, sign_url.as_bytes()).await?;
581        let mut response = request.send().await?;
582
583        let file_size = response.content_length();
584        let mut bytes_received = 0;
585        if let Err(error) = progress.send((bytes_received, file_size)).await {
586            tracing::warn!(error = ?error);
587        }
588
589        let mut download_path = path.to_path_buf();
590        download_path.set_extension("download");
591
592        let mut hasher = Sha256::new();
593        let mut file = vfs::File::create(&download_path).await?;
594
595        loop {
596            tokio::select! {
597                biased;
598                _ = cancel.changed() => {
599                  let reason = cancel.borrow().clone();
600                  vfs::remove_file(download_path).await?;
601                  tracing::debug!(reason = ?reason, "download::canceled");
602                  return Err(Error::TransferCanceled(reason));
603                }
604                chunk = response.chunk() => {
605                  if let Some(chunk) = chunk? {
606                    file.write_all(&chunk).await?;
607                    hasher.update(&chunk);
608
609                    bytes_received += chunk.len() as u64;
610                    if let Err(error) = progress.send((bytes_received, file_size)).await {
611                        tracing::warn!(error = ?error);
612                    }
613                  } else {
614                    break;
615                  }
616                }
617            }
618        }
619
620        file.flush().await?;
621
622        let digest = hasher.finalize();
623
624        let digest_valid =
625            digest.as_slice() == file_info.file_name().as_ref();
626        if !digest_valid {
627            vfs::remove_file(download_path).await?;
628            return Err(Error::FileChecksumMismatch(
629                file_info.file_name().to_string(),
630                hex::encode(digest.as_slice()),
631            ));
632        }
633
634        let status = response.status();
635        tracing::debug!(status = %status, "http::download_file");
636
637        if status == http::StatusCode::OK
638            && vfs::try_exists(&download_path).await?
639        {
640            vfs::rename(download_path, path).await?;
641        }
642
643        self.error_json(response).await?;
644        Ok(status)
645    }
646
647    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip(self)))]
648    async fn delete_file(
649        &self,
650        file_info: &ExternalFile,
651    ) -> Result<http::StatusCode> {
652        let url_path = format!("api/v1/sync/file/{}", file_info);
653        let url = self.build_url(&url_path)?;
654        let sign_url = url.path().to_owned();
655
656        tracing::debug!(url = %url, "http::delete_file");
657
658        let request = self.client.delete(url);
659        let request =
660            self.request_headers(request, sign_url.as_bytes()).await?;
661        let response = request.send().await?;
662        let status = response.status();
663        tracing::debug!(status = %status, "http::delete_file");
664        if !status.is_success() && status != http::StatusCode::NOT_FOUND {
665            self.error_json(response).await?;
666        }
667        Ok(status)
668    }
669
670    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip(self)))]
671    async fn move_file(
672        &self,
673        from: &ExternalFile,
674        to: &ExternalFile,
675    ) -> Result<http::StatusCode> {
676        let url_path = format!("api/v1/sync/file/{}", from);
677        let mut url = self.build_url(&url_path)?;
678
679        url.query_pairs_mut()
680            .append_pair("vault_id", &to.vault_id().to_string())
681            .append_pair("secret_id", &to.secret_id().to_string())
682            .append_pair("name", &to.file_name().to_string());
683
684        tracing::debug!(from = %from, to = %to, url = %url, "http::move_file");
685
686        let sign_url = url.path().to_owned();
687        let request = self.client.post(url);
688        let request =
689            self.request_headers(request, sign_url.as_bytes()).await?;
690        let response = request.send().await?;
691        let status = response.status();
692        tracing::debug!(status = %status, "http::move_file");
693        self.error_json(response).await?;
694        Ok(status)
695    }
696
697    #[cfg_attr(not(target_arch = "wasm32"), instrument(skip_all))]
698    async fn compare_files(
699        &self,
700        local_files: FileSet,
701    ) -> Result<FileTransfersSet> {
702        let url_path = format!("api/v1/sync/files");
703        let url = self.build_url(&url_path)?;
704        let sign_url = url.path().to_owned();
705        let body = local_files.encode().await?;
706
707        tracing::debug!(url = %url, "http::compare_files");
708
709        let request = self
710            .client
711            .post(url)
712            .header(CONTENT_TYPE, MIME_TYPE_PROTOBUF);
713        let request =
714            self.request_headers(request, sign_url.as_bytes()).await?;
715        let response = request.body(body).send().await?;
716        let status = response.status();
717        tracing::debug!(status = %status, "http::compare_files");
718        let response = self.check_response(response).await?;
719        let buffer = response.bytes().await?;
720        Ok(FileTransfersSet::decode(buffer).await?)
721    }
722}