tauri_plugin_upload/
lib.rs1#![doc(
10 html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
11 html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
12)]
13
14mod transfer_stats;
15use transfer_stats::TransferStats;
16
17use futures_util::TryStreamExt;
18use serde::{ser::Serializer, Serialize};
19use tauri::{
20 command,
21 ipc::Channel,
22 plugin::{Builder as PluginBuilder, TauriPlugin},
23 Runtime,
24};
25use tokio::{
26 fs::File,
27 io::{AsyncWriteExt, BufWriter},
28};
29use tokio_util::codec::{BytesCodec, FramedRead};
30
31use read_progress_stream::ReadProgressStream;
32
33use std::collections::HashMap;
34
35type Result<T> = std::result::Result<T, Error>;
36
37#[derive(Debug, thiserror::Error)]
38pub enum Error {
39 #[error(transparent)]
40 Io(#[from] std::io::Error),
41 #[error(transparent)]
42 Request(#[from] reqwest::Error),
43 #[error("{0}")]
44 ContentLength(String),
45 #[error("request failed with status code {0}: {1}")]
46 HttpErrorCode(u16, String),
47}
48
49impl Serialize for Error {
50 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
51 where
52 S: Serializer,
53 {
54 serializer.serialize_str(self.to_string().as_ref())
55 }
56}
57
58#[derive(Clone, Serialize)]
59#[serde(rename_all = "camelCase")]
60struct ProgressPayload {
61 progress: u64,
62 progress_total: u64,
63 total: u64,
64 transfer_speed: u64,
65}
66
67#[command]
68async fn download(
69 url: &str,
70 file_path: &str,
71 headers: HashMap<String, String>,
72 body: Option<String>,
73 on_progress: Channel<ProgressPayload>,
74) -> Result<()> {
75 let client = reqwest::Client::new();
76 let mut request = if let Some(body) = body {
77 client.post(url).body(body)
78 } else {
79 client.get(url)
80 };
81 for (key, value) in headers {
84 request = request.header(&key, value);
85 }
86
87 let response = request.send().await?;
88 if !response.status().is_success() {
89 return Err(Error::HttpErrorCode(
90 response.status().as_u16(),
91 response.text().await.unwrap_or_default(),
92 ));
93 }
94 let total = response.content_length().unwrap_or(0);
95
96 let mut file = BufWriter::new(File::create(file_path).await?);
97 let mut stream = response.bytes_stream();
98
99 let mut stats = TransferStats::default();
100 while let Some(chunk) = stream.try_next().await? {
101 file.write_all(&chunk).await?;
102 stats.record_chunk_transfer(chunk.len());
103 let _ = on_progress.send(ProgressPayload {
104 progress: chunk.len() as u64,
105 progress_total: stats.total_transferred,
106 total,
107 transfer_speed: stats.transfer_speed,
108 });
109 }
110 file.flush().await?;
111
112 Ok(())
113}
114
115#[command]
116async fn upload(
117 url: &str,
118 file_path: &str,
119 headers: HashMap<String, String>,
120 on_progress: Channel<ProgressPayload>,
121) -> Result<String> {
122 let file = File::open(file_path).await?;
124 let file_len = file.metadata().await.unwrap().len();
125
126 let client = reqwest::Client::new();
128 let mut request = client
129 .post(url)
130 .header(reqwest::header::CONTENT_LENGTH, file_len)
131 .body(file_to_body(on_progress, file));
132
133 for (key, value) in headers {
136 request = request.header(&key, value);
137 }
138
139 let response = request.send().await?;
140 if response.status().is_success() {
141 response.text().await.map_err(Into::into)
142 } else {
143 Err(Error::HttpErrorCode(
144 response.status().as_u16(),
145 response.text().await.unwrap_or_default(),
146 ))
147 }
148}
149
150fn file_to_body(channel: Channel<ProgressPayload>, file: File) -> reqwest::Body {
151 let stream = FramedRead::new(file, BytesCodec::new()).map_ok(|r| r.freeze());
152
153 let mut stats = TransferStats::default();
154 reqwest::Body::wrap_stream(ReadProgressStream::new(
155 stream,
156 Box::new(move |progress, total| {
157 stats.record_chunk_transfer(progress as usize);
158 let _ = channel.send(ProgressPayload {
159 progress,
160 progress_total: stats.total_transferred,
161 total,
162 transfer_speed: stats.transfer_speed,
163 });
164 }),
165 ))
166}
167
168pub fn init<R: Runtime>() -> TauriPlugin<R> {
169 PluginBuilder::new("upload")
170 .invoke_handler(tauri::generate_handler![download, upload])
171 .build()
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use mockito::{self, Mock, Server, ServerGuard};
178 use tauri::ipc::InvokeResponseBody;
179 struct MockedServer {
180 _server: ServerGuard,
181 url: String,
182 mocked_endpoint: Mock,
183 }
184
185 #[tokio::test]
186 async fn should_error_if_status_not_success() {
187 let mocked_server = spawn_server_mocked(400).await;
188 let result = download_file(&mocked_server.url).await;
189 mocked_server.mocked_endpoint.assert();
190 assert!(result.is_err());
191 }
192
193 #[tokio::test]
194 async fn should_download_file_successfully() {
195 let mocked_server = spawn_server_mocked(200).await;
196 let result = download_file(&mocked_server.url).await;
197 mocked_server.mocked_endpoint.assert();
198 assert!(
199 result.is_ok(),
200 "failed to download file: {}",
201 result.unwrap_err()
202 );
203 }
204
205 async fn download_file(url: &str) -> Result<()> {
206 let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/test.txt");
207 let headers = HashMap::new();
208 let sender: Channel<ProgressPayload> =
209 Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
210 let _ = msg;
211 Ok(())
212 });
213 download(url, file_path, headers, None, sender).await
214 }
215
216 async fn spawn_server_mocked(return_status: usize) -> MockedServer {
217 let mut _server = Server::new_async().await;
218 let path = "/mock_test";
219 let mock = _server
220 .mock("GET", path)
221 .with_status(return_status)
222 .with_body("mocked response body")
223 .create_async()
224 .await;
225
226 let url = _server.url() + path;
227 MockedServer {
228 _server,
229 url,
230 mocked_endpoint: mock,
231 }
232 }
233}