svc_authn/
account.rs

1use std::fmt;
2use std::str::FromStr;
3
4use crate::Authenticable;
5use crate::Error;
6
7////////////////////////////////////////////////////////////////////////////////
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10#[cfg_attr(feature = "diesel", derive(FromSqlRow, AsExpression))]
11#[cfg_attr(feature = "diesel", sql_type = "sql::Account_id")]
12#[cfg_attr(feature = "sqlx", derive(sqlx::Type))]
13#[cfg_attr(feature = "sqlx", sqlx(type_name = "account_id"))]
14pub struct AccountId {
15    label: String,
16    audience: String,
17}
18
19impl AccountId {
20    pub fn new(label: &str, audience: &str) -> Self {
21        Self {
22            label: label.to_owned(),
23            audience: audience.to_owned(),
24        }
25    }
26
27    pub fn label(&self) -> &str {
28        &self.label
29    }
30
31    pub fn audience(&self) -> &str {
32        &self.audience
33    }
34}
35
36impl fmt::Display for AccountId {
37    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
38        write!(fmt, "{}.{}", self.label, self.audience)
39    }
40}
41
42impl FromStr for AccountId {
43    type Err = Error;
44
45    fn from_str(val: &str) -> Result<Self, Self::Err> {
46        let parts: Vec<&str> = val.splitn(2, '.').collect();
47        match parts[..] {
48            [label, audience] => Ok(Self::new(label, audience)),
49            _ => Err(Error::new(&format!(
50                "invalid value for the application name: {}",
51                val
52            ))),
53        }
54    }
55}
56
57impl Authenticable for AccountId {
58    fn as_account_id(&self) -> &Self {
59        self
60    }
61}
62
63////////////////////////////////////////////////////////////////////////////////
64
65#[cfg(feature = "jose")]
66pub mod jose {
67    use super::AccountId;
68    use crate::jose::Claims;
69
70    impl From<Claims<String>> for AccountId {
71        fn from(value: Claims<String>) -> Self {
72            Self::new(value.subject(), value.audience())
73        }
74    }
75}
76
77#[cfg(feature = "diesel")]
78pub mod sql {
79    use diesel::deserialize::{self, FromSql};
80    use diesel::pg::Pg;
81    use diesel::serialize::{self, Output, ToSql, WriteTuple};
82    use diesel::sql_types::{Record, Text};
83    use std::io::Write;
84
85    use super::AccountId;
86
87    #[derive(SqlType, QueryId)]
88    #[postgres(type_name = "account_id")]
89    #[allow(non_camel_case_types)]
90    pub struct Account_id;
91
92    impl ToSql<Account_id, Pg> for AccountId {
93        fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
94            WriteTuple::<(Text, Text)>::write_tuple(&(&self.label, &self.audience), out)
95        }
96    }
97
98    impl FromSql<Account_id, Pg> for AccountId {
99        fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
100            let (label, audience): (String, String) =
101                FromSql::<Record<(Text, Text)>, Pg>::from_sql(bytes)?;
102            Ok(AccountId::new(&label, &audience))
103        }
104    }
105}
106
107////////////////////////////////////////////////////////////////////////////////
108
109mod serde {
110    use serde::{de, ser};
111    use std::fmt;
112
113    use super::AccountId;
114
115    impl ser::Serialize for AccountId {
116        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
117        where
118            S: ser::Serializer,
119        {
120            serializer.serialize_str(&self.to_string())
121        }
122    }
123
124    impl<'de> de::Deserialize<'de> for AccountId {
125        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126        where
127            D: de::Deserializer<'de>,
128        {
129            struct AccountIdVisitor;
130
131            impl<'de> de::Visitor<'de> for AccountIdVisitor {
132                type Value = AccountId;
133
134                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
135                    formatter.write_str("struct AccountId")
136                }
137
138                fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
139                where
140                    E: de::Error,
141                {
142                    use std::str::FromStr;
143
144                    AccountId::from_str(v)
145                        .map_err(|_| de::Error::invalid_value(de::Unexpected::Str(v), &self))
146                }
147            }
148
149            deserializer.deserialize_str(AccountIdVisitor)
150        }
151    }
152}