tauri_plugin_upload/
lib.rs

1// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5//! Upload files from disk to a remote server over HTTP.
6//!
7//! Download files from a remote HTTP server to disk.
8
9#![doc(
10    html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
11    html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
12)]
13
14mod transfer_stats;
15use transfer_stats::TransferStats;
16
17use futures_util::TryStreamExt;
18use serde::{ser::Serializer, Serialize};
19use tauri::{
20    command,
21    ipc::Channel,
22    plugin::{Builder as PluginBuilder, TauriPlugin},
23    Runtime,
24};
25use tokio::{
26    fs::File,
27    io::{AsyncWriteExt, BufWriter},
28};
29use tokio_util::codec::{BytesCodec, FramedRead};
30
31use read_progress_stream::ReadProgressStream;
32
33use std::collections::HashMap;
34
35type Result<T> = std::result::Result<T, Error>;
36
37#[derive(Debug, thiserror::Error)]
38pub enum Error {
39    #[error(transparent)]
40    Io(#[from] std::io::Error),
41    #[error(transparent)]
42    Request(#[from] reqwest::Error),
43    #[error("{0}")]
44    ContentLength(String),
45    #[error("request failed with status code {0}: {1}")]
46    HttpErrorCode(u16, String),
47}
48
49impl Serialize for Error {
50    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
51    where
52        S: Serializer,
53    {
54        serializer.serialize_str(self.to_string().as_ref())
55    }
56}
57
58#[derive(Clone, Serialize)]
59#[serde(rename_all = "camelCase")]
60struct ProgressPayload {
61    progress: u64,
62    progress_total: u64,
63    total: u64,
64    transfer_speed: u64,
65}
66
67#[command]
68async fn download(
69    url: &str,
70    file_path: &str,
71    headers: HashMap<String, String>,
72    body: Option<String>,
73    on_progress: Channel<ProgressPayload>,
74) -> Result<()> {
75    let client = reqwest::Client::new();
76    let mut request = if let Some(body) = body {
77        client.post(url).body(body)
78    } else {
79        client.get(url)
80    };
81    // Loop trought the headers keys and values
82    // and add them to the request object.
83    for (key, value) in headers {
84        request = request.header(&key, value);
85    }
86
87    let response = request.send().await?;
88    if !response.status().is_success() {
89        return Err(Error::HttpErrorCode(
90            response.status().as_u16(),
91            response.text().await.unwrap_or_default(),
92        ));
93    }
94    let total = response.content_length().unwrap_or(0);
95
96    let mut file = BufWriter::new(File::create(file_path).await?);
97    let mut stream = response.bytes_stream();
98
99    let mut stats = TransferStats::default();
100    while let Some(chunk) = stream.try_next().await? {
101        file.write_all(&chunk).await?;
102        stats.record_chunk_transfer(chunk.len());
103        let _ = on_progress.send(ProgressPayload {
104            progress: chunk.len() as u64,
105            progress_total: stats.total_transferred,
106            total,
107            transfer_speed: stats.transfer_speed,
108        });
109    }
110    file.flush().await?;
111
112    Ok(())
113}
114
115#[command]
116async fn upload(
117    url: &str,
118    file_path: &str,
119    headers: HashMap<String, String>,
120    on_progress: Channel<ProgressPayload>,
121) -> Result<String> {
122    // Read the file
123    let file = File::open(file_path).await?;
124    let file_len = file.metadata().await.unwrap().len();
125
126    // Create the request and attach the file to the body
127    let client = reqwest::Client::new();
128    let mut request = client
129        .post(url)
130        .header(reqwest::header::CONTENT_LENGTH, file_len)
131        .body(file_to_body(on_progress, file));
132
133    // Loop through the headers keys and values
134    // and add them to the request object.
135    for (key, value) in headers {
136        request = request.header(&key, value);
137    }
138
139    let response = request.send().await?;
140    if response.status().is_success() {
141        response.text().await.map_err(Into::into)
142    } else {
143        Err(Error::HttpErrorCode(
144            response.status().as_u16(),
145            response.text().await.unwrap_or_default(),
146        ))
147    }
148}
149
150fn file_to_body(channel: Channel<ProgressPayload>, file: File) -> reqwest::Body {
151    let stream = FramedRead::new(file, BytesCodec::new()).map_ok(|r| r.freeze());
152
153    let mut stats = TransferStats::default();
154    reqwest::Body::wrap_stream(ReadProgressStream::new(
155        stream,
156        Box::new(move |progress, total| {
157            stats.record_chunk_transfer(progress as usize);
158            let _ = channel.send(ProgressPayload {
159                progress,
160                progress_total: stats.total_transferred,
161                total,
162                transfer_speed: stats.transfer_speed,
163            });
164        }),
165    ))
166}
167
168pub fn init<R: Runtime>() -> TauriPlugin<R> {
169    PluginBuilder::new("upload")
170        .invoke_handler(tauri::generate_handler![download, upload])
171        .build()
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use mockito::{self, Mock, Server, ServerGuard};
178    use tauri::ipc::InvokeResponseBody;
179    struct MockedServer {
180        _server: ServerGuard,
181        url: String,
182        mocked_endpoint: Mock,
183    }
184
185    #[tokio::test]
186    async fn should_error_if_status_not_success() {
187        let mocked_server = spawn_server_mocked(400).await;
188        let result = download_file(&mocked_server.url).await;
189        mocked_server.mocked_endpoint.assert();
190        assert!(result.is_err());
191    }
192
193    #[tokio::test]
194    async fn should_download_file_successfully() {
195        let mocked_server = spawn_server_mocked(200).await;
196        let result = download_file(&mocked_server.url).await;
197        mocked_server.mocked_endpoint.assert();
198        assert!(
199            result.is_ok(),
200            "failed to download file: {}",
201            result.unwrap_err()
202        );
203    }
204
205    async fn download_file(url: &str) -> Result<()> {
206        let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/test.txt");
207        let headers = HashMap::new();
208        let sender: Channel<ProgressPayload> =
209            Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
210                let _ = msg;
211                Ok(())
212            });
213        download(url, file_path, headers, None, sender).await
214    }
215
216    async fn spawn_server_mocked(return_status: usize) -> MockedServer {
217        let mut _server = Server::new_async().await;
218        let path = "/mock_test";
219        let mock = _server
220            .mock("GET", path)
221            .with_status(return_status)
222            .with_body("mocked response body")
223            .create_async()
224            .await;
225
226        let url = _server.url() + path;
227        MockedServer {
228            _server,
229            url,
230            mocked_endpoint: mock,
231        }
232    }
233}