1use crate::{
2 cli::{globals::GlobalArgs, progressbar::Bar},
3 s3::{S3, actions, checksum::Checksum},
4 stream::{db::Db, iterator::PartIterator, part::Part},
5};
6use anyhow::{Result, anyhow};
7use futures::stream::{FuturesUnordered, StreamExt};
8use rkyv::{from_bytes, rancor::Error as RkyvError, to_bytes};
9use sled::transaction::{TransactionError, Transactional};
10use std::{collections::BTreeMap, path::Path};
11use tokio::time::{Duration, sleep};
12
13pub struct MultipartUploadRequest<'a> {
14 pub s3: &'a S3,
15 pub key: &'a str,
16 pub file: &'a Path,
17 pub file_size: u64,
18 pub chunk_size: u64,
19 pub sdb: &'a Db,
20 pub acl: Option<String>,
21 pub meta: Option<BTreeMap<String, String>>,
22 pub quiet: bool,
23 pub additional_checksum: Option<Checksum>,
24 pub max_requests: u8,
25 pub globals: GlobalArgs,
26}
27
28struct UploadPartRequest<'a> {
29 s3: &'a S3,
30 key: &'a str,
31 file: &'a Path,
32 part_number: u16,
33 uid: &'a str,
34 seek: u64,
35 chunk: u64,
36 additional_checksum: &'a mut Option<Checksum>,
37 globals: GlobalArgs,
38}
39
40pub async fn upload_multipart(request: MultipartUploadRequest<'_>) -> Result<String> {
47 log::debug!(
48 "Starting multi part upload:
49 key: {}
50 file: {}
51 file_size: {}
52 part size: {}
53 acl: {:#?}
54 meta: {:#?}
55 additional checksum: {:#?}",
56 request.key,
57 request.file.display(),
58 request.file_size,
59 request.chunk_size,
60 request.acl.as_ref(),
61 request.meta.as_ref(),
62 request.additional_checksum.as_ref()
63 );
64
65 let db_parts = request.sdb.db_parts()?;
67 let db_uploaded = request.sdb.db_uploaded()?;
68
69 let upload_id = if let Some(uid) = request.sdb.upload_id()? {
70 uid
71 } else {
72 let action = actions::CreateMultipartUpload::new(
74 request.key,
75 request.acl,
76 request.meta,
77 request.additional_checksum.clone(),
78 );
79
80 let response = action.request(request.s3).await?;
81
82 db_parts.clear()?;
83 request.sdb.save_upload_id(&response.upload_id)?;
85 response.upload_id
86 };
87
88 log::debug!("upload_id: {}", &upload_id);
89
90 if db_parts.is_empty() {
93 for (number, seek, chunk) in PartIterator::new(request.file_size, request.chunk_size) {
94 request
95 .sdb
96 .create_part(number, seek, chunk, request.additional_checksum.clone())?;
97 }
98 db_parts.flush()?;
99 }
100
101 let pb = if request.quiet {
103 Bar::default()
104 } else {
105 Bar::new(request.file_size)
106 };
107
108 increment_progress_bar(&pb, db_uploaded.len() as u64 * request.chunk_size, None);
109
110 let mut tasks = FuturesUnordered::new();
111
112 log::info!("Max concurrent requests: {}", request.max_requests);
113
114 for part in db_parts
115 .iter()
116 .values()
117 .filter_map(Result::ok)
118 .map(|p| from_bytes::<Part, RkyvError>(&p).map_err(anyhow::Error::from))
119 {
120 let part: Part = part?;
121
122 log::info!("Task push part: {}", part.get_number());
123
124 tasks.push(upload_part(
126 request.s3,
127 request.key,
128 request.file,
129 &upload_id,
130 request.sdb,
131 part,
132 &request.globals,
133 ));
134
135 await_tasks(
136 &mut tasks,
137 &pb,
138 request.chunk_size,
139 request.max_requests.into(),
140 )
141 .await?;
142 }
143
144 await_remaining_tasks(&mut tasks, &pb, request.chunk_size).await?;
146
147 increment_progress_bar(&pb, 0, Some(true));
149
150 if !db_parts.is_empty() {
151 return Err(anyhow!("could not upload all parts"));
152 }
153
154 let uploaded = request.sdb.uploaded_parts()?;
156 let action = actions::CompleteMultipartUpload::new(
157 request.key,
158 &upload_id,
159 uploaded,
160 request.additional_checksum,
161 );
162 let rs = action.request(request.s3).await?;
163
164 db_uploaded.clear()?;
166
167 request.sdb.save_etag(&rs.e_tag)?;
169
170 log::info!("Upload finished, ETag: {}", rs.e_tag);
172
173 Ok(format!("ETag: {}", rs.e_tag))
174}
175
176async fn await_tasks<T>(
178 tasks: &mut FuturesUnordered<T>,
179 pb: &Bar,
180 chunk_size: u64,
181 max_requests: usize,
182) -> Result<()>
183where
184 T: std::future::Future<Output = Result<usize>> + Send,
185{
186 log::debug!("Running tasks: {}", tasks.len());
187
188 while tasks.len() >= max_requests {
190 if let Some(r) = tasks.next().await {
191 r.map_err(|e| anyhow!("{e}"))?;
192 increment_progress_bar(pb, chunk_size, None);
193 }
194 }
195
196 Ok(())
197}
198
199async fn await_remaining_tasks<T>(
201 tasks: &mut FuturesUnordered<T>,
202 pb: &Bar,
203 chunk_size: u64,
204) -> Result<()>
205where
206 T: std::future::Future<Output = Result<usize>> + Send,
207{
208 log::debug!("Remaining tasks: {}", tasks.len());
209
210 while let Some(r) = tasks.next().await {
211 r.map_err(|e| anyhow!("{e}"))?;
212 increment_progress_bar(pb, chunk_size, None);
213 }
214 Ok(())
215}
216
217fn increment_progress_bar(pb: &Bar, chunk_size: u64, finish: Option<bool>) {
218 if let Some(pb) = pb.progress.as_ref() {
219 pb.inc(chunk_size);
220
221 if finish == Some(true) {
222 pb.finish();
223 }
224 }
225}
226
227async fn upload_part(
228 s3: &S3,
229 key: &str,
230 file: &Path,
231 uid: &str,
232 db: &Db,
233 part: Part,
234 globals: &GlobalArgs,
235) -> Result<usize> {
236 let unprocessed = db.db_parts()?;
237 let processed = db.db_uploaded()?;
238
239 let mut additional_checksum = part.get_checksum();
240
241 let part_number = part.get_number();
243
244 let mut etag: String = String::new();
245
246 for attempt in 1..=globals.retries {
248 let backoff_time = 2u64.pow(attempt - 1);
249 if attempt > 1 {
250 log::warn!("Error uploading part: {part_number}, retrying in {backoff_time} seconds");
251
252 sleep(Duration::from_secs(backoff_time)).await;
253 }
254
255 match try_upload_part(UploadPartRequest {
256 s3,
257 key,
258 file,
259 part_number,
260 uid,
261 seek: part.get_seek(),
262 chunk: part.get_chunk(),
263 additional_checksum: &mut additional_checksum,
264 globals: globals.clone(),
265 })
266 .await
267 {
268 Ok(e) => {
269 etag = e;
270
271 log::info!(
272 "Uploaded part: {}, etag: {}{}",
273 part_number,
274 etag,
275 additional_checksum
276 .as_ref()
277 .map(|c| format!(" additional_checksum: {}", c.checksum))
278 .unwrap_or_default()
279 );
280
281 break;
282 }
283
284 Err(e) => {
285 log::error!(
286 "Error uploading part: {}, attempt {}/{} failed: {}",
287 part.get_number(),
288 attempt,
289 globals.retries,
290 e
291 );
292
293 if attempt == globals.retries {
295 return Err(e);
297 }
298 }
299 }
300 }
301
302 let part = part.set_etag(etag).set_checksum(additional_checksum);
304
305 let rkyv_part = to_bytes::<RkyvError>(&part)?;
306
307 (&unprocessed, &processed)
309 .transaction(|(unprocessed, processed)| {
310 unprocessed.remove(&part_number.to_be_bytes())?;
311 processed.insert(&part_number.to_be_bytes(), rkyv_part.as_slice())?;
312 Ok(())
313 })
314 .map_err(|err| match err {
315 TransactionError::Abort(err) | TransactionError::Storage(err) => err,
316 })?;
317
318 db.flush()
319}
320
321async fn try_upload_part(request: UploadPartRequest<'_>) -> Result<String> {
322 let action = actions::UploadPart::new(
323 request.key,
324 request.file,
325 request.part_number,
326 request.uid,
327 request.seek,
328 request.chunk,
329 request.additional_checksum.as_mut(),
330 );
331
332 log::debug!("Uploading part: {}", request.part_number);
333
334 action.request(request.s3, &request.globals).await
335}