1#![doc(
10 html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
11 html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
12)]
13
14mod transfer_stats;
15use transfer_stats::TransferStats;
16
17use futures_util::TryStreamExt;
18use serde::{ser::Serializer, Deserialize, Serialize};
19use tauri::{
20 command,
21 ipc::Channel,
22 plugin::{Builder as PluginBuilder, TauriPlugin},
23 Runtime,
24};
25use tokio::{
26 fs::File,
27 io::{AsyncWriteExt, BufWriter},
28};
29use tokio_util::codec::{BytesCodec, FramedRead};
30
31use read_progress_stream::ReadProgressStream;
32
33use std::collections::HashMap;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(rename_all = "UPPERCASE")]
37pub enum HttpMethod {
38 Post,
39 Put,
40 Patch,
41}
42
43type Result<T> = std::result::Result<T, Error>;
44
45#[derive(Debug, thiserror::Error)]
46pub enum Error {
47 #[error(transparent)]
48 Io(#[from] std::io::Error),
49 #[error(transparent)]
50 Request(#[from] reqwest::Error),
51 #[error("{0}")]
52 ContentLength(String),
53 #[error("request failed with status code {0}: {1}")]
54 HttpErrorCode(u16, String),
55}
56
57impl Serialize for Error {
58 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
59 where
60 S: Serializer,
61 {
62 serializer.serialize_str(self.to_string().as_ref())
63 }
64}
65
66#[derive(Clone, Serialize)]
67#[serde(rename_all = "camelCase")]
68struct ProgressPayload {
69 progress: u64,
70 progress_total: u64,
71 total: u64,
72 transfer_speed: u64,
73}
74
75#[command]
76async fn download(
77 url: String,
78 file_path: String,
79 headers: HashMap<String, String>,
80 body: Option<String>,
81 on_progress: Channel<ProgressPayload>,
82) -> Result<()> {
83 tokio::spawn(async move {
84 let client = reqwest::Client::new();
85 let mut request = if let Some(body) = body {
86 client.post(&url).body(body)
87 } else {
88 client.get(&url)
89 };
90 for (key, value) in headers {
93 request = request.header(&key, value);
94 }
95
96 let response = request.send().await?;
97 if !response.status().is_success() {
98 return Err(Error::HttpErrorCode(
99 response.status().as_u16(),
100 response.text().await.unwrap_or_default(),
101 ));
102 }
103 let total = response.content_length().unwrap_or(0);
104
105 let mut file = BufWriter::new(File::create(&file_path).await?);
106 let mut stream = response.bytes_stream();
107
108 let mut stats = TransferStats::default();
109 while let Some(chunk) = stream.try_next().await? {
110 file.write_all(&chunk).await?;
111 stats.record_chunk_transfer(chunk.len());
112 let _ = on_progress.send(ProgressPayload {
113 progress: chunk.len() as u64,
114 progress_total: stats.total_transferred,
115 total,
116 transfer_speed: stats.transfer_speed,
117 });
118 }
119 file.flush().await?;
120 Ok(())
121 })
122 .await
123 .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
124}
125
126#[command]
127async fn upload(
128 url: String,
129 file_path: String,
130 headers: HashMap<String, String>,
131 method: Option<HttpMethod>,
132 on_progress: Channel<ProgressPayload>,
133) -> Result<String> {
134 tokio::spawn(async move {
135 let file = File::open(&file_path).await?;
137 let file_len = file.metadata().await.unwrap().len();
138
139 let http_method = method.unwrap_or(HttpMethod::Post);
141
142 let client = reqwest::Client::new();
144 let mut request = match http_method {
145 HttpMethod::Put => client.put(&url),
146 HttpMethod::Patch => client.patch(&url),
147 HttpMethod::Post => client.post(&url),
148 }
149 .header(reqwest::header::CONTENT_LENGTH, file_len)
150 .body(file_to_body(on_progress, file, file_len));
151
152 for (key, value) in headers {
155 request = request.header(&key, value);
156 }
157
158 let response = request.send().await?;
159 if response.status().is_success() {
160 response.text().await.map_err(Into::into)
161 } else {
162 Err(Error::HttpErrorCode(
163 response.status().as_u16(),
164 response.text().await.unwrap_or_default(),
165 ))
166 }
167 })
168 .await
169 .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
170}
171
172fn file_to_body(channel: Channel<ProgressPayload>, file: File, file_len: u64) -> reqwest::Body {
173 let stream = FramedRead::new(file, BytesCodec::new()).map_ok(|r| r.freeze());
174
175 let mut stats = TransferStats::default();
176 reqwest::Body::wrap_stream(ReadProgressStream::new(
177 stream,
178 Box::new(move |progress, _total| {
179 stats.record_chunk_transfer(progress as usize);
180 let _ = channel.send(ProgressPayload {
181 progress,
182 progress_total: stats.total_transferred,
183 total: file_len,
184 transfer_speed: stats.transfer_speed,
185 });
186 }),
187 ))
188}
189
190pub fn init<R: Runtime>() -> TauriPlugin<R> {
191 PluginBuilder::new("upload")
192 .invoke_handler(tauri::generate_handler![download, upload])
193 .build()
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use mockito::{self, Mock, Server, ServerGuard};
200 use tauri::ipc::InvokeResponseBody;
201 struct MockedServer {
202 _server: ServerGuard,
203 url: String,
204 mocked_endpoint: Mock,
205 }
206
207 #[tokio::test]
208 async fn should_error_on_download_if_status_not_success() {
209 let mocked_server = spawn_server_mocked(400).await;
210 let result = download_file(mocked_server.url).await;
211 mocked_server.mocked_endpoint.assert();
212 assert!(result.is_err());
213 }
214
215 #[tokio::test]
216 async fn should_download_file_successfully() {
217 let mocked_server = spawn_server_mocked(200).await;
218 let result = download_file(mocked_server.url).await;
219 mocked_server.mocked_endpoint.assert();
220 assert!(
221 result.is_ok(),
222 "failed to download file: {}",
223 result.unwrap_err()
224 );
225 }
226
227 #[tokio::test]
228 async fn should_error_on_upload_if_status_not_success() {
229 let mocked_server = spawn_upload_server_mocked(500, "POST").await;
230 let result = upload_file(mocked_server.url, None).await;
231 mocked_server.mocked_endpoint.assert();
232 assert!(result.is_err());
233 match result.unwrap_err() {
234 Error::HttpErrorCode(status, _) => assert_eq!(status, 500),
235 _ => panic!("Expected HttpErrorCode error"),
236 }
237 }
238
239 #[tokio::test]
240 async fn should_error_on_upload_if_file_not_found() {
241 let mocked_server = spawn_upload_server_mocked(200, "POST").await;
242 let file_path = "/nonexistent/file.txt".to_string();
243 let headers = HashMap::new();
244 let sender: Channel<ProgressPayload> =
245 Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
246 let _ = msg;
247 Ok(())
248 });
249
250 let result = upload(mocked_server.url, file_path, headers, None, sender).await;
251 assert!(result.is_err());
252 match result.unwrap_err() {
253 Error::Io(_) => {}
254 _ => panic!("Expected IO error for missing file"),
255 }
256 }
257
258 #[tokio::test]
259 async fn should_upload_file_with_post_method() {
260 let mocked_server = spawn_upload_server_mocked(200, "POST").await;
261 let result = upload_file(mocked_server.url, Some(HttpMethod::Post)).await;
262 mocked_server.mocked_endpoint.assert();
263 assert!(
264 result.is_ok(),
265 "failed to upload file: {}",
266 result.unwrap_err()
267 );
268 let response_body = result.unwrap();
269 assert_eq!(response_body, "upload successful");
270 }
271
272 #[tokio::test]
273 async fn should_upload_file_with_put_method() {
274 let mocked_server = spawn_upload_server_mocked(200, "PUT").await;
275 let result = upload_file(mocked_server.url, Some(HttpMethod::Put)).await;
276 mocked_server.mocked_endpoint.assert();
277 assert!(
278 result.is_ok(),
279 "failed to upload file with PUT: {}",
280 result.unwrap_err()
281 );
282 let response_body = result.unwrap();
283 assert_eq!(response_body, "upload successful");
284 }
285
286 #[tokio::test]
287 async fn should_upload_file_with_patch_method() {
288 let mocked_server = spawn_upload_server_mocked(200, "PATCH").await;
289 let result = upload_file(mocked_server.url, Some(HttpMethod::Patch)).await;
290 mocked_server.mocked_endpoint.assert();
291 assert!(
292 result.is_ok(),
293 "failed to upload file with PATCH: {}",
294 result.unwrap_err()
295 );
296 let response_body = result.unwrap();
297 assert_eq!(response_body, "upload successful");
298 }
299
300 async fn download_file(url: String) -> Result<()> {
301 let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/download.txt").to_string();
302 let headers = HashMap::new();
303 let sender: Channel<ProgressPayload> =
304 Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
305 let _ = msg;
306 Ok(())
307 });
308 download(url, file_path, headers, None, sender).await
309 }
310
311 async fn upload_file(url: String, method: Option<HttpMethod>) -> Result<String> {
312 let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/upload.txt").to_string();
313 let headers = HashMap::new();
314 let sender: Channel<ProgressPayload> =
315 Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
316 let _ = msg;
317 Ok(())
318 });
319 upload(url, file_path, headers, method, sender).await
320 }
321
322 async fn spawn_server_mocked(return_status: usize) -> MockedServer {
323 let mut _server = Server::new_async().await;
324 let path = "/mock_test";
325 let mock = _server
326 .mock("GET", path)
327 .with_status(return_status)
328 .with_body("mocked response body")
329 .create_async()
330 .await;
331
332 let url = _server.url() + path;
333 MockedServer {
334 _server,
335 url,
336 mocked_endpoint: mock,
337 }
338 }
339
340 async fn spawn_upload_server_mocked(return_status: usize, method: &str) -> MockedServer {
341 let mut _server = Server::new_async().await;
342 let path = "/upload_test";
343 let mock = _server
344 .mock(method, path)
345 .with_status(return_status)
346 .with_body("upload successful")
347 .match_header("content-length", "20")
348 .create_async()
349 .await;
350
351 let url = _server.url() + path;
352 MockedServer {
353 _server,
354 url,
355 mocked_endpoint: mock,
356 }
357 }
358}