1#![doc(
10 html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
11 html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
12)]
13
14mod transfer_stats;
15use transfer_stats::TransferStats;
16
17use futures_util::TryStreamExt;
18use serde::{ser::Serializer, Serialize};
19use tauri::{
20 command,
21 ipc::Channel,
22 plugin::{Builder as PluginBuilder, TauriPlugin},
23 Runtime,
24};
25use tokio::{
26 fs::File,
27 io::{AsyncWriteExt, BufWriter},
28};
29use tokio_util::codec::{BytesCodec, FramedRead};
30
31use read_progress_stream::ReadProgressStream;
32
33use std::collections::HashMap;
34
35type Result<T> = std::result::Result<T, Error>;
36
37#[derive(Debug, thiserror::Error)]
38pub enum Error {
39 #[error(transparent)]
40 Io(#[from] std::io::Error),
41 #[error(transparent)]
42 Request(#[from] reqwest::Error),
43 #[error("{0}")]
44 ContentLength(String),
45 #[error("request failed with status code {0}: {1}")]
46 HttpErrorCode(u16, String),
47}
48
49impl Serialize for Error {
50 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
51 where
52 S: Serializer,
53 {
54 serializer.serialize_str(self.to_string().as_ref())
55 }
56}
57
58#[derive(Clone, Serialize)]
59#[serde(rename_all = "camelCase")]
60struct ProgressPayload {
61 progress: u64,
62 progress_total: u64,
63 total: u64,
64 transfer_speed: u64,
65}
66
67#[command]
68async fn download(
69 url: String,
70 file_path: String,
71 headers: HashMap<String, String>,
72 body: Option<String>,
73 on_progress: Channel<ProgressPayload>,
74) -> Result<()> {
75 tokio::spawn(async move {
76 let client = reqwest::Client::new();
77 let mut request = if let Some(body) = body {
78 client.post(&url).body(body)
79 } else {
80 client.get(&url)
81 };
82 for (key, value) in headers {
85 request = request.header(&key, value);
86 }
87
88 let response = request.send().await?;
89 if !response.status().is_success() {
90 return Err(Error::HttpErrorCode(
91 response.status().as_u16(),
92 response.text().await.unwrap_or_default(),
93 ));
94 }
95 let total = response.content_length().unwrap_or(0);
96
97 let mut file = BufWriter::new(File::create(&file_path).await?);
98 let mut stream = response.bytes_stream();
99
100 let mut stats = TransferStats::default();
101 while let Some(chunk) = stream.try_next().await? {
102 file.write_all(&chunk).await?;
103 stats.record_chunk_transfer(chunk.len());
104 let _ = on_progress.send(ProgressPayload {
105 progress: chunk.len() as u64,
106 progress_total: stats.total_transferred,
107 total,
108 transfer_speed: stats.transfer_speed,
109 });
110 }
111 file.flush().await?;
112 Ok(())
113 })
114 .await
115 .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
116}
117
118#[command]
119async fn upload(
120 url: String,
121 file_path: String,
122 headers: HashMap<String, String>,
123 on_progress: Channel<ProgressPayload>,
124) -> Result<String> {
125 tokio::spawn(async move {
126 let file = File::open(&file_path).await?;
128 let file_len = file.metadata().await.unwrap().len();
129
130 let client = reqwest::Client::new();
132 let mut request = client
133 .post(&url)
134 .header(reqwest::header::CONTENT_LENGTH, file_len)
135 .body(file_to_body(on_progress, file, file_len));
136
137 for (key, value) in headers {
140 request = request.header(&key, value);
141 }
142
143 let response = request.send().await?;
144 if response.status().is_success() {
145 response.text().await.map_err(Into::into)
146 } else {
147 Err(Error::HttpErrorCode(
148 response.status().as_u16(),
149 response.text().await.unwrap_or_default(),
150 ))
151 }
152 })
153 .await
154 .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
155}
156
157fn file_to_body(channel: Channel<ProgressPayload>, file: File, file_len: u64) -> reqwest::Body {
158 let stream = FramedRead::new(file, BytesCodec::new()).map_ok(|r| r.freeze());
159
160 let mut stats = TransferStats::default();
161 reqwest::Body::wrap_stream(ReadProgressStream::new(
162 stream,
163 Box::new(move |progress, _total| {
164 stats.record_chunk_transfer(progress as usize);
165 let _ = channel.send(ProgressPayload {
166 progress,
167 progress_total: stats.total_transferred,
168 total: file_len,
169 transfer_speed: stats.transfer_speed,
170 });
171 }),
172 ))
173}
174
175pub fn init<R: Runtime>() -> TauriPlugin<R> {
176 PluginBuilder::new("upload")
177 .invoke_handler(tauri::generate_handler![download, upload])
178 .build()
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use mockito::{self, Mock, Server, ServerGuard};
185 use tauri::ipc::InvokeResponseBody;
186 struct MockedServer {
187 _server: ServerGuard,
188 url: String,
189 mocked_endpoint: Mock,
190 }
191
192 #[tokio::test]
193 async fn should_error_on_download_if_status_not_success() {
194 let mocked_server = spawn_server_mocked(400).await;
195 let result = download_file(mocked_server.url).await;
196 mocked_server.mocked_endpoint.assert();
197 assert!(result.is_err());
198 }
199
200 #[tokio::test]
201 async fn should_download_file_successfully() {
202 let mocked_server = spawn_server_mocked(200).await;
203 let result = download_file(mocked_server.url).await;
204 mocked_server.mocked_endpoint.assert();
205 assert!(
206 result.is_ok(),
207 "failed to download file: {}",
208 result.unwrap_err()
209 );
210 }
211
212 #[tokio::test]
213 async fn should_error_on_upload_if_status_not_success() {
214 let mocked_server = spawn_upload_server_mocked(500).await;
215 let result = upload_file(mocked_server.url).await;
216 mocked_server.mocked_endpoint.assert();
217 assert!(result.is_err());
218 match result.unwrap_err() {
219 Error::HttpErrorCode(status, _) => assert_eq!(status, 500),
220 _ => panic!("Expected HttpErrorCode error"),
221 }
222 }
223
224 #[tokio::test]
225 async fn should_error_on_upload_if_file_not_found() {
226 let mocked_server = spawn_upload_server_mocked(200).await;
227 let file_path = "/nonexistent/file.txt".to_string();
228 let headers = HashMap::new();
229 let sender: Channel<ProgressPayload> =
230 Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
231 let _ = msg;
232 Ok(())
233 });
234
235 let result = upload(mocked_server.url, file_path, headers, sender).await;
236 assert!(result.is_err());
237 match result.unwrap_err() {
238 Error::Io(_) => {}
239 _ => panic!("Expected IO error for missing file"),
240 }
241 }
242
243 #[tokio::test]
244 async fn should_upload_file_successfully() {
245 let mocked_server = spawn_upload_server_mocked(200).await;
246 let result = upload_file(mocked_server.url).await;
247 mocked_server.mocked_endpoint.assert();
248 assert!(
249 result.is_ok(),
250 "failed to upload file: {}",
251 result.unwrap_err()
252 );
253 let response_body = result.unwrap();
254 assert_eq!(response_body, "upload successful");
255 }
256
257 async fn download_file(url: String) -> Result<()> {
258 let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/test.txt").to_string();
259 let headers = HashMap::new();
260 let sender: Channel<ProgressPayload> =
261 Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
262 let _ = msg;
263 Ok(())
264 });
265 download(url, file_path, headers, None, sender).await
266 }
267
268 async fn upload_file(url: String) -> Result<String> {
269 let file_path = concat!(env!("CARGO_MANIFEST_DIR"), "/test/test.txt").to_string();
270 let headers = HashMap::new();
271 let sender: Channel<ProgressPayload> =
272 Channel::new(|msg: InvokeResponseBody| -> tauri::Result<()> {
273 let _ = msg;
274 Ok(())
275 });
276 upload(url, file_path, headers, sender).await
277 }
278
279 async fn spawn_server_mocked(return_status: usize) -> MockedServer {
280 let mut _server = Server::new_async().await;
281 let path = "/mock_test";
282 let mock = _server
283 .mock("GET", path)
284 .with_status(return_status)
285 .with_body("mocked response body")
286 .create_async()
287 .await;
288
289 let url = _server.url() + path;
290 MockedServer {
291 _server,
292 url,
293 mocked_endpoint: mock,
294 }
295 }
296
297 async fn spawn_upload_server_mocked(return_status: usize) -> MockedServer {
298 let mut _server = Server::new_async().await;
299 let path = "/upload_test";
300 let mock = _server
301 .mock("POST", path)
302 .with_status(return_status)
303 .with_body("upload successful")
304 .match_header("content-length", "20")
305 .create_async()
306 .await;
307
308 let url = _server.url() + path;
309 MockedServer {
310 _server,
311 url,
312 mocked_endpoint: mock,
313 }
314 }
315}