1use base64::{Engine as _, engine::general_purpose};
2use ed25519_dalek::{VerifyingKey, ed25519};
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use std::collections::HashMap;
6
7#[derive(Debug)]
10pub enum KeyringError {
11 UnsupportedAlgorithm,
13 ParsingError(base64::DecodeError),
16 ConversionError(ed25519::Error),
19 KeyAlreadyExists,
21}
22
23pub type PublicKey = Vec<u8>;
25
26#[derive(Clone, Debug, PartialEq, Eq)]
29pub enum Algorithm {
30 Ed25519,
32 RsaPssSha512,
34 RsaV1_5Sha256,
36 HmacSha256,
38 EcdsaP256Sha256,
40 EcdsaP384Sha384,
42}
43
44impl std::fmt::Display for Algorithm {
45 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
46 match self {
47 Algorithm::Ed25519 => write!(f, "ed25519"),
48 Algorithm::RsaPssSha512 => write!(f, "rsa-pss-sha512"),
49 Algorithm::RsaV1_5Sha256 => write!(f, "rsa-pss-sha512"),
50 Algorithm::HmacSha256 => write!(f, "hmac-sha256"),
51 Algorithm::EcdsaP256Sha256 => write!(f, "ecdsa-p256-sha256"),
52 Algorithm::EcdsaP384Sha384 => write!(f, "ecdsa-p384-sha384"),
53 }
54 }
55}
56
57#[derive(Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
60#[serde(tag = "kty")]
61pub enum Thumbprintable {
62 EC {
64 crv: String,
66 x: String,
68 y: String,
70 },
71 OKP {
73 crv: String,
75 x: String,
77 },
78 RSA {
80 e: String,
82 n: String,
84 },
85 #[serde(rename = "oct")]
87 OCT {
88 k: String,
90 },
91}
92
93#[derive(Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
95pub struct JSONWebKeySet {
96 pub keys: Vec<Thumbprintable>,
98}
99
100impl Thumbprintable {
101 pub fn b64_thumbprint(&self) -> String {
103 general_purpose::URL_SAFE_NO_PAD.encode(Sha256::digest(match self {
104 Thumbprintable::EC { crv, x, y } => {
105 format!("{{\"crv\":\"{crv}\",\"kty\":\"EC\",\"x\":\"{x}\",\"y\":\"{y}\"}}")
106 }
107 Thumbprintable::OKP { crv, x } => {
108 format!("{{\"crv\":\"{crv}\",\"kty\":\"OKP\",\"x\":\"{x}\"}}")
109 }
110 Thumbprintable::RSA { e, n } => {
111 format!("{{\"e\":\"{e}\",\"kty\":\"RSA\",\"n\":\"{n}\"}}")
112 }
113 Thumbprintable::OCT { k } => format!("{{\"k\":\"{k}\",\"kty\":\"oct\"}}"),
114 }))
115 }
116
117 pub fn public_key(&self) -> Result<Vec<u8>, KeyringError> {
125 match self {
126 Thumbprintable::OKP { crv, x } => match crv.as_str() {
127 "Ed25519" => {
128 let decoded = general_purpose::URL_SAFE_NO_PAD
129 .decode(x)
130 .map_err(KeyringError::ParsingError)?;
131 VerifyingKey::try_from(decoded.as_slice())
132 .map(|key| key.to_bytes().to_vec())
133 .map_err(KeyringError::ConversionError)
134 }
135 _ => Err(KeyringError::UnsupportedAlgorithm),
136 },
137 _ => Err(KeyringError::UnsupportedAlgorithm),
138 }
139 }
140
141 pub fn algorithm(&self) -> Result<Algorithm, KeyringError> {
147 match self {
148 Thumbprintable::OKP { crv, .. } => match crv.as_str() {
149 "Ed25519" => Ok(Algorithm::Ed25519),
150 _ => Err(KeyringError::UnsupportedAlgorithm),
151 },
152 _ => Err(KeyringError::UnsupportedAlgorithm),
153 }
154 }
155}
156
157#[derive(Default, Debug, Clone)]
160pub struct KeyRing {
161 ring: HashMap<String, (Algorithm, PublicKey)>,
162}
163
164impl FromIterator<(String, (Algorithm, PublicKey))> for KeyRing {
165 fn from_iter<T: IntoIterator<Item = (String, (Algorithm, PublicKey))>>(iter: T) -> KeyRing {
166 KeyRing {
167 ring: HashMap::from_iter(iter),
168 }
169 }
170}
171
172impl KeyRing {
173 pub fn import_raw(
176 &mut self,
177 identifier: String,
178 algorithm: Algorithm,
179 public_key: Vec<u8>,
180 ) -> bool {
181 !self.ring.contains_key(&identifier)
182 && self
183 .ring
184 .insert(identifier, (algorithm, public_key))
185 .is_none()
186 }
187
188 pub fn rename_key(&mut self, old_identifier: String, new_identifier: String) -> bool {
191 match self.ring.remove(&old_identifier) {
192 Some(value) => self.ring.insert(new_identifier, value).is_none(),
193 None => false,
194 }
195 }
196
197 pub fn get(&self, identifier: &String) -> Option<&(Algorithm, Vec<u8>)> {
199 self.ring.get(identifier)
200 }
201
202 pub fn try_import_jwk(&mut self, jwk: &Thumbprintable) -> Result<(), KeyringError> {
209 let thumbprint = jwk.b64_thumbprint();
210 let public_key = jwk.public_key()?;
211 let algorithm = jwk.algorithm()?;
212 if !self.import_raw(thumbprint, algorithm, public_key) {
213 return Err(KeyringError::KeyAlreadyExists);
214 }
215 Ok(())
216 }
217
218 pub fn import_jwks(&mut self, jwks: JSONWebKeySet) -> Vec<Option<KeyringError>> {
221 jwks.keys
222 .iter()
223 .map(|jwk| self.try_import_jwk(jwk).err())
224 .collect::<Vec<_>>()
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_importing_ed25519_key_from_jwks() {
234 let mut keyring = KeyRing::default();
235 let jwks: JSONWebKeySet = serde_json::from_str(r#"{"keys":[{"kty":"OKP","crv":"Ed25519","kid":"test-key-ed25519","d":"n4Ni-HpISpVObnQMW0wOhCKROaIKqKtW_2ZYb2p9KcU","x":"JrQLj5P_89iXES9-vFgrIy29clF9CC_oPPsw3c5D0bs"}]}"#).unwrap();
236 for (index, result) in keyring.import_jwks(jwks).into_iter().enumerate() {
237 assert_eq!(index, 0);
238 assert!(result.is_none());
239 }
240 assert!(
241 keyring
242 .get(&String::from("poqkLGiymh_W0uP6PZFw-dvez3QJT5SolqXBCW38r0U"))
243 .is_some()
244 );
245 assert!(keyring.rename_key(
246 String::from("poqkLGiymh_W0uP6PZFw-dvez3QJT5SolqXBCW38r0U"),
247 String::from("test-key-ed25519")
248 ));
249 assert!(keyring.get(&String::from("test-key-ed25519")).is_some());
250 }
251}