Skip to main content

tfserver/codec/
spake2_encrypted.rs

1use std::io;
2use crate::codec::codec_trait::TfCodec;
3use crate::structures::temp_transport::TempTransport;
4use crate::structures::transport::{AsyncReadWrite, Transport};
5use aes_gcm::{
6    Aes256Gcm, Key, Nonce,
7    aead::{ KeyInit},
8};
9use async_trait::async_trait;
10use bytes::{Buf, Bytes, BytesMut};
11use futures_util::{SinkExt, StreamExt};
12use hkdf::Hkdf;
13use sha2::Sha256;
14use spake2::{Ed25519Group, Identity, Password, Spake2};
15use std::sync::Arc;
16use std::sync::atomic::{AtomicU64, Ordering};
17use aead::AeadInPlace;
18use tokio_util::codec::{Decoder, Encoder, Framed, LengthDelimitedCodec};
19
20pub struct Spake2Encrypted {
21    server_provider: Option<Arc<dyn ServerCredentialProvider>>,
22    client_provider: Option<Arc<dyn ClientCredentialProvider>>,
23    is_server: bool,
24    server_id: Vec<u8>,
25    length_codec: LengthDelimitedCodec,
26    keys: Option<SessionKeys>,
27}
28
29impl Spake2Encrypted {
30    pub fn create_server(
31        server_provider: Arc<dyn ServerCredentialProvider>,
32        server_id: String,
33        codec: LengthDelimitedCodec,
34    ) -> Self {
35        Self {
36            server_provider: Some(server_provider),
37            client_provider: None,
38            is_server: true,
39            server_id: server_id.as_bytes().to_vec(),
40            length_codec: codec,
41            keys: None,
42        }
43    }
44
45    pub fn create_client(
46        client_provider: Arc<dyn ClientCredentialProvider>,
47        server_id: String,
48        codec: LengthDelimitedCodec,
49    ) -> Self {
50        Self {
51            server_provider: None,
52            client_provider: Some(client_provider),
53            is_server: false,
54            server_id: server_id.as_bytes().to_vec(),
55            length_codec: codec,
56            keys: None,
57        }
58    }
59}
60impl Decoder for Spake2Encrypted {
61    type Item = BytesMut;
62    type Error = io::Error;
63
64    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
65        let mut frame = match self.length_codec.decode(src)? {
66            Some(f) => f,
67            None => return Ok(None),
68        };
69        if let Some(keys) = &self.keys {
70            keys.open_in_place(&mut frame)
71                .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "decryption failed"))?;
72        } else {
73            return Err(io::Error::new(io::ErrorKind::Other, "decryption failed"));
74        }
75        Ok(Some(frame))
76    }
77}
78
79impl Encoder<Bytes> for Spake2Encrypted {
80    type Error = io::Error;
81
82    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
83        if let Some(keys) = &self.keys {
84            let mut buf = BytesMut::from(item);
85            keys.seal_in_place(&mut buf)
86                .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "encryption failed"))?;
87            self.length_codec.encode(buf.freeze(), dst)
88        } else {
89            return Err(io::Error::new(io::ErrorKind::Other, "encryption failed"));
90        }
91    }
92}
93impl Clone for Spake2Encrypted {
94    fn clone(&self) -> Self {
95        Self{
96            server_provider: self.server_provider.clone(),
97            client_provider: self.client_provider.clone(),
98            is_server: self.is_server.clone(),
99            server_id: self.server_id.clone(),
100            length_codec: self.length_codec.clone(),
101            keys: None
102        }
103    }
104}
105
106#[async_trait]
107impl TfCodec for Spake2Encrypted {
108    async fn initial_setup(&mut self, tr: &mut Transport) -> bool {
109        //Safe limitation to prevent dos
110        let length_codec = LengthDelimitedCodec::builder().max_frame_length(2048).new_codec();
111        let mut framed = Framed::new(TempTransport::new(tr), length_codec);
112        if self.is_server{
113            let res = server_handshake(&mut framed, self.server_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
114            if let Some(keys) = res {
115                self.keys = Some(keys);
116                return true;
117            } else {
118                return false;
119            }
120        } else {
121            let res = client_handshake(&mut framed, self.client_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
122            if let Some(keys) = res {
123                self.keys = Some(keys);
124                return true;
125            }
126            return false;
127        }
128    }
129}
130
131
132#[async_trait]
133pub trait ServerCredentialProvider: Send+Sync+'static  {
134    async fn get_client_password(&self, client_identity: &str) -> Option<Vec<u8>>;
135}
136
137#[async_trait]
138pub trait ClientCredentialProvider: Send+Sync+'static {
139    ///Return 0 - client identity, 1 - client password
140    async fn get_client_credentials(&self) -> Option<(Vec<u8>, Vec<u8>)>;
141}
142
143pub struct SessionKeys {
144    pub send: Aes256Gcm,
145    pub recv: Aes256Gcm,
146
147    /// Local outbound packet counter
148    send_counter: AtomicU64,
149
150    /// Highest accepted inbound counter
151    recv_counter: AtomicU64,
152}
153
154
155const COUNTER_LEN: usize = 8;
156const TAG_LEN: usize = 16;
157
158
159struct OffsetBuffer<'a> {
160    buf: &'a mut BytesMut,
161    offset: usize,
162}
163
164impl AsRef<[u8]> for OffsetBuffer<'_> {
165    fn as_ref(&self) -> &[u8] { &self.buf[self.offset..] }
166}
167
168impl AsMut<[u8]> for OffsetBuffer<'_> {
169    fn as_mut(&mut self) -> &mut [u8] { &mut self.buf[self.offset..] }
170}
171
172impl aead::Buffer for OffsetBuffer<'_> {
173    fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> {
174        self.buf.extend_from_slice(other);
175        Ok(())
176    }
177
178    fn truncate(&mut self, len: usize) {
179        self.buf.truncate(self.offset + len);
180    }
181}
182impl SessionKeys {
183
184        fn derive_session_keys(shared: &[u8], is_server: bool) -> Option<Self> {
185            let hk = Hkdf::<Sha256>::new(None, shared);
186
187            let mut key_a = [0u8; 32];
188            let mut key_b = [0u8; 32];
189
190            hk.expand(b"aes-tunnel-key-a", &mut key_a).ok()?;
191            hk.expand(b"aes-tunnel-key-b", &mut key_b).ok()?;
192
193            let (send_key, recv_key) = if is_server {
194                (key_b, key_a)
195            } else {
196                (key_a, key_b)
197            };
198
199            Some(Self {
200                send: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&send_key)),
201                recv: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&recv_key)),
202                send_counter: AtomicU64::new(1),
203                recv_counter: AtomicU64::new(0),
204            })
205        }
206
207        #[inline]
208        fn nonce_from_counter(counter: u64) -> [u8; 12] {
209            let mut nonce = [0u8; 12];
210            nonce[4..].copy_from_slice(&counter.to_be_bytes());
211            nonce
212        }
213
214        pub fn seal_in_place(&self, buf: &mut BytesMut) -> Option<()> {
215            let counter = self.send_counter.fetch_add(1, Ordering::Relaxed);
216
217            if counter == u64::MAX {
218                return None;
219            }
220
221            let counter_bytes = counter.to_be_bytes();
222            let nonce_bytes = Self::nonce_from_counter(counter);
223            let nonce = Nonce::from_slice(&nonce_bytes);
224
225            // Reframe the plaintext as [counter | plaintext] in a buffer sized for
226            // the tag too, so neither the prefix nor the tag append reallocates.
227            // The plaintext moves exactly once (the `unsplit`); after `split`, `buf`
228            // is empty with no spare capacity, so the `reserve` never copies.
229            let plaintext = buf.split();
230            buf.reserve(COUNTER_LEN + plaintext.len() + TAG_LEN);
231            buf.extend_from_slice(&counter_bytes);
232            buf.unsplit(plaintext);
233
234            // Encrypt only the bytes after the counter prefix in place; the tag
235            // lands in the reserved headroom. counter is included in AAD so any
236            // wire tampering fails tag verification.
237            let mut framed = OffsetBuffer { buf: &mut *buf, offset: COUNTER_LEN };
238            self.send
239                .encrypt_in_place(nonce, &counter_bytes, &mut framed)
240                .ok()?;
241
242            Some(())
243        }
244
245        pub fn open_in_place(&self, buf: &mut BytesMut) -> Option<()> {
246            if buf.len() < COUNTER_LEN {
247                return None;
248            }
249
250            let counter = u64::from_be_bytes(buf[..COUNTER_LEN].try_into().ok()?);
251
252            if counter == u64::MAX {
253                return None;
254            }
255
256            // compare-exchange loop — prevents TOCTOU race if called concurrently
257            let mut last = self.recv_counter.load(Ordering::Acquire);
258            loop {
259                if counter <= last {
260                    return None; // replay or reorder
261                }
262                match self.recv_counter.compare_exchange_weak(
263                    last,
264                    counter,
265                    Ordering::AcqRel,
266                    Ordering::Acquire,
267                ) {
268                    Ok(_) => break,
269                    Err(current) => last = current, // another thread advanced it, retry
270                }
271            }
272
273            let counter_bytes = counter.to_be_bytes();
274            let nonce_bytes = Self::nonce_from_counter(counter);
275            let nonce = Nonce::from_slice(&nonce_bytes);
276
277            // Decrypt the ciphertext after the counter prefix in place; the tag is
278            // truncated in place. AAD must match what seal used, otherwise tag fails.
279            let mut framed = OffsetBuffer { buf: &mut *buf, offset: COUNTER_LEN };
280            self.recv
281                .decrypt_in_place(nonce, &counter_bytes, &mut framed)
282                .ok()?;
283
284            // Drop the counter prefix with a zero-copy advance so `buf` holds
285            // exactly the recovered plaintext.
286            buf.advance(COUNTER_LEN);
287
288            Some(())
289        }
290    }
291
292pub async fn client_handshake<'a, IO: AsyncReadWrite>(
293    io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
294    cred: Arc<dyn ClientCredentialProvider>,
295    server_id: &[u8],
296) -> Option<SessionKeys> {
297    let creds = cred.get_client_credentials().await?;
298    let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_a(
299        &Password::new(creds.1.as_slice()),
300        &Identity::new(creds.0.as_slice()),
301        &Identity::new(server_id),
302    );
303    io.send(Bytes::from(creds.0.clone())).await.ok()?;
304    io.send(Bytes::from(outbound_msg)).await.ok()?;
305
306    let peer_msg = io.next().await?.ok()?;
307
308    let shared = spake.finish(&peer_msg).ok()?;
309
310    SessionKeys::derive_session_keys(&shared, false)
311}
312
313pub async fn server_handshake<'a, IO: AsyncReadWrite>(
314    io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
315    cred_provider: Arc<dyn ServerCredentialProvider>,
316    server_id: &[u8],
317) -> Option<SessionKeys>
318where
319    IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
320{
321    let client_identity = io.next().await?.ok()?;
322    let client_identity = String::from_utf8_lossy(client_identity.as_ref());
323    let password = cred_provider.get_client_password(&client_identity).await?;
324    let client_identity = client_identity.as_bytes();
325    let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_b(
326        &Password::new(password),
327        &Identity::new(client_identity),
328        &Identity::new(server_id),
329    );
330    let peer_msg = io.next().await?.ok()?;
331
332    io.send(Bytes::from(outbound_msg)).await.ok()?;
333
334    let shared = spake.finish(&peer_msg).ok()?;
335
336    SessionKeys::derive_session_keys(&shared, true)
337}
338
339
340
341#[cfg(test)]
342mod seal_open_tests {
343    use super::*;
344
345
346    fn pair() -> (SessionKeys, SessionKeys) {
347        (
348            SessionKeys::derive_session_keys(b"shared-secret", true).unwrap(),
349            SessionKeys::derive_session_keys(b"shared-secret", false).unwrap(),
350        )
351    }
352
353    fn roundtrip(plaintext: &[u8]) {
354        let (server, client) = pair();
355
356        let mut buf = BytesMut::from(plaintext);
357        server.seal_in_place(&mut buf).expect("seal");
358        // wire layout = counter(8) | ciphertext | tag(16)
359        assert_eq!(buf.len(), COUNTER_LEN + plaintext.len() + TAG_LEN);
360        client.open_in_place(&mut buf).expect("open");
361        assert_eq!(&buf[..], plaintext);
362    }
363
364    #[test]
365    fn roundtrip_various_sizes() {
366        roundtrip(b"");
367        roundtrip(b"hello");
368        roundtrip(&[0x41u8; 4096]);
369    }
370
371    #[test]
372    fn replay_is_rejected() {
373        let (server, client) = pair();
374        let mut buf = BytesMut::from(&b"first"[..]);
375        server.seal_in_place(&mut buf).unwrap();
376        let mut replay = buf.clone();
377        client.open_in_place(&mut buf).unwrap();
378        assert!(client.open_in_place(&mut replay).is_none(), "replay must be rejected");
379    }
380
381    #[test]
382    fn tamper_is_rejected() {
383        let (server, client) = pair();
384        let mut buf = BytesMut::from(&b"payload"[..]);
385        server.seal_in_place(&mut buf).unwrap();
386        let idx = buf.len() - 1;
387        buf[idx] ^= 0xFF; // flip a tag byte
388        assert!(client.open_in_place(&mut buf).is_none(), "tamper must fail tag check");
389    }
390}