tauri_plugin_upload/
lib.rs

1// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5//! Upload files from disk to a remote server over HTTP.
6//!
7//! Download files from a remote HTTP server to disk.
8
9#![doc(
10    html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
11    html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
12)]
13
14mod transfer_stats;
15use transfer_stats::TransferStats;
16
17use futures_util::TryStreamExt;
18use serde::{ser::Serializer, Serialize};
19use tauri::{
20    command,
21    ipc::Channel,
22    plugin::{Builder as PluginBuilder, TauriPlugin},
23    Runtime,
24};
25use tokio::{
26    fs::File,
27    io::{AsyncWriteExt, BufWriter},
28};
29use tokio_util::codec::{BytesCodec, FramedRead};
30
31use read_progress_stream::ReadProgressStream;
32
33use std::collections::HashMap;
34
35type Result<T> = std::result::Result<T, Error>;
36
37#[derive(Debug, thiserror::Error)]
38pub enum Error {
39    #[error(transparent)]
40    Io(#[from] std::io::Error),
41    #[error(transparent)]
42    Request(#[from] reqwest::Error),
43    #[error("{0}")]
44    ContentLength(String),
45    #[error("request failed with status code {0}: {1}")]
46    HttpErrorCode(u16, String),
47}
48
49impl Serialize for Error {
50    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
51    where
52        S: Serializer,
53    {
54        serializer.serialize_str(self.to_string().as_ref())
55    }
56}
57
58#[derive(Clone, Serialize)]
59#[serde(rename_all = "camelCase")]
60struct ProgressPayload {
61    progress: u64,
62    progress_total: u64,
63    total: u64,
64    transfer_speed: u64,
65}
66
67#[command]
68async fn download(
69    url: String,
70    file_path: String,
71    headers: HashMap<String, String>,
72    body: Option<String>,
73    on_progress: Channel<ProgressPayload>,
74) -> Result<()> {
75    tokio::spawn(async move {
76        let client = reqwest::Client::new();
77        let mut request = if let Some(body) = body {
78            client.post(&url).body(body)
79        } else {
80            client.get(&url)
81        };
82        // Loop trought the headers keys and values
83        // and add them to the request object.
84        for (key, value) in headers {
85            request = request.header(&key, value);
86        }
87
88        let response = request.send().await?;
89        if !response.status().is_success() {
90            return Err(Error::HttpErrorCode(
91                response.status().as_u16(),
92                response.text().await.unwrap_or_default(),
93            ));
94        }
95        let total = response.content_length().unwrap_or(0);
96
97        let mut file = BufWriter::new(File::create(&file_path).await?);
98        let mut stream = response.bytes_stream();
99
100        let mut stats = TransferStats::default();
101        while let Some(chunk) = stream.try_next().await? {
102            file.write_all(&chunk).await?;
103            stats.record_chunk_transfer(chunk.len());
104            let _ = on_progress.send(ProgressPayload {
105                progress: chunk.len() as u64,
106                progress_total: stats.total_transferred,
107                total,
108                transfer_speed: stats.transfer_speed,
109            });
110        }
111        file.flush().await?;
112        Ok(())
113    })
114    .await
115    .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
116}
117
118#[command]
119async fn upload(
120    url: String,
121    file_path: String,
122    headers: HashMap<String, String>,
123    on_progress: Channel<ProgressPayload>,
124) -> Result<String> {
125    tokio::spawn(async move {
126        // Read the file
127        let file = File::open(&file_path).await?;
128        let file_len = file.metadata().await.unwrap().len();
129
130        // Create the request and attach the file to the body
131        let client = reqwest::Client::new();
132        let mut request = client
133            .post(&url)
134            .header(reqwest::header::CONTENT_LENGTH, file_len)
135            .body(file_to_body(on_progress, file, file_len));
136
137        // Loop through the headers keys and values
138        // and add them to the request object.
139        for (key, value) in headers {
140            request = request.header(&key, value);
141        }
142
143        let response = request.send().await?;
144        if response.status().is_success() {
145            response.text().await.map_err(Into::into)
146        } else {
147            Err(Error::HttpErrorCode(
148                response.status().as_u16(),
149                response.text().await.unwrap_or_default(),
150            ))
151        }
152    })
153    .await
154    .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
155}
156
157fn file_to_body(channel: Channel<ProgressPayload>, file: File, file_len: u64) -> reqwest::Body {
158    let stream = FramedRead::new(file, BytesCodec::new()).map_ok(|r| r.freeze());
159
160    let mut stats = TransferStats::default();
161    reqwest::Body::wrap_stream(ReadProgressStream::new(
162        stream,
163        Box::new(move |progress, _total| {
164            stats.record_chunk_transfer(progress as usize);
165            let _ = channel.send(ProgressPayload {
166                progress,
167                progress_total: stats.total_transferred,
168                total: file_len,
169                transfer_speed: stats.transfer_speed,
170            });
171        }),
172    ))
173}
174
175pub fn init<R: Runtime>() -> TauriPlugin<R> {
176    PluginBuilder::new("upload")
177        .invoke_handler(tauri::generate_handler![download, upload])
178        .build()
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use mockito::{self, Mock, Server, ServerGuard};
185    use tauri::ipc::InvokeResponseBody;
186    struct MockedServer {
187        _server: ServerGuard,
188        url: String,
189        mocked_endpoint: Mock,
190    }
191
192    #[tokio::test]
193    async fn should_error_on_download_if_status_not_success() {
194        let mocked_server = spawn_server_mocked(400).await;
195        let result = download_file(mocked_server.url).await;
196        mocked_server.mocked_endpoint.assert();
197        assert!(result.is_err());
198    }
199
200    #[tokio::test]
201    async fn should_download_file_successfully() {
202        let mocked_server = spawn_server_mocked(200).await;
203        let result = download_file(mocked_server.url).await;
204        mocked_server.mocked_endpoint.assert();
205        assert!(
206            result.is_ok(),
207            "failed to download file: {}",
208            result.unwrap_err()
209        );
210    }
211
212    #[tokio::test]
213    async fn should_error_on_upload_if_status_not_success() {
214        let mocked_server = spawn_upload_server_mocked(500).await;
215        let result = upload_file(mocked_server.url).await;
216        mocked_server.mocked_endpoint.assert();
217        assert!(result.is_err());
218        match result.unwrap_err() {
219            Error::HttpErrorCode(status, _) => assert_eq!(status, 500),
220            _ => panic!("Expected HttpErrorCode error"),
221        }
222    }
223
224    #[tokio::test]
225    async fn should_error_on_upload_if_file_not_found() {
226        let mocked_server = spawn_upload_server_mocked(200).await;
227        let file_path = "/nonexistent/file.txt".to_string();
228        let headers = HashMap::new();
229        let sender: Channel<ProgressPayload> =
230            Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
231                let _ = msg;
232                Ok(())
233            });
234
235        let result = upload(mocked_server.url, file_path, headers, sender).await;
236        assert!(result.is_err());
237        match result.unwrap_err() {
238            Error::Io(_) => {}
239            _ => panic!("Expected IO error for missing file"),
240        }
241    }
242
243    #[tokio::test]
244    async fn should_upload_file_successfully() {
245        let mocked_server = spawn_upload_server_mocked(200).await;
246        let result = upload_file(mocked_server.url).await;
247        mocked_server.mocked_endpoint.assert();
248        assert!(
249            result.is_ok(),
250            "failed to upload file: {}",
251            result.unwrap_err()
252        );
253        let response_body = result.unwrap();
254        assert_eq!(response_body, "upload successful");
255    }
256
257    async fn download_file(url: String) -> Result<()> {
258        let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/test.txt").to_string();
259        let headers = HashMap::new();
260        let sender: Channel<ProgressPayload> =
261            Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
262                let _ = msg;
263                Ok(())
264            });
265        download(url, file_path, headers, None, sender).await
266    }
267
268    async fn upload_file(url: String) -> Result<String> {
269        let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/test.txt").to_string();
270        let headers = HashMap::new();
271        let sender: Channel<ProgressPayload> =
272            Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
273                let _ = msg;
274                Ok(())
275            });
276        upload(url, file_path, headers, sender).await
277    }
278
279    async fn spawn_server_mocked(return_status: usize) -> MockedServer {
280        let mut _server = Server::new_async().await;
281        let path = "/mock_test";
282        let mock = _server
283            .mock("GET", path)
284            .with_status(return_status)
285            .with_body("mocked response body")
286            .create_async()
287            .await;
288
289        let url = _server.url() + path;
290        MockedServer {
291            _server,
292            url,
293            mocked_endpoint: mock,
294        }
295    }
296
297    async fn spawn_upload_server_mocked(return_status: usize) -> MockedServer {
298        let mut _server = Server::new_async().await;
299        let path = "/upload_test";
300        let mock = _server
301            .mock("POST", path)
302            .with_status(return_status)
303            .with_body("upload successful")
304            .match_header("content-length", "20")
305            .create_async()
306            .await;
307
308        let url = _server.url() + path;
309        MockedServer {
310            _server,
311            url,
312            mocked_endpoint: mock,
313        }
314    }
315}