rustls_mbedcrypto_provider/
hash.rs1use crate::log::error;
9use alloc::boxed::Box;
10use alloc::vec;
11use alloc::vec::Vec;
12use rustls::crypto::hash::{self, HashAlgorithm};
13use std::sync::Mutex;
14
15pub static SHA256: Hash = Hash(&MBED_SHA_256);
17pub static SHA384: Hash = Hash(&MBED_SHA_384);
19
20pub struct Hash(&'static Algorithm);
22
23#[derive(Clone, Debug, PartialEq)]
25pub struct Algorithm {
26 pub(crate) hash_algorithm: HashAlgorithm,
27 pub(crate) hash_type: mbedtls::hash::Type,
28 pub(crate) output_len: usize,
29}
30
31pub static MBED_SHA_256: Algorithm = Algorithm {
35 hash_algorithm: HashAlgorithm::SHA256,
36 hash_type: mbedtls::hash::Type::Sha256,
37 output_len: 256 / 8,
38};
39
40pub static MBED_SHA_384: Algorithm = Algorithm {
44 hash_algorithm: HashAlgorithm::SHA384,
45 hash_type: mbedtls::hash::Type::Sha384,
46 output_len: 384 / 8,
47};
48
49pub static MBED_SHA_512: Algorithm = Algorithm {
53 hash_algorithm: HashAlgorithm::SHA512,
54 hash_type: mbedtls::hash::Type::Sha512,
55 output_len: 512 / 8,
56};
57
58impl hash::Hash for Hash {
59 fn start(&self) -> Box<dyn hash::Context> {
60 Box::new(HashContext(MbedHashContext::new(self.0)))
61 }
62
63 fn hash(&self, data: &[u8]) -> hash::Output {
64 hash::Output::new(&hash(self.0, data))
65 }
66
67 fn algorithm(&self) -> HashAlgorithm {
68 self.0.hash_algorithm
69 }
70
71 fn output_len(&self) -> usize {
72 self.0.output_len
73 }
74}
75
76struct HashContext(MbedHashContext);
77
78impl hash::Context for HashContext {
79 fn fork_finish(&self) -> hash::Output {
80 hash::Output::new(&self.0.clone().finalize())
81 }
82
83 fn fork(&self) -> Box<dyn hash::Context> {
84 Box::new(Self(self.0.clone()))
85 }
86
87 fn finish(self: Box<Self>) -> hash::Output {
88 hash::Output::new(&self.0.finalize())
89 }
90
91 fn update(&mut self, data: &[u8]) {
92 self.0.update(data)
93 }
94}
95
96pub(crate) struct MbedHashContext {
97 pub(crate) state: Mutex<mbedtls::hash::Md>,
98 pub(crate) hash_algo: &'static Algorithm,
99}
100
101impl Clone for MbedHashContext {
102 fn clone(&self) -> Self {
103 let state = self.state.lock().unwrap();
104 Self { state: Mutex::new(state.clone()), hash_algo: self.hash_algo }
105 }
106}
107
108impl MbedHashContext {
109 pub(crate) fn new(hash_algo: &'static Algorithm) -> Self {
110 Self {
111 hash_algo,
112 state: Mutex::new(mbedtls::hash::Md::new(hash_algo.hash_type).expect("input is validated")),
113 }
114 }
115
116 pub(crate) fn finalize(self) -> Vec<u8> {
118 match self.state.into_inner() {
119 Ok(ctx) => {
120 let mut out = vec![0u8; self.hash_algo.output_len];
121 match ctx.finish(&mut out) {
122 Ok(_) => out,
123 Err(_err) => {
124 error!("Failed to finalize hash, mbedtls error: {_err:?}");
125 vec![]
126 }
127 }
128 }
129 Err(_err) => {
130 error!("Failed to get lock, error: {_err:?}");
131 vec![]
132 }
133 }
134 }
135
136 pub(crate) fn update(&mut self, data: &[u8]) {
137 match self.state.lock().as_mut() {
138 Ok(ctx) => match ctx.update(data) {
139 Ok(_) => {}
140 Err(_err) => {
141 error!("Failed to update hash, mbedtls error: {_err:?}");
142 }
143 },
144 Err(_err) => {
145 error!("Failed to get lock, error: {_err:?}");
146 }
147 }
148 }
149}
150
151pub(crate) fn hash(hash_algo: &'static Algorithm, data: &[u8]) -> Vec<u8> {
152 let mut out = vec![0u8; hash_algo.output_len];
153 match mbedtls::hash::Md::hash(hash_algo.hash_type, data, &mut out) {
154 Ok(_) => out,
155 Err(_err) => {
156 error!("Failed to do hash, mbedtls error: {_err:?}");
157 vec![]
158 }
159 }
160}
161
162#[cfg(bench)]
163mod benchmarks {
164
165 #[bench]
166 fn bench_sha_256_hash(b: &mut test::Bencher) {
167 bench_hash(b, &super::SHA256);
168 }
169
170 #[bench]
171 fn bench_sha_384_hash(b: &mut test::Bencher) {
172 bench_hash(b, &super::SHA384);
173 }
174
175 #[bench]
176 fn bench_sha_256_hash_multi_parts(b: &mut test::Bencher) {
177 bench_hash_multi_parts(b, &super::SHA256);
178 }
179
180 #[bench]
181 fn bench_sha_384_hash_multi_parts(b: &mut test::Bencher) {
182 bench_hash_multi_parts(b, &super::SHA384);
183 }
184
185 fn bench_hash(b: &mut test::Bencher, hash: &super::Hash) {
186 use super::hash::Hash;
187 let input = [123u8; 1024 * 16];
188 b.iter(|| {
189 test::black_box(hash.hash(&input));
190 });
191 }
192
193 fn bench_hash_multi_parts(b: &mut test::Bencher, hash: &super::Hash) {
194 use super::hash::Hash;
195 let input = [123u8; 1024 * 16];
196 b.iter(|| {
197 let mut ctx = hash.start();
198 for i in 0..16 {
199 test::black_box(ctx.update(&input[i * 1024..(i + 1) * 1024]));
200 }
201 test::black_box(ctx.finish())
202 });
203 }
204}