1use std::io;
2use crate::codec::codec_trait::TfCodec;
3use crate::structures::temp_transport::TempTransport;
4use crate::structures::transport::{AsyncReadWrite, Transport};
5use aes_gcm::{
6 Aes256Gcm, Key, Nonce,
7 aead::{ KeyInit},
8};
9use async_trait::async_trait;
10use bytes::{Buf, Bytes, BytesMut};
11use futures_util::{SinkExt, StreamExt};
12use hkdf::Hkdf;
13use sha2::Sha256;
14use spake2::{Ed25519Group, Identity, Password, Spake2};
15use std::sync::Arc;
16use std::sync::atomic::{AtomicU64, Ordering};
17use aead::AeadInPlace;
18use tokio_util::codec::{Decoder, Encoder, Framed, LengthDelimitedCodec};
19
20pub struct Spake2Encrypted {
21 server_provider: Option<Arc<dyn ServerCredentialProvider>>,
22 client_provider: Option<Arc<dyn ClientCredentialProvider>>,
23 is_server: bool,
24 server_id: Vec<u8>,
25 length_codec: LengthDelimitedCodec,
26 keys: Option<SessionKeys>,
27}
28
29impl Spake2Encrypted {
30 pub fn create_server(
31 server_provider: Arc<dyn ServerCredentialProvider>,
32 server_id: String,
33 codec: LengthDelimitedCodec,
34 ) -> Self {
35 Self {
36 server_provider: Some(server_provider),
37 client_provider: None,
38 is_server: true,
39 server_id: server_id.as_bytes().to_vec(),
40 length_codec: codec,
41 keys: None,
42 }
43 }
44
45 pub fn create_client(
46 client_provider: Arc<dyn ClientCredentialProvider>,
47 server_id: String,
48 codec: LengthDelimitedCodec,
49 ) -> Self {
50 Self {
51 server_provider: None,
52 client_provider: Some(client_provider),
53 is_server: false,
54 server_id: server_id.as_bytes().to_vec(),
55 length_codec: codec,
56 keys: None,
57 }
58 }
59}
60impl Decoder for Spake2Encrypted {
61 type Item = BytesMut;
62 type Error = io::Error;
63
64 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
65 let mut frame = match self.length_codec.decode(src)? {
66 Some(f) => f,
67 None => return Ok(None),
68 };
69 if let Some(keys) = &self.keys {
70 keys.open_in_place(&mut frame)
71 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "decryption failed"))?;
72 } else {
73 return Err(io::Error::new(io::ErrorKind::Other, "decryption failed"));
74 }
75 Ok(Some(frame))
76 }
77}
78
79impl Encoder<Bytes> for Spake2Encrypted {
80 type Error = io::Error;
81
82 fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
83 if let Some(keys) = &self.keys {
84 let mut buf = BytesMut::from(item);
85 keys.seal_in_place(&mut buf)
86 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "encryption failed"))?;
87 self.length_codec.encode(buf.freeze(), dst)
88 } else {
89 return Err(io::Error::new(io::ErrorKind::Other, "encryption failed"));
90 }
91 }
92}
93impl Clone for Spake2Encrypted {
94 fn clone(&self) -> Self {
95 Self{
96 server_provider: self.server_provider.clone(),
97 client_provider: self.client_provider.clone(),
98 is_server: self.is_server.clone(),
99 server_id: self.server_id.clone(),
100 length_codec: self.length_codec.clone(),
101 keys: None
102 }
103 }
104}
105
106#[async_trait]
107impl TfCodec for Spake2Encrypted {
108 async fn initial_setup(&mut self, tr: &mut Transport) -> bool {
109 let length_codec = LengthDelimitedCodec::builder().max_frame_length(2048).new_codec();
111 let mut framed = Framed::new(TempTransport::new(tr), length_codec);
112 if self.is_server{
113 let res = server_handshake(&mut framed, self.server_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
114 if let Some(keys) = res {
115 self.keys = Some(keys);
116 return true;
117 } else {
118 return false;
119 }
120 } else {
121 let res = client_handshake(&mut framed, self.client_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
122 if let Some(keys) = res {
123 self.keys = Some(keys);
124 return true;
125 }
126 return false;
127 }
128 }
129}
130
131
132#[async_trait]
133pub trait ServerCredentialProvider: Send+Sync+'static {
134 async fn get_client_password(&self, client_identity: &str) -> Option<Vec<u8>>;
135}
136
137#[async_trait]
138pub trait ClientCredentialProvider: Send+Sync+'static {
139 async fn get_client_credentials(&self) -> Option<(Vec<u8>, Vec<u8>)>;
141}
142
143pub struct SessionKeys {
144 pub send: Aes256Gcm,
145 pub recv: Aes256Gcm,
146
147 send_counter: AtomicU64,
149
150 recv_counter: AtomicU64,
152}
153
154
155const COUNTER_LEN: usize = 8;
156const TAG_LEN: usize = 16;
157
158
159struct OffsetBuffer<'a> {
160 buf: &'a mut BytesMut,
161 offset: usize,
162}
163
164impl AsRef<[u8]> for OffsetBuffer<'_> {
165 fn as_ref(&self) -> &[u8] { &self.buf[self.offset..] }
166}
167
168impl AsMut<[u8]> for OffsetBuffer<'_> {
169 fn as_mut(&mut self) -> &mut [u8] { &mut self.buf[self.offset..] }
170}
171
172impl aead::Buffer for OffsetBuffer<'_> {
173 fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> {
174 self.buf.extend_from_slice(other);
175 Ok(())
176 }
177
178 fn truncate(&mut self, len: usize) {
179 self.buf.truncate(self.offset + len);
180 }
181}
182impl SessionKeys {
183
184 fn derive_session_keys(shared: &[u8], is_server: bool) -> Option<Self> {
185 let hk = Hkdf::<Sha256>::new(None, shared);
186
187 let mut key_a = [0u8; 32];
188 let mut key_b = [0u8; 32];
189
190 hk.expand(b"aes-tunnel-key-a", &mut key_a).ok()?;
191 hk.expand(b"aes-tunnel-key-b", &mut key_b).ok()?;
192
193 let (send_key, recv_key) = if is_server {
194 (key_b, key_a)
195 } else {
196 (key_a, key_b)
197 };
198
199 Some(Self {
200 send: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&send_key)),
201 recv: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&recv_key)),
202 send_counter: AtomicU64::new(1),
203 recv_counter: AtomicU64::new(0),
204 })
205 }
206
207 #[inline]
208 fn nonce_from_counter(counter: u64) -> [u8; 12] {
209 let mut nonce = [0u8; 12];
210 nonce[4..].copy_from_slice(&counter.to_be_bytes());
211 nonce
212 }
213
214 pub fn seal_in_place(&self, buf: &mut BytesMut) -> Option<()> {
215 let counter = self.send_counter.fetch_add(1, Ordering::Relaxed);
216
217 if counter == u64::MAX {
218 return None;
219 }
220
221 let counter_bytes = counter.to_be_bytes();
222 let nonce_bytes = Self::nonce_from_counter(counter);
223 let nonce = Nonce::from_slice(&nonce_bytes);
224
225 let plaintext = buf.split();
230 buf.reserve(COUNTER_LEN + plaintext.len() + TAG_LEN);
231 buf.extend_from_slice(&counter_bytes);
232 buf.unsplit(plaintext);
233
234 let mut framed = OffsetBuffer { buf: &mut *buf, offset: COUNTER_LEN };
238 self.send
239 .encrypt_in_place(nonce, &counter_bytes, &mut framed)
240 .ok()?;
241
242 Some(())
243 }
244
245 pub fn open_in_place(&self, buf: &mut BytesMut) -> Option<()> {
246 if buf.len() < COUNTER_LEN {
247 return None;
248 }
249
250 let counter = u64::from_be_bytes(buf[..COUNTER_LEN].try_into().ok()?);
251
252 if counter == u64::MAX {
253 return None;
254 }
255
256 let mut last = self.recv_counter.load(Ordering::Acquire);
258 loop {
259 if counter <= last {
260 return None; }
262 match self.recv_counter.compare_exchange_weak(
263 last,
264 counter,
265 Ordering::AcqRel,
266 Ordering::Acquire,
267 ) {
268 Ok(_) => break,
269 Err(current) => last = current, }
271 }
272
273 let counter_bytes = counter.to_be_bytes();
274 let nonce_bytes = Self::nonce_from_counter(counter);
275 let nonce = Nonce::from_slice(&nonce_bytes);
276
277 let mut framed = OffsetBuffer { buf: &mut *buf, offset: COUNTER_LEN };
280 self.recv
281 .decrypt_in_place(nonce, &counter_bytes, &mut framed)
282 .ok()?;
283
284 buf.advance(COUNTER_LEN);
287
288 Some(())
289 }
290 }
291
292pub async fn client_handshake<'a, IO: AsyncReadWrite>(
293 io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
294 cred: Arc<dyn ClientCredentialProvider>,
295 server_id: &[u8],
296) -> Option<SessionKeys> {
297 let creds = cred.get_client_credentials().await?;
298 let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_a(
299 &Password::new(creds.1.as_slice()),
300 &Identity::new(creds.0.as_slice()),
301 &Identity::new(server_id),
302 );
303 io.send(Bytes::from(creds.0.clone())).await.ok()?;
304 io.send(Bytes::from(outbound_msg)).await.ok()?;
305
306 let peer_msg = io.next().await?.ok()?;
307
308 let shared = spake.finish(&peer_msg).ok()?;
309
310 SessionKeys::derive_session_keys(&shared, false)
311}
312
313pub async fn server_handshake<'a, IO: AsyncReadWrite>(
314 io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
315 cred_provider: Arc<dyn ServerCredentialProvider>,
316 server_id: &[u8],
317) -> Option<SessionKeys>
318where
319 IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
320{
321 let client_identity = io.next().await?.ok()?;
322 let client_identity = String::from_utf8_lossy(client_identity.as_ref());
323 let password = cred_provider.get_client_password(&client_identity).await?;
324 let client_identity = client_identity.as_bytes();
325 let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_b(
326 &Password::new(password),
327 &Identity::new(client_identity),
328 &Identity::new(server_id),
329 );
330 let peer_msg = io.next().await?.ok()?;
331
332 io.send(Bytes::from(outbound_msg)).await.ok()?;
333
334 let shared = spake.finish(&peer_msg).ok()?;
335
336 SessionKeys::derive_session_keys(&shared, true)
337}
338
339
340
341#[cfg(test)]
342mod seal_open_tests {
343 use super::*;
344
345
346 fn pair() -> (SessionKeys, SessionKeys) {
347 (
348 SessionKeys::derive_session_keys(b"shared-secret", true).unwrap(),
349 SessionKeys::derive_session_keys(b"shared-secret", false).unwrap(),
350 )
351 }
352
353 fn roundtrip(plaintext: &[u8]) {
354 let (server, client) = pair();
355
356 let mut buf = BytesMut::from(plaintext);
357 server.seal_in_place(&mut buf).expect("seal");
358 assert_eq!(buf.len(), COUNTER_LEN + plaintext.len() + TAG_LEN);
360 client.open_in_place(&mut buf).expect("open");
361 assert_eq!(&buf[..], plaintext);
362 }
363
364 #[test]
365 fn roundtrip_various_sizes() {
366 roundtrip(b"");
367 roundtrip(b"hello");
368 roundtrip(&[0x41u8; 4096]);
369 }
370
371 #[test]
372 fn replay_is_rejected() {
373 let (server, client) = pair();
374 let mut buf = BytesMut::from(&b"first"[..]);
375 server.seal_in_place(&mut buf).unwrap();
376 let mut replay = buf.clone();
377 client.open_in_place(&mut buf).unwrap();
378 assert!(client.open_in_place(&mut replay).is_none(), "replay must be rejected");
379 }
380
381 #[test]
382 fn tamper_is_rejected() {
383 let (server, client) = pair();
384 let mut buf = BytesMut::from(&b"payload"[..]);
385 server.seal_in_place(&mut buf).unwrap();
386 let idx = buf.len() - 1;
387 buf[idx] ^= 0xFF; assert!(client.open_in_place(&mut buf).is_none(), "tamper must fail tag check");
389 }
390}