1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
mod dispatch;
mod split;

use bytes::Bytes;
use futures::{FutureExt, Stream, TryFutureExt, TryStreamExt};
use rusoto_core::{ByteStream, RusotoError};
use rusoto_s3::{
    AbortMultipartUploadRequest, CompleteMultipartUploadError, CompleteMultipartUploadOutput,
    CompleteMultipartUploadRequest, CompletedMultipartUpload, CompletedPart,
    CreateMultipartUploadError, CreateMultipartUploadRequest, UploadPartError, UploadPartRequest,
    S3,
};
use std::num::NonZeroUsize;
use std::ops::RangeInclusive;

// https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html
pub const PART_SIZE: RangeInclusive<usize> = 5 << 20..=5 << 30;

pub struct MultipartUploadRequest<B> {
    pub body: B,
    pub bucket: String,
    pub key: String,
}

pub type MultipartUploadOutput = CompleteMultipartUploadOutput;

pub async fn multipart_upload<C, B, E>(
    client: &C,
    input: MultipartUploadRequest<B>,
    part_size: RangeInclusive<usize>,
    concurrency_limit: Option<NonZeroUsize>,
) -> Result<MultipartUploadOutput, E>
where
    C: S3,
    B: Stream<Item = Result<Bytes, E>>,
    E: From<RusotoError<CreateMultipartUploadError>>
        + From<RusotoError<UploadPartError>>
        + From<RusotoError<CompleteMultipartUploadError>>,
{
    let MultipartUploadRequest { body, bucket, key } = input;

    let output = client
        .create_multipart_upload(CreateMultipartUploadRequest {
            bucket: bucket.clone(),
            key: key.clone(),
            ..CreateMultipartUploadRequest::default()
        })
        .await?;
    let upload_id = output.upload_id.as_ref().unwrap();

    let stream = split::split(body, part_size).map_ok(|part| {
        client
            .upload_part(UploadPartRequest {
                body: Some(ByteStream::new(futures::stream::iter(
                    part.body.into_iter().map(Ok),
                ))),
                bucket: bucket.clone(),
                content_length: Some(part.content_length as _),
                content_md5: Some(base64::encode(part.content_md5)),
                key: key.clone(),
                part_number: part.part_number as _,
                upload_id: upload_id.clone(),
                ..UploadPartRequest::default()
            })
            .map_ok({
                let part_number = part.part_number;
                move |output| CompletedPart {
                    e_tag: output.e_tag,
                    part_number: Some(part_number as _),
                }
            })
            .err_into()
    });

    (async {
        let mut completed_parts = dispatch::dispatch_concurrent(stream, concurrency_limit).await?;
        completed_parts.sort_by_key(|completed_part| completed_part.part_number);

        let output = client
            .complete_multipart_upload(CompleteMultipartUploadRequest {
                bucket: bucket.clone(),
                key: key.clone(),
                multipart_upload: Some(CompletedMultipartUpload {
                    parts: Some(completed_parts),
                }),
                upload_id: upload_id.clone(),
                ..CompleteMultipartUploadRequest::default()
            })
            .await?;

        Ok(output)
    })
    .or_else(|e| {
        client
            .abort_multipart_upload(AbortMultipartUploadRequest {
                bucket: bucket.clone(),
                key: key.clone(),
                upload_id: upload_id.clone(),
                ..AbortMultipartUploadRequest::default()
            })
            .map(|_| Err(e))
    })
    .await
}

#[cfg(test)]
mod tests;