1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use bytes::Bytes;
use s2n_tls::error::Error;

impl Format {
    pub fn as_pem(&self) -> Option<&[u8]> {
        if let Format::Pem(bytes) = &self {
            Some(bytes.as_ref())
        } else {
            None
        }
    }

    #[allow(dead_code)] // remove if s2n-tls ever starts supporting DER certs
    pub fn as_der(&self) -> Option<&[u8]> {
        if let Format::Der(bytes) = &self {
            Some(bytes.as_ref())
        } else {
            None
        }
    }
}

pub(crate) enum Format {
    Pem(Bytes),
    Der(Bytes),
    #[allow(dead_code)] // Only used if private key offloading supported
    None,
}

macro_rules! cert_type {
    ($name:ident, $trait:ident, $method:ident) => {
        pub struct $name(pub(crate) Format);

        pub trait $trait {
            fn $method(self) -> Result<$name, Error>;
        }

        impl $trait for $name {
            fn $method(self) -> Result<$name, Error> {
                Ok(self)
            }
        }

        impl $trait for String {
            fn $method(self) -> Result<$name, Error> {
                let bytes = self.into_bytes();
                let bytes = Bytes::from(bytes);
                let bytes = Format::Pem(bytes);
                Ok($name(bytes))
            }
        }

        impl $trait for &String {
            fn $method(self) -> Result<$name, Error> {
                let bytes = self.as_bytes();
                let bytes = Bytes::copy_from_slice(bytes);
                let bytes = Format::Pem(bytes);
                Ok($name(bytes))
            }
        }

        impl $trait for &str {
            fn $method(self) -> Result<$name, Error> {
                let bytes = self.as_bytes();
                let bytes = Bytes::copy_from_slice(bytes);
                let bytes = Format::Pem(bytes);
                Ok($name(bytes))
            }
        }

        impl $trait for Vec<u8> {
            fn $method(self) -> Result<$name, Error> {
                let bytes = Bytes::from(self);
                let bytes = Format::Der(bytes);
                Ok($name(bytes))
            }
        }

        impl $trait for &[u8] {
            fn $method(self) -> Result<$name, Error> {
                let bytes = Bytes::copy_from_slice(self);
                let bytes = Format::Der(bytes);
                Ok($name(bytes))
            }
        }

        impl $trait for &std::path::Path {
            fn $method(self) -> Result<$name, Error> {
                match self.extension() {
                    Some(ext) if ext == "der" => {
                        let der = std::fs::read(self).map_err(|err| Error::io_error(err))?;
                        der.$method()
                    }
                    // assume it's in pem format
                    _ => {
                        let pem =
                            std::fs::read_to_string(self).map_err(|err| Error::io_error(err))?;
                        pem.$method()
                    }
                }
            }
        }
    };
}

cert_type!(PrivateKey, IntoPrivateKey, into_private_key);
cert_type!(Certificate, IntoCertificate, into_certificate);

#[cfg(any(test, feature = "unstable_private_key"))]
pub const OFFLOAD_PRIVATE_KEY: PrivateKey = PrivateKey(Format::None);