1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use crate::{
    body::{BodyStream, RequestBody},
    context::WebContext,
    error::{forward_blank_bad_request, Error},
    handler::FromRequest,
};

pub type Multipart<B = RequestBody> = http_multipart::Multipart<B>;

impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for Multipart<B>
where
    B: BodyStream + Default,
{
    type Type<'b> = Multipart<B>;
    type Error = Error<C>;

    async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
        let body = ctx.take_body_ref();
        http_multipart::multipart(ctx.req(), body).map_err(Error::from_service)
    }
}

forward_blank_bad_request!(http_multipart::MultipartError);

#[cfg(test)]
mod test {
    use core::pin::pin;

    use xitca_unsafe_collection::futures::NowOrPanic;

    use crate::{
        handler::handler_service,
        http::{
            header::{HeaderValue, CONTENT_TYPE, TRANSFER_ENCODING},
            request, Method, RequestExt,
        },
        route::post,
        service::Service,
        test::collect_body,
        App,
    };

    use super::*;

    async fn handler(multipart: Multipart) -> Vec<u8> {
        let mut multipart = pin!(multipart);

        let mut res = Vec::new();

        {
            let mut field = multipart.try_next().await.ok().unwrap().unwrap();

            assert_eq!(field.name().unwrap(), "file");
            assert_eq!(field.file_name().unwrap(), "foo.txt");

            while let Some(bytes) = field.try_next().await.ok().unwrap() {
                res.extend_from_slice(bytes.as_ref());
            }
        }

        {
            let mut field = multipart.try_next().await.ok().unwrap().unwrap();

            assert_eq!(field.name().unwrap(), "file");
            assert_eq!(field.file_name().unwrap(), "bar.txt");

            while let Some(bytes) = field.try_next().await.ok().unwrap() {
                res.extend_from_slice(bytes.as_ref());
            }
        }

        res
    }

    #[test]
    fn simple() {
        let body: &'static [u8] = b"\
            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
            Content-Disposition: form-data; name=\"file\"; filename=\"foo.txt\"\r\n\
            Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
            test\r\n\
            --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
            Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
            Content-Type: text/plain\r\nContent-Length: 8\r\n\r\n\
            testdata\r\n\
            --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n";

        let req = request::Builder::default()
            .method(Method::POST)
            .header(
                CONTENT_TYPE,
                HeaderValue::from_static("multipart/mixed; boundary=abbc761f78ff4d7cb7573b5a23f96ef0"),
            )
            .header(TRANSFER_ENCODING, HeaderValue::from_static("chunked"))
            .body(RequestExt::default().map_body(|_: ()| body.into()))
            .unwrap();

        let res = App::new()
            .at("/", post(handler_service(handler)))
            .finish()
            .call(())
            .now_or_panic()
            .unwrap()
            .call(req)
            .now_or_panic()
            .unwrap();

        let body = collect_body(res.into_body()).now_or_panic().unwrap();

        assert_eq!(body, b"testtestdata");
    }
}