tauri_plugin_upload/
lib.rs

1// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5//! Upload files from disk to a remote server over HTTP.
6//!
7//! Download files from a remote HTTP server to disk.
8
9#![doc(
10    html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
11    html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
12)]
13
14mod transfer_stats;
15use transfer_stats::TransferStats;
16
17use futures_util::TryStreamExt;
18use serde::{ser::Serializer, Deserialize, Serialize};
19use tauri::{
20    command,
21    ipc::Channel,
22    plugin::{Builder as PluginBuilder, TauriPlugin},
23    Runtime,
24};
25use tokio::{
26    fs::File,
27    io::{AsyncWriteExt, BufWriter},
28};
29use tokio_util::codec::{BytesCodec, FramedRead};
30
31use read_progress_stream::ReadProgressStream;
32
33use std::collections::HashMap;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(rename_all = "UPPERCASE")]
37pub enum HttpMethod {
38    Post,
39    Put,
40    Patch,
41}
42
43type Result<T> = std::result::Result<T, Error>;
44
45#[derive(Debug, thiserror::Error)]
46pub enum Error {
47    #[error(transparent)]
48    Io(#[from] std::io::Error),
49    #[error(transparent)]
50    Request(#[from] reqwest::Error),
51    #[error("{0}")]
52    ContentLength(String),
53    #[error("request failed with status code {0}: {1}")]
54    HttpErrorCode(u16, String),
55}
56
57impl Serialize for Error {
58    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
59    where
60        S: Serializer,
61    {
62        serializer.serialize_str(self.to_string().as_ref())
63    }
64}
65
66#[derive(Clone, Serialize)]
67#[serde(rename_all = "camelCase")]
68struct ProgressPayload {
69    progress: u64,
70    progress_total: u64,
71    total: u64,
72    transfer_speed: u64,
73}
74
75#[command]
76async fn download(
77    url: String,
78    file_path: String,
79    headers: HashMap<String, String>,
80    body: Option<String>,
81    on_progress: Channel<ProgressPayload>,
82) -> Result<()> {
83    tokio::spawn(async move {
84        let client = reqwest::Client::new();
85        let mut request = if let Some(body) = body {
86            client.post(&url).body(body)
87        } else {
88            client.get(&url)
89        };
90        // Loop trought the headers keys and values
91        // and add them to the request object.
92        for (key, value) in headers {
93            request = request.header(&key, value);
94        }
95
96        let response = request.send().await?;
97        if !response.status().is_success() {
98            return Err(Error::HttpErrorCode(
99                response.status().as_u16(),
100                response.text().await.unwrap_or_default(),
101            ));
102        }
103        let total = response.content_length().unwrap_or(0);
104
105        let mut file = BufWriter::new(File::create(&file_path).await?);
106        let mut stream = response.bytes_stream();
107
108        let mut stats = TransferStats::default();
109        while let Some(chunk) = stream.try_next().await? {
110            file.write_all(&chunk).await?;
111            stats.record_chunk_transfer(chunk.len());
112            let _ = on_progress.send(ProgressPayload {
113                progress: chunk.len() as u64,
114                progress_total: stats.total_transferred,
115                total,
116                transfer_speed: stats.transfer_speed,
117            });
118        }
119        file.flush().await?;
120        Ok(())
121    })
122    .await
123    .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
124}
125
126#[command]
127async fn upload(
128    url: String,
129    file_path: String,
130    headers: HashMap<String, String>,
131    method: Option<HttpMethod>,
132    on_progress: Channel<ProgressPayload>,
133) -> Result<String> {
134    tokio::spawn(async move {
135        // Read the file
136        let file = File::open(&file_path).await?;
137        let file_len = file.metadata().await.unwrap().len();
138
139        // Get HTTP method (defaults to POST)
140        let http_method = method.unwrap_or(HttpMethod::Post);
141
142        // Create the request and attach the file to the body
143        let client = reqwest::Client::new();
144        let mut request = match http_method {
145            HttpMethod::Put => client.put(&url),
146            HttpMethod::Patch => client.patch(&url),
147            HttpMethod::Post => client.post(&url),
148        }
149        .header(reqwest::header::CONTENT_LENGTH, file_len)
150        .body(file_to_body(on_progress, file, file_len));
151
152        // Loop through the headers keys and values
153        // and add them to the request object.
154        for (key, value) in headers {
155            request = request.header(&key, value);
156        }
157
158        let response = request.send().await?;
159        if response.status().is_success() {
160            response.text().await.map_err(Into::into)
161        } else {
162            Err(Error::HttpErrorCode(
163                response.status().as_u16(),
164                response.text().await.unwrap_or_default(),
165            ))
166        }
167    })
168    .await
169    .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
170}
171
172fn file_to_body(channel: Channel<ProgressPayload>, file: File, file_len: u64) -> reqwest::Body {
173    let stream = FramedRead::new(file, BytesCodec::new()).map_ok(|r| r.freeze());
174
175    let mut stats = TransferStats::default();
176    reqwest::Body::wrap_stream(ReadProgressStream::new(
177        stream,
178        Box::new(move |progress, _total| {
179            stats.record_chunk_transfer(progress as usize);
180            let _ = channel.send(ProgressPayload {
181                progress,
182                progress_total: stats.total_transferred,
183                total: file_len,
184                transfer_speed: stats.transfer_speed,
185            });
186        }),
187    ))
188}
189
190pub fn init<R: Runtime>() -> TauriPlugin<R> {
191    PluginBuilder::new("upload")
192        .invoke_handler(tauri::generate_handler![download, upload])
193        .build()
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use mockito::{self, Mock, Server, ServerGuard};
200    use tauri::ipc::InvokeResponseBody;
201    struct MockedServer {
202        _server: ServerGuard,
203        url: String,
204        mocked_endpoint: Mock,
205    }
206
207    #[tokio::test]
208    async fn should_error_on_download_if_status_not_success() {
209        let mocked_server = spawn_server_mocked(400).await;
210        let result = download_file(mocked_server.url).await;
211        mocked_server.mocked_endpoint.assert();
212        assert!(result.is_err());
213    }
214
215    #[tokio::test]
216    async fn should_download_file_successfully() {
217        let mocked_server = spawn_server_mocked(200).await;
218        let result = download_file(mocked_server.url).await;
219        mocked_server.mocked_endpoint.assert();
220        assert!(
221            result.is_ok(),
222            "failed to download file: {}",
223            result.unwrap_err()
224        );
225    }
226
227    #[tokio::test]
228    async fn should_error_on_upload_if_status_not_success() {
229        let mocked_server = spawn_upload_server_mocked(500, "POST").await;
230        let result = upload_file(mocked_server.url, None).await;
231        mocked_server.mocked_endpoint.assert();
232        assert!(result.is_err());
233        match result.unwrap_err() {
234            Error::HttpErrorCode(status, _) => assert_eq!(status, 500),
235            _ => panic!("Expected HttpErrorCode error"),
236        }
237    }
238
239    #[tokio::test]
240    async fn should_error_on_upload_if_file_not_found() {
241        let mocked_server = spawn_upload_server_mocked(200, "POST").await;
242        let file_path = "/nonexistent/file.txt".to_string();
243        let headers = HashMap::new();
244        let sender: Channel<ProgressPayload> =
245            Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
246                let _ = msg;
247                Ok(())
248            });
249
250        let result = upload(mocked_server.url, file_path, headers, None, sender).await;
251        assert!(result.is_err());
252        match result.unwrap_err() {
253            Error::Io(_) => {}
254            _ => panic!("Expected IO error for missing file"),
255        }
256    }
257
258    #[tokio::test]
259    async fn should_upload_file_with_post_method() {
260        let mocked_server = spawn_upload_server_mocked(200, "POST").await;
261        let result = upload_file(mocked_server.url, Some(HttpMethod::Post)).await;
262        mocked_server.mocked_endpoint.assert();
263        assert!(
264            result.is_ok(),
265            "failed to upload file: {}",
266            result.unwrap_err()
267        );
268        let response_body = result.unwrap();
269        assert_eq!(response_body, "upload successful");
270    }
271
272    #[tokio::test]
273    async fn should_upload_file_with_put_method() {
274        let mocked_server = spawn_upload_server_mocked(200, "PUT").await;
275        let result = upload_file(mocked_server.url, Some(HttpMethod::Put)).await;
276        mocked_server.mocked_endpoint.assert();
277        assert!(
278            result.is_ok(),
279            "failed to upload file with PUT: {}",
280            result.unwrap_err()
281        );
282        let response_body = result.unwrap();
283        assert_eq!(response_body, "upload successful");
284    }
285
286    #[tokio::test]
287    async fn should_upload_file_with_patch_method() {
288        let mocked_server = spawn_upload_server_mocked(200, "PATCH").await;
289        let result = upload_file(mocked_server.url, Some(HttpMethod::Patch)).await;
290        mocked_server.mocked_endpoint.assert();
291        assert!(
292            result.is_ok(),
293            "failed to upload file with PATCH: {}",
294            result.unwrap_err()
295        );
296        let response_body = result.unwrap();
297        assert_eq!(response_body, "upload successful");
298    }
299
300    async fn download_file(url: String) -> Result<()> {
301        let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/download.txt").to_string();
302        let headers = HashMap::new();
303        let sender: Channel<ProgressPayload> =
304            Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
305                let _ = msg;
306                Ok(())
307            });
308        download(url, file_path, headers, None, sender).await
309    }
310
311    async fn upload_file(url: String, method: Option<HttpMethod>) -> Result<String> {
312        let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/upload.txt").to_string();
313        let headers = HashMap::new();
314        let sender: Channel<ProgressPayload> =
315            Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
316                let _ = msg;
317                Ok(())
318            });
319        upload(url, file_path, headers, method, sender).await
320    }
321
322    async fn spawn_server_mocked(return_status: usize) -> MockedServer {
323        let mut _server = Server::new_async().await;
324        let path = "/mock_test";
325        let mock = _server
326            .mock("GET", path)
327            .with_status(return_status)
328            .with_body("mocked response body")
329            .create_async()
330            .await;
331
332        let url = _server.url() + path;
333        MockedServer {
334            _server,
335            url,
336            mocked_endpoint: mock,
337        }
338    }
339
340    async fn spawn_upload_server_mocked(return_status: usize, method: &str) -> MockedServer {
341        let mut _server = Server::new_async().await;
342        let path = "/upload_test";
343        let mock = _server
344            .mock(method, path)
345            .with_status(return_status)
346            .with_body("upload successful")
347            .match_header("content-length", "20")
348            .create_async()
349            .await;
350
351        let url = _server.url() + path;
352        MockedServer {
353            _server,
354            url,
355            mocked_endpoint: mock,
356        }
357    }
358}