rustywallet_electrum/
pinning.rs1use std::collections::HashMap;
8use std::sync::Arc;
9
10use rustls::{
11 client::{ServerCertVerified, ServerCertVerifier},
12 Certificate, ServerName,
13};
14use sha2::{Digest, Sha256};
15
16use crate::error::{ElectrumError, Result};
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub struct CertFingerprint([u8; 32]);
21
22impl CertFingerprint {
23 pub fn from_bytes(bytes: [u8; 32]) -> Self {
25 Self(bytes)
26 }
27
28 pub fn from_hex(hex: &str) -> Result<Self> {
30 let bytes = hex::decode(hex)
31 .map_err(|e| ElectrumError::TlsError(format!("Invalid hex fingerprint: {}", e)))?;
32
33 if bytes.len() != 32 {
34 return Err(ElectrumError::TlsError(format!(
35 "Fingerprint must be 32 bytes, got {}",
36 bytes.len()
37 )));
38 }
39
40 let mut arr = [0u8; 32];
41 arr.copy_from_slice(&bytes);
42 Ok(Self(arr))
43 }
44
45 pub fn from_certificate(cert_der: &[u8]) -> Self {
47 let mut hasher = Sha256::new();
48 hasher.update(cert_der);
49 let result = hasher.finalize();
50 let mut arr = [0u8; 32];
51 arr.copy_from_slice(&result);
52 Self(arr)
53 }
54
55 pub fn as_bytes(&self) -> &[u8; 32] {
57 &self.0
58 }
59
60 pub fn to_hex(&self) -> String {
62 hex::encode(self.0)
63 }
64}
65
66impl std::fmt::Display for CertFingerprint {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "{}", self.to_hex())
69 }
70}
71
72#[derive(Debug, Clone, Default)]
74pub struct CertPinStore {
75 pins: HashMap<String, Vec<CertFingerprint>>,
76}
77
78impl CertPinStore {
79 pub fn new() -> Self {
81 Self::default()
82 }
83
84 pub fn add_pin(&mut self, server: impl Into<String>, fingerprint: CertFingerprint) {
88 self.pins
89 .entry(server.into())
90 .or_default()
91 .push(fingerprint);
92 }
93
94 pub fn add_pin_hex(&mut self, server: impl Into<String>, hex: &str) -> Result<()> {
96 let fingerprint = CertFingerprint::from_hex(hex)?;
97 self.add_pin(server, fingerprint);
98 Ok(())
99 }
100
101 pub fn verify(&self, server: &str, cert_der: &[u8]) -> bool {
103 let fingerprint = CertFingerprint::from_certificate(cert_der);
104
105 if let Some(pins) = self.pins.get(server) {
106 pins.contains(&fingerprint)
107 } else {
108 true
110 }
111 }
112
113 pub fn get_pins(&self, server: &str) -> Option<&[CertFingerprint]> {
115 self.pins.get(server).map(|v| v.as_slice())
116 }
117
118 pub fn has_pins(&self, server: &str) -> bool {
120 self.pins.contains_key(server)
121 }
122
123 pub fn remove_pins(&mut self, server: &str) {
125 self.pins.remove(server);
126 }
127
128 pub fn server_count(&self) -> usize {
130 self.pins.len()
131 }
132}
133
134pub struct PinningVerifier {
136 pin_store: CertPinStore,
137 allow_unpinned: bool,
138}
139
140impl PinningVerifier {
141 pub fn new(pin_store: CertPinStore, allow_unpinned: bool) -> Self {
147 Self {
148 pin_store,
149 allow_unpinned,
150 }
151 }
152
153 pub fn strict(pin_store: CertPinStore) -> Self {
155 Self::new(pin_store, false)
156 }
157
158 pub fn permissive(pin_store: CertPinStore) -> Self {
160 Self::new(pin_store, true)
161 }
162}
163
164impl ServerCertVerifier for PinningVerifier {
165 fn verify_server_cert(
166 &self,
167 end_entity: &Certificate,
168 _intermediates: &[Certificate],
169 server_name: &ServerName,
170 _scts: &mut dyn Iterator<Item = &[u8]>,
171 _ocsp_response: &[u8],
172 _now: std::time::SystemTime,
173 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
174 let server = match server_name {
175 ServerName::DnsName(name) => name.as_ref().to_string(),
176 _ => return Err(rustls::Error::General("Invalid server name".into())),
177 };
178
179 if !self.pin_store.has_pins(&server) {
181 if self.allow_unpinned {
182 return Ok(ServerCertVerified::assertion());
183 } else {
184 return Err(rustls::Error::General(format!(
185 "No certificate pins for server: {}",
186 server
187 )));
188 }
189 }
190
191 if self.pin_store.verify(&server, &end_entity.0) {
193 Ok(ServerCertVerified::assertion())
194 } else {
195 Err(rustls::Error::General(format!(
196 "Certificate fingerprint mismatch for server: {}",
197 server
198 )))
199 }
200 }
201}
202
203pub struct PinningConfigBuilder {
205 pin_store: CertPinStore,
206 allow_unpinned: bool,
207}
208
209impl PinningConfigBuilder {
210 pub fn new() -> Self {
212 Self {
213 pin_store: CertPinStore::new(),
214 allow_unpinned: true,
215 }
216 }
217
218 pub fn pin(mut self, server: impl Into<String>, fingerprint: CertFingerprint) -> Self {
220 self.pin_store.add_pin(server, fingerprint);
221 self
222 }
223
224 pub fn pin_hex(mut self, server: impl Into<String>, hex: &str) -> Result<Self> {
226 self.pin_store.add_pin_hex(server, hex)?;
227 Ok(self)
228 }
229
230 pub fn allow_unpinned(mut self, allow: bool) -> Self {
232 self.allow_unpinned = allow;
233 self
234 }
235
236 pub fn build(self) -> rustls::ClientConfig {
238 let verifier = PinningVerifier::new(self.pin_store, self.allow_unpinned);
239
240 rustls::ClientConfig::builder()
241 .with_safe_defaults()
242 .with_custom_certificate_verifier(Arc::new(verifier))
243 .with_no_client_auth()
244 }
245
246 pub fn pin_store(&self) -> &CertPinStore {
248 &self.pin_store
249 }
250}
251
252impl Default for PinningConfigBuilder {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258pub mod known_pins {
260 use super::CertFingerprint;
261
262 pub fn blockstream() -> Option<CertFingerprint> {
265 None
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_fingerprint_from_hex() {
277 let hex = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
278 let fp = CertFingerprint::from_hex(hex).unwrap();
279 assert_eq!(fp.to_hex(), hex);
280 }
281
282 #[test]
283 fn test_fingerprint_from_certificate() {
284 let cert_der = b"test certificate data";
285 let fp = CertFingerprint::from_certificate(cert_der);
286 assert_eq!(fp.as_bytes().len(), 32);
287 }
288
289 #[test]
290 fn test_pin_store() {
291 let mut store = CertPinStore::new();
292 let fp = CertFingerprint::from_bytes([0u8; 32]);
293
294 store.add_pin("server.example.com", fp.clone());
295
296 assert!(store.has_pins("server.example.com"));
297 assert!(!store.has_pins("other.example.com"));
298
299 let pins = store.get_pins("server.example.com").unwrap();
300 assert_eq!(pins.len(), 1);
301 assert_eq!(pins[0], fp);
302 }
303
304 #[test]
305 fn test_pin_store_verify() {
306 let mut store = CertPinStore::new();
307 let cert_der = b"test certificate";
308 let fp = CertFingerprint::from_certificate(cert_der);
309
310 store.add_pin("server.example.com", fp);
311
312 assert!(store.verify("server.example.com", cert_der));
313 assert!(!store.verify("server.example.com", b"wrong cert"));
314 assert!(store.verify("unpinned.example.com", b"any cert"));
316 }
317}