1#![allow(clippy::len_zero)]
2
3#[macro_use]
4extern crate lazy_static;
5
6extern crate anyhow;
7
8pub mod authenticator;
9mod der;
10mod jwt;
11mod pki;
12pub mod spiffe;
13mod spire;
14mod verifier;
15mod workload;
16
17pub use authenticator::SpiffeIdAuthorizer;
18pub use jwt::{JwtBundle, JwtKey};
19pub use spiffe::{SpiffeID, SpiffeIDMatcher};
20
21use crate::der::parse_der_cert_chain;
22use anyhow::*;
23use arc_swap::ArcSwap;
24use rustls::ClientConfig;
25use rustls::{sign::CertifiedKey, PrivateKey};
26use rustls::{Certificate, RootCertStore};
27use std::collections::{BTreeMap, BTreeSet};
28use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
29use std::sync::Arc;
30use tokio::sync::watch::{channel, Receiver, Sender};
31use verifier::DynamicLoadedCertResolverVerifier;
32
33pub struct Identity {
34 pub cert_key: Arc<CertifiedKey>,
35 pub raw_key: Vec<u8>,
36 pub raw_bundle: Vec<Vec<u8>>,
37 pub bundle: Arc<RootCertStore>,
38}
39
40impl Identity {
41 pub fn from_raw(bundle: &[u8], certs: &[u8], key: &[u8]) -> Result<Identity> {
42 let certs = parse_der_cert_chain(certs)?;
43 let key = rustls::PrivateKey(key.to_vec());
44 let bundle = parse_der_cert_chain(bundle)?;
45 Self::from_rustls(bundle, certs, key)
46 }
47
48 pub fn from_rustls(
49 bundle: Vec<Certificate>,
50 certs: Vec<Certificate>,
51 key: PrivateKey,
52 ) -> Result<Identity> {
53 let cert_key = CertifiedKey::new(
54 certs,
55 rustls::sign::any_supported_type(&key)
56 .map_err(|_| anyhow!("unsupported private key type"))?,
57 );
58 let mut root_store = RootCertStore { roots: vec![] };
59 for bundle_cert in bundle.iter() {
60 root_store.add(bundle_cert)?;
61 }
62 Ok(Identity {
63 cert_key: Arc::new(cert_key),
64 raw_key: key.0,
65 raw_bundle: bundle.into_iter().map(|x| x.0).collect(),
66 bundle: Arc::new(root_store),
67 })
68 }
69}
70
71#[derive(Eq, Clone)]
72pub struct CrlEntry(pub Certificate);
73
74impl std::hash::Hash for CrlEntry {
75 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
76 (self.0).0.hash(state);
77 }
78}
79
80impl PartialEq for CrlEntry {
81 fn eq(&self, other: &CrlEntry) -> bool {
82 self.0 == other.0
83 }
84}
85
86impl PartialOrd for CrlEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
88 self.0 .0.partial_cmp(&other.0 .0)
89 }
90}
91
92impl Ord for CrlEntry {
93 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
94 self.0 .0.cmp(&other.0 .0)
95 }
96}
97
98pub static CURRENT_IDENTITY_VERSION: AtomicU64 = AtomicU64::new(0);
99pub static SPIFFEID_SEPARATOR: AtomicU8 = AtomicU8::new(58);
100
101lazy_static! {
102 pub(crate) static ref IDENTITY_UPDATE_WATCHER: (Sender<u64>, Receiver<u64>) = channel(0);
103 pub static ref IDENTITIES: ArcSwap<BTreeMap<SpiffeID, Arc<Identity>>> = ArcSwap::new(Arc::new(BTreeMap::new()));
104 pub static ref JWT_BUNDLES: ArcSwap<BTreeMap<String, Arc<JwtBundle>>> = ArcSwap::new(Arc::new(BTreeMap::new()));
105 pub static ref CERTIFICATE_REVOKATION_LIST: ArcSwap<BTreeSet<CrlEntry>> = ArcSwap::new(Arc::new(BTreeSet::new()));
107
108 pub static ref VALID_SPIFFEID_SEPARATORS: Vec<char> = vec![':', '-', '.', '_'];
109}
110
111pub async fn wait_for_identity_update(current_version: Option<u64>) -> Option<u64> {
112 let current_version =
113 current_version.unwrap_or_else(|| CURRENT_IDENTITY_VERSION.load(Ordering::SeqCst));
114 let mut receiver = IDENTITY_UPDATE_WATCHER.1.clone();
115 loop {
116 receiver.changed().await.ok()?;
117 let latest_version = *receiver.borrow();
118 if latest_version <= current_version {
119 continue;
120 }
121 return Some(latest_version);
122 }
123}
124
125pub fn init() {
126 tokio::spawn(spire::spire_manager());
127}
128
129pub fn init_mock(identities: BTreeMap<SpiffeID, Arc<Identity>>, crl: Vec<Certificate>) {
130 IDENTITIES.store(Arc::new(identities));
131 CERTIFICATE_REVOKATION_LIST.store(Arc::new(crl.into_iter().map(CrlEntry).collect()));
132}
133
134pub fn make_client_config(
135 identity: Option<SpiffeID>,
136 protocols: &[Vec<u8>],
137 authorizer: Box<dyn SpiffeIdAuthorizer>,
138 require_server_auth: bool,
139) -> rustls::ClientConfig {
140 let dyn_resolver_verifier = Arc::new(DynamicLoadedCertResolverVerifier {
141 identity,
142 authorizer,
143 require_client_auth: require_server_auth,
144 });
145
146 let mut config = ClientConfig::builder()
148 .with_cipher_suites(rustls::ALL_CIPHER_SUITES)
149 .with_safe_default_kx_groups()
150 .with_safe_default_protocol_versions()
151 .expect("create client config fail")
152 .with_custom_certificate_verifier(dyn_resolver_verifier.clone())
153 .with_no_client_auth();
154
155 config.alpn_protocols = protocols.to_vec();
156 config.key_log = Arc::new(rustls::KeyLogFile::new());
157 config.client_auth_cert_resolver = dyn_resolver_verifier;
158
159 config
160}
161
162pub fn make_server_config(
163 identity: Option<SpiffeID>,
164 protocols: &[Vec<u8>],
165 authorizer: Box<dyn SpiffeIdAuthorizer>,
166 require_client_auth: bool,
167) -> rustls::ServerConfig {
168 let dyn_resolver_verifier = Arc::new(DynamicLoadedCertResolverVerifier {
169 identity,
170 authorizer,
171 require_client_auth,
172 });
173
174 let mut config = rustls::ServerConfig::builder()
176 .with_safe_default_cipher_suites()
177 .with_safe_default_kx_groups()
178 .with_safe_default_protocol_versions()
179 .expect("create server config failed")
180 .with_client_cert_verifier(dyn_resolver_verifier.clone())
181 .with_cert_resolver(dyn_resolver_verifier.clone());
182
183 config.key_log = Arc::new(rustls::KeyLogFile::new());
184
185 config.cert_resolver = dyn_resolver_verifier;
186
187 config.alpn_protocols = Vec::from(protocols);
188
189 config
190}
191
192pub fn set_spiffe_separator(
193 separator: &str
194) -> Result<()>{
195
196 if separator.len() != 1 {
197 return Err(anyhow!("invalid spiffe separator length: {}", separator.len()));
198 }
199
200 match &separator.chars().next() {
201 None => return Err(anyhow!("empty spiffe separator")),
202 Some(c) => {
203 if !VALID_SPIFFEID_SEPARATORS.contains(c){
204 return Err(anyhow!("invalid spiffe separator char: {}", separator));
205 }
206 SPIFFEID_SEPARATOR.store(*c as u8, Ordering::SeqCst);
207
208 },
209 };
210
211 Ok(())
212}