1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
mod aws;
mod specified;

pub use {aws::AwsPrincipal, specified::SpecifiedPrincipal};

use {
    crate::display_json,
    log::debug,
    serde::{
        de::{self, value::MapAccessDeserializer, Deserializer, MapAccess, Unexpected, Visitor},
        ser::Serializer,
        Deserialize, Serialize,
    },
    std::fmt::{Formatter, Result as FmtResult},
};

#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Principal {
    Any,
    Specified(SpecifiedPrincipal),
}

impl From<SpecifiedPrincipal> for Principal {
    fn from(sp: SpecifiedPrincipal) -> Self {
        Self::Specified(sp)
    }
}

struct PrincipalVisitor {}

impl<'de> Visitor<'de> for PrincipalVisitor {
    type Value = Principal;

    fn expecting(&self, f: &mut Formatter) -> FmtResult {
        write!(f, "map of principal types to values or \"*\"")
    }

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        if v == "*" {
            Ok(Principal::Any)
        } else {
            return Err(E::invalid_value(Unexpected::Str(v), &self));
        }
    }

    fn visit_map<A>(self, access: A) -> Result<Self::Value, A::Error>
    where
        A: MapAccess<'de>,
    {
        let deserializer = MapAccessDeserializer::new(access);
        match SpecifiedPrincipal::deserialize(deserializer) {
            Ok(pm) => Ok(Principal::Specified(pm)),
            Err(e) => {
                debug!("Failed to deserialize statement: {:?}", e);
                Err(e)
            }
        }
    }
}

impl<'de> Deserialize<'de> for Principal {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        deserializer.deserialize_any(PrincipalVisitor {})
    }
}

impl Serialize for Principal {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        match self {
            Self::Any => serializer.serialize_str("*"),
            Self::Specified(specified) => specified.serialize(serializer),
        }
    }
}

display_json!(Principal);

#[cfg(test)]
mod tests {
    use {
        crate::{AwsPrincipal, Principal, SpecifiedPrincipal},
        indoc::indoc,
        pretty_assertions::assert_eq,
        std::{str::FromStr, sync::Arc},
    };

    #[test_log::test]
    fn test_formatting() {
        let aws_principal = vec![
            Arc::new(AwsPrincipal::from_str("123456789012").unwrap()),
            Arc::new(AwsPrincipal::from_str("arn:aws:iam::123456789012:role/test").unwrap()),
        ];
        let p1 = Principal::Any;
        let p2 = Principal::Specified(SpecifiedPrincipal::builder().aws(aws_principal).build().unwrap());

        assert_eq!(format!("{}", p1), r#""*""#);
        assert_eq!(
            format!("{}", p2),
            indoc! { r#"
            {
                "AWS": [
                    "123456789012",
                    "arn:aws:iam::123456789012:role/test"
                ]
            }"#}
        )
    }
}