1use alloc::vec::Vec;
7use serde::{
8 Deserialize, Deserializer, Serialize, Serializer,
9 de::{self, Visitor},
10};
11use zeroize::{Zeroize, Zeroizing};
12
13#[cfg(feature = "std")]
14use secstr::SecVec;
15use subtle::ConstantTimeEq;
16
17#[derive(Clone)]
40pub struct SecBytes {
41 #[cfg(feature = "std")]
42 inner: SecVec<u8>,
43
44 #[cfg(not(feature = "std"))]
45 inner: Zeroizing<Vec<u8>>,
46}
47
48impl SecBytes {
49 pub fn new(data: Vec<u8>) -> Self {
53 #[cfg(feature = "std")]
54 return Self {
55 inner: SecVec::from(data),
56 };
57
58 #[cfg(not(feature = "std"))]
59 return Self {
60 inner: Zeroizing::new(data),
61 };
62 }
63
64 pub fn from_slice(data: &[u8]) -> Self {
68 Self::new(data.to_vec())
69 }
70
71 pub fn from_array<const N: usize>(data: [u8; N]) -> Self {
73 Self::new(data.to_vec())
74 }
75
76 pub fn as_slice(&self) -> &[u8] {
86 #[cfg(feature = "std")]
87 return self.inner.unsecure();
88
89 #[cfg(not(feature = "std"))]
90 return &self.inner;
91 }
92
93 pub fn len(&self) -> usize {
95 #[cfg(feature = "std")]
96 return self.inner.unsecure().len();
97
98 #[cfg(not(feature = "std"))]
99 return self.inner.len();
100 }
101
102 pub fn is_empty(&self) -> bool {
104 self.len() == 0
105 }
106
107 pub fn with_bytes<F, R>(&self, f: F) -> R
122 where
123 F: FnOnce(&[u8]) -> R,
124 {
125 let result = f(self.as_slice());
126
127 core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
130
131 result
132 }
133
134 pub fn to_array<const N: usize>(&self) -> Option<Zeroizing<[u8; N]>> {
143 if self.len() != N {
144 return None;
145 }
146
147 let mut arr = [0u8; N];
148 arr.copy_from_slice(self.as_slice());
149 Some(Zeroizing::new(arr))
150 }
151
152 pub fn to_vec(&self) -> Vec<u8> {
160 self.as_slice().to_vec()
161 }
162}
163
164impl Zeroize for SecBytes {
166 fn zeroize(&mut self) {
167 #[cfg(feature = "std")]
168 {
169 }
171
172 #[cfg(not(feature = "std"))]
173 {
174 self.inner.zeroize();
175 }
176 }
177}
178
179impl Drop for SecBytes {
180 fn drop(&mut self) {
181 self.zeroize();
182 }
183}
184
185impl core::fmt::Debug for SecBytes {
187 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
188 f.debug_struct("SecBytes")
189 .field("len", &self.len())
190 .field("data", &"<redacted>")
191 .finish()
192 }
193}
194
195impl PartialEq for SecBytes {
197 fn eq(&self, other: &Self) -> bool {
198 self.as_slice().ct_eq(other.as_slice()).into()
199 }
200}
201
202impl Eq for SecBytes {}
203
204impl Serialize for SecBytes {
206 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
207 where
208 S: Serializer,
209 {
210 serializer.serialize_bytes(self.as_slice())
211 }
212}
213
214impl<'de> Deserialize<'de> for SecBytes {
215 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
216 where
217 D: Deserializer<'de>,
218 {
219 struct SecBytesVisitor;
220
221 impl<'de> Visitor<'de> for SecBytesVisitor {
222 type Value = SecBytes;
223
224 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
225 formatter.write_str("bytes")
226 }
227
228 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
229 where
230 E: de::Error,
231 {
232 Ok(SecBytes::new(v.to_vec()))
233 }
234 }
235
236 deserializer.deserialize_bytes(SecBytesVisitor)
237 }
238}
239
240#[derive(Clone, Serialize, Deserialize)]
244#[serde(try_from = "SecBytes", into = "SecBytes")]
245pub struct SecPinHash {
246 inner: SecBytes,
247}
248
249impl SecPinHash {
250 pub fn new(hash: [u8; 32]) -> Self {
252 Self {
253 inner: SecBytes::from_array(hash),
254 }
255 }
256
257 pub fn from_slice(slice: &[u8]) -> Self {
259 assert_eq!(slice.len(), 32, "PIN hash must be 32 bytes");
260 let mut hash = [0u8; 32];
261 hash.copy_from_slice(slice);
262 Self::new(hash)
263 }
264
265 pub fn verify(&self, expected: &[u8]) -> bool {
267 if expected.len() != 32 {
268 return false;
269 }
270 self.inner.as_slice().ct_eq(expected).into()
271 }
272
273 pub fn verify_first_16(&self, expected: &[u8]) -> bool {
275 if expected.len() < 16 {
276 return false;
277 }
278 self.inner.as_slice()[..16].ct_eq(&expected[..16]).into()
279 }
280
281 pub fn as_array(&self) -> [u8; 32] {
283 let mut arr = [0u8; 32];
284 arr.copy_from_slice(self.inner.as_slice());
285 arr
286 }
287
288 pub fn with_bytes<F, R>(&self, f: F) -> R
290 where
291 F: FnOnce(&[u8; 32]) -> R,
292 {
293 let arr = self.as_array();
294 f(&arr)
295 }
296}
297
298impl core::fmt::Debug for SecPinHash {
299 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
300 f.debug_struct("SecPinHash")
301 .field("data", &"<redacted>")
302 .finish()
303 }
304}
305
306impl PartialEq for SecPinHash {
307 fn eq(&self, other: &Self) -> bool {
308 self.inner == other.inner
309 }
310}
311
312impl Eq for SecPinHash {}
313
314impl TryFrom<SecBytes> for SecPinHash {
316 type Error = &'static str;
317
318 fn try_from(bytes: SecBytes) -> Result<Self, Self::Error> {
319 if bytes.len() != 32 {
320 return Err("PIN hash must be exactly 32 bytes");
321 }
322 Ok(Self { inner: bytes })
323 }
324}
325
326impl From<SecPinHash> for SecBytes {
327 fn from(hash: SecPinHash) -> Self {
328 hash.inner
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_sec_pin_hash_new() {
338 let hash = [0x42u8; 32];
339 let sec_hash = SecPinHash::new(hash);
340 assert!(sec_hash.verify(&hash));
341 }
342
343 #[test]
344 fn test_sec_pin_hash_from_slice() {
345 let hash = [0x42u8; 32];
346 let sec_hash = SecPinHash::from_slice(&hash);
347 assert!(sec_hash.verify(&hash));
348 }
349
350 #[test]
351 fn test_sec_pin_hash_verify_first_16() {
352 let hash = [0x42u8; 32];
353 let sec_hash = SecPinHash::new(hash);
354 assert!(sec_hash.verify_first_16(&hash[..16]));
355 assert!(!sec_hash.verify_first_16(&[0x43u8; 16]));
356 }
357
358 #[test]
359 fn test_sec_pin_hash_debug() {
360 let hash = [0x42u8; 32];
361 let sec_hash = SecPinHash::new(hash);
362 let debug_str = format!("{:?}", sec_hash);
363 assert!(debug_str.contains("redacted"));
364 assert!(!debug_str.contains("42"));
365 }
366
367 #[test]
368 fn test_sec_pin_hash_equality() {
369 let hash1 = [0x42u8; 32];
370 let hash2 = [0x42u8; 32];
371 let hash3 = [0x43u8; 32];
372 let sec1 = SecPinHash::new(hash1);
373 let sec2 = SecPinHash::new(hash2);
374 let sec3 = SecPinHash::new(hash3);
375 assert_eq!(sec1, sec2);
376 assert_ne!(sec1, sec3);
377 }
378
379 #[test]
380 fn test_sec_pin_hash_serialization() {
381 let hash = [0x42u8; 32];
382 let sec_hash = SecPinHash::new(hash);
383
384 let buf = crate::cbor::encode(&sec_hash).unwrap();
386
387 let deserialized: SecPinHash = crate::cbor::decode(&buf).unwrap();
389
390 assert_eq!(sec_hash, deserialized);
392 assert!(deserialized.verify(&hash));
393 }
394
395 #[test]
396 fn test_sec_pin_hash_serialization_wrong_length() {
397 let short_bytes = SecBytes::from_slice(&[0x42u8; 16]);
399 let buf = crate::cbor::encode(&short_bytes).unwrap();
400
401 let result: Result<SecPinHash, _> = crate::cbor::decode(&buf);
403 assert!(result.is_err());
404 }
405
406 #[test]
407 fn test_new_and_access() {
408 let data = vec![1, 2, 3, 4];
409 let sec = SecBytes::new(data.clone());
410 assert_eq!(sec.as_slice(), &[1, 2, 3, 4]);
411 assert_eq!(sec.len(), 4);
412 assert!(!sec.is_empty());
413 }
414
415 #[test]
416 fn test_from_slice() {
417 let data = &[1, 2, 3, 4];
418 let sec = SecBytes::from_slice(data);
419 assert_eq!(sec.as_slice(), data);
420 }
421
422 #[test]
423 fn test_from_array() {
424 let data = [1u8, 2, 3, 4];
425 let sec = SecBytes::from_array(data);
426 assert_eq!(sec.as_slice(), &data);
427 }
428
429 #[test]
430 fn test_to_array() {
431 let data = vec![1, 2, 3, 4];
432 let sec = SecBytes::new(data);
433 let arr = sec.to_array::<4>().unwrap();
434 assert_eq!(*arr, [1, 2, 3, 4]);
435 }
436
437 #[test]
438 fn test_to_array_wrong_size() {
439 let data = vec![1, 2, 3, 4];
440 let sec = SecBytes::new(data);
441 assert!(sec.to_array::<5>().is_none());
442 }
443
444 #[test]
445 fn test_with_bytes() {
446 let data = vec![1, 2, 3, 4];
447 let sec = SecBytes::new(data);
448 let sum = sec.with_bytes(|bytes| bytes.iter().sum::<u8>());
449 assert_eq!(sum, 10);
450 }
451
452 #[test]
453 fn test_clone() {
454 let data = vec![1, 2, 3, 4];
455 let sec1 = SecBytes::new(data);
456 let sec2 = sec1.clone();
457 assert_eq!(sec1.as_slice(), sec2.as_slice());
458 }
459
460 #[test]
461 fn test_equality() {
462 let sec1 = SecBytes::new(vec![1, 2, 3, 4]);
463 let sec2 = SecBytes::new(vec![1, 2, 3, 4]);
464 let sec3 = SecBytes::new(vec![1, 2, 3, 5]);
465 assert_eq!(sec1, sec2);
466 assert_ne!(sec1, sec3);
467 }
468
469 #[test]
470 fn test_debug() {
471 let sec = SecBytes::new(vec![1, 2, 3, 4]);
472 let debug_str = format!("{:?}", sec);
473 assert!(debug_str.contains("SecBytes"));
474 assert!(debug_str.contains("len"));
475 assert!(debug_str.contains("redacted"));
476 assert!(!debug_str.contains("1"));
478 }
479
480 #[test]
481 fn test_empty() {
482 let sec = SecBytes::new(vec![]);
483 assert!(sec.is_empty());
484 assert_eq!(sec.len(), 0);
485 }
486}