Skip to main content

s3m/stream/
upload_multipart.rs

1use crate::{
2    cli::{globals::GlobalArgs, progressbar::Bar},
3    s3::{S3, actions, checksum::Checksum},
4    stream::{db::Db, iterator::PartIterator, part::Part},
5};
6use anyhow::{Result, anyhow};
7use futures::stream::{FuturesUnordered, StreamExt};
8use rkyv::{from_bytes, rancor::Error as RkyvError, to_bytes};
9use sled::transaction::{TransactionError, Transactional};
10use std::{collections::BTreeMap, path::Path};
11use tokio::time::{Duration, sleep};
12
13pub struct MultipartUploadRequest<'a> {
14    pub s3: &'a S3,
15    pub key: &'a str,
16    pub file: &'a Path,
17    pub file_size: u64,
18    pub chunk_size: u64,
19    pub sdb: &'a Db,
20    pub acl: Option<String>,
21    pub meta: Option<BTreeMap<String, String>>,
22    pub quiet: bool,
23    pub additional_checksum: Option<Checksum>,
24    pub max_requests: u8,
25    pub globals: GlobalArgs,
26}
27
28struct UploadPartRequest<'a> {
29    s3: &'a S3,
30    key: &'a str,
31    file: &'a Path,
32    part_number: u16,
33    uid: &'a str,
34    seek: u64,
35    chunk: u64,
36    additional_checksum: &'a mut Option<Checksum>,
37    globals: GlobalArgs,
38}
39
40// https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingRESTAPImpUpload.html
41// * Initiate Multipart Upload
42// * Upload Part
43// * Complete Multipart Upload
44/// # Errors
45/// Will return an error if the upload fails
46pub async fn upload_multipart(request: MultipartUploadRequest<'_>) -> Result<String> {
47    log::debug!(
48        "Starting multi part upload:
49        key: {}
50        file: {}
51        file_size: {}
52        part size: {}
53        acl: {:#?}
54        meta: {:#?}
55        additional checksum: {:#?}",
56        request.key,
57        request.file.display(),
58        request.file_size,
59        request.chunk_size,
60        request.acl.as_ref(),
61        request.meta.as_ref(),
62        request.additional_checksum.as_ref()
63    );
64
65    // trees for keeping track of parts to upload
66    let db_parts = request.sdb.db_parts()?;
67    let db_uploaded = request.sdb.db_uploaded()?;
68
69    let upload_id = if let Some(uid) = request.sdb.upload_id()? {
70        uid
71    } else {
72        // Initiate Multipart Upload - request an Upload ID
73        let action = actions::CreateMultipartUpload::new(
74            request.key,
75            request.acl,
76            request.meta,
77            request.additional_checksum.clone(),
78        );
79
80        let response = action.request(request.s3).await?;
81
82        db_parts.clear()?;
83        // save the upload_id to resume if required
84        request.sdb.save_upload_id(&response.upload_id)?;
85        response.upload_id
86    };
87
88    log::debug!("upload_id: {}", &upload_id);
89
90    // if db_parts is not empty it means that a previous upload did not finish successfully.
91    // skip creating the parts again and try to re-upload the pending ones
92    if db_parts.is_empty() {
93        for (number, seek, chunk) in PartIterator::new(request.file_size, request.chunk_size) {
94            request
95                .sdb
96                .create_part(number, seek, chunk, request.additional_checksum.clone())?;
97        }
98        db_parts.flush()?;
99    }
100
101    // Upload parts progress bar
102    let pb = if request.quiet {
103        Bar::default()
104    } else {
105        Bar::new(request.file_size)
106    };
107
108    increment_progress_bar(&pb, db_uploaded.len() as u64 * request.chunk_size, None);
109
110    let mut tasks = FuturesUnordered::new();
111
112    log::info!("Max concurrent requests: {}", request.max_requests);
113
114    for part in db_parts
115        .iter()
116        .values()
117        .filter_map(Result::ok)
118        .map(|p| from_bytes::<Part, RkyvError>(&p).map_err(anyhow::Error::from))
119    {
120        let part: Part = part?;
121
122        log::info!("Task push part: {}", part.get_number());
123
124        // spawn task (upload part)
125        tasks.push(upload_part(
126            request.s3,
127            request.key,
128            request.file,
129            &upload_id,
130            request.sdb,
131            part,
132            &request.globals,
133        ));
134
135        await_tasks(
136            &mut tasks,
137            &pb,
138            request.chunk_size,
139            request.max_requests.into(),
140        )
141        .await?;
142    }
143
144    // wait for the remaining tasks
145    await_remaining_tasks(&mut tasks, &pb, request.chunk_size).await?;
146
147    // finish progress bar
148    increment_progress_bar(&pb, 0, Some(true));
149
150    if !db_parts.is_empty() {
151        return Err(anyhow!("could not upload all parts"));
152    }
153
154    // Complete Multipart Upload
155    let uploaded = request.sdb.uploaded_parts()?;
156    let action = actions::CompleteMultipartUpload::new(
157        request.key,
158        &upload_id,
159        uploaded,
160        request.additional_checksum,
161    );
162    let rs = action.request(request.s3).await?;
163
164    // cleanup uploads tree
165    db_uploaded.clear()?;
166
167    // save the returned Etag
168    request.sdb.save_etag(&rs.e_tag)?;
169
170    // upload finished
171    log::info!("Upload finished, ETag: {}", rs.e_tag);
172
173    Ok(format!("ETag: {}", rs.e_tag))
174}
175
176// throttling tasks
177async fn await_tasks<T>(
178    tasks: &mut FuturesUnordered<T>,
179    pb: &Bar,
180    chunk_size: u64,
181    max_requests: usize,
182) -> Result<()>
183where
184    T: std::future::Future<Output = Result<usize>> + Send,
185{
186    log::debug!("Running tasks: {}", tasks.len());
187
188    // limit to num_cpus - 2 or 1
189    while tasks.len() >= max_requests {
190        if let Some(r) = tasks.next().await {
191            r.map_err(|e| anyhow!("{e}"))?;
192            increment_progress_bar(pb, chunk_size, None);
193        }
194    }
195
196    Ok(())
197}
198
199// consume remaining tasks
200async fn await_remaining_tasks<T>(
201    tasks: &mut FuturesUnordered<T>,
202    pb: &Bar,
203    chunk_size: u64,
204) -> Result<()>
205where
206    T: std::future::Future<Output = Result<usize>> + Send,
207{
208    log::debug!("Remaining tasks: {}", tasks.len());
209
210    while let Some(r) = tasks.next().await {
211        r.map_err(|e| anyhow!("{e}"))?;
212        increment_progress_bar(pb, chunk_size, None);
213    }
214    Ok(())
215}
216
217fn increment_progress_bar(pb: &Bar, chunk_size: u64, finish: Option<bool>) {
218    if let Some(pb) = pb.progress.as_ref() {
219        pb.inc(chunk_size);
220
221        if finish == Some(true) {
222            pb.finish();
223        }
224    }
225}
226
227async fn upload_part(
228    s3: &S3,
229    key: &str,
230    file: &Path,
231    uid: &str,
232    db: &Db,
233    part: Part,
234    globals: &GlobalArgs,
235) -> Result<usize> {
236    let unprocessed = db.db_parts()?;
237    let processed = db.db_uploaded()?;
238
239    let mut additional_checksum = part.get_checksum();
240
241    // do request to get the ETag and update the checksum if required
242    let part_number = part.get_number();
243
244    let mut etag: String = String::new();
245
246    // Retry with exponential backoff
247    for attempt in 1..=globals.retries {
248        let backoff_time = 2u64.pow(attempt - 1);
249        if attempt > 1 {
250            log::warn!("Error uploading part: {part_number}, retrying in {backoff_time} seconds");
251
252            sleep(Duration::from_secs(backoff_time)).await;
253        }
254
255        match try_upload_part(UploadPartRequest {
256            s3,
257            key,
258            file,
259            part_number,
260            uid,
261            seek: part.get_seek(),
262            chunk: part.get_chunk(),
263            additional_checksum: &mut additional_checksum,
264            globals: globals.clone(),
265        })
266        .await
267        {
268            Ok(e) => {
269                etag = e;
270
271                log::info!(
272                    "Uploaded part: {}, etag: {}{}",
273                    part_number,
274                    etag,
275                    additional_checksum
276                        .as_ref()
277                        .map(|c| format!(" additional_checksum: {}", c.checksum))
278                        .unwrap_or_default()
279                );
280
281                break;
282            }
283
284            Err(e) => {
285                log::error!(
286                    "Error uploading part: {}, attempt {}/{} failed: {}",
287                    part.get_number(),
288                    attempt,
289                    globals.retries,
290                    e
291                );
292
293                // Increment attempt after an error
294                if attempt == globals.retries {
295                    // If it's the last attempt, return the error without incrementing attempt
296                    return Err(e);
297                }
298            }
299        }
300    }
301
302    // update part with the etag and checksum if any
303    let part = part.set_etag(etag).set_checksum(additional_checksum);
304
305    let rkyv_part = to_bytes::<RkyvError>(&part)?;
306
307    // move part to uploaded
308    (&unprocessed, &processed)
309        .transaction(|(unprocessed, processed)| {
310            unprocessed.remove(&part_number.to_be_bytes())?;
311            processed.insert(&part_number.to_be_bytes(), rkyv_part.as_slice())?;
312            Ok(())
313        })
314        .map_err(|err| match err {
315            TransactionError::Abort(err) | TransactionError::Storage(err) => err,
316        })?;
317
318    db.flush()
319}
320
321async fn try_upload_part(request: UploadPartRequest<'_>) -> Result<String> {
322    let action = actions::UploadPart::new(
323        request.key,
324        request.file,
325        request.part_number,
326        request.uid,
327        request.seek,
328        request.chunk,
329        request.additional_checksum.as_mut(),
330    );
331
332    log::debug!("Uploading part: {}", request.part_number);
333
334    action.request(request.s3, &request.globals).await
335}