secure_layer/
lib.rs

1/**
2 * Secure Layer
3 * Author: Amit Hendin
4 * Created: 22/1/2024
5 * Description: This library is a simple mechanism a web server developer can employ in order to secure communications between client and server.
6 * The main idea behind it is to have a "rolling shared key", meaning a key that constantly changes in an unpredictable way yet is securely shared between server and client.
7 * To achieve this, the client and the server first share a key once, this is called in code the  "init_key_hash" and is generated by the function start_session, after that
8 * the client encrypts his message using this key and sends the cipher text to the server and after that modifies his own key by concatenating the plain text to the original key
9 * and hashing the result. The server received the cipher_text and decrypts with the original key, then modifies his own key in the same manner using the decrypted plain text.
10 * The result is that both client and server have the same new key without exchanging the key over the network in any way. Notice that the server must successfully decrypt the cipher text in order
11 * to acquire the new key, also the replication of the process by a third party get more difficult over time meaning to generate the same key one would have to follow the initial key and all of
12 * the following plain text exchanged between the client and the server.
13 */
14extern crate rusqlite;
15extern crate aes_gcm;
16extern crate sha3;
17extern crate rand;
18
19use rusqlite::{Connection, params};
20use std::time::{SystemTime, UNIX_EPOCH};
21use aes_gcm::{
22    aead::{AeadCore, AeadInPlace, KeyInit},
23    Aes256Gcm, Nonce, Key // Or `Aes128Gcm`
24};
25use aes_gcm::aead::Aead;
26use rand::Rng;
27use sha3::{Digest, Sha3_256};
28
29const KEY_SIZE: usize = 256;
30
31/**
32* Gets the current time in milliseconds. Reduced from 128 bit unsigned int to 64 but due to sqlite data type constraints.
33*
34* Output: the current time in milliseconds as 64 bit unsigned int
35*/
36fn now_milli() -> u64 {
37    // Get the current time as a Duration since the Unix Epoch
38    let duration_since_epoch = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time went backwards");
39
40    // Extract the milliseconds from the duration
41    let milliseconds = duration_since_epoch.as_millis();
42    return milliseconds as u64;
43}
44
45/**
46* Generates a string of random bytes of a given length.
47*
48* Input: size - An unsigned integer
49* Output: A random string of bytes of length <size>
50*/
51fn rand_bytes(size: usize) -> Vec<u8> {
52    let mut rng = rand::thread_rng();
53    (0..size).map(|_| rng.gen()).collect()
54}
55
56/**
57* The secure hash function.
58*
59* Input: plain_text - The plain text to hash in bytes
60* Output: hash_text - The hash text of the algorithm in use.
61*/
62fn hash(plain_text: &[u8]) -> Vec<u8> {
63    // create a SHA3-256 object
64    let mut hasher = Sha3_256::new();
65    // write input message
66    hasher.update(plain_text);
67    // read hash digest
68    let result = hasher.finalize();
69    return result.to_vec();
70}
71
72/**
73* A secure decryption function with integrated message authentication
74*
75* Input: cipher_text - The cipher text to decrypt
76* key - The key to decrypt with
77* Output: The plain text decrypted from the cipher text using the key, may return error is integrated MAC authentication fails
78*/
79fn decrypt(cipher_text: &[u8], key: &[u8]) -> Result<Vec<u8>, String> {
80    let aes_key = Key::<Aes256Gcm>::from_slice(key);
81    let cipher = Aes256Gcm::new(aes_key);
82    let plain_text = cipher.decrypt(Nonce::from_slice(&[0u8; 12]), cipher_text).unwrap();
83
84    return Ok(plain_text)
85}
86
87/**
88 * A secure encryption function with integrated message authentication
89 *
90 * Input: plain_text - The plain text to encrypt
91 * key - The key to encrypt with
92 * Output: The cipher text encrypted from the given plain text using the key
93 */
94fn encrypt(plain_text: &[u8], key: &[u8]) -> Result<Vec<u8>, String> {
95    let aes_key = Key::<Aes256Gcm>::from_slice(key);
96    let cipher = Aes256Gcm::new(aes_key);
97    let cipher_text = cipher.encrypt(Nonce::from_slice(&[0u8; 12]), plain_text).unwrap();
98
99    return Ok(cipher_text)
100}
101
102/**
103 * Just a struct to hold the sqlite connection
104 */
105pub struct SecureLayer {
106    conn: Connection
107}
108
109impl SecureLayer {
110    pub fn new(conn: Connection) -> SecureLayer {
111        /* Used to store the hashed passcode the secure layer recognises */
112        conn.execute(r#"
113            CREATE TABLE IF NOT EXISTS passcodes (
114                id INTEGER PRIMARY KEY,
115                hash BLOB
116            )
117        "#, ()).unwrap();
118        /* Used to store the open sessions for each passcode */
119        conn.execute(r#"
120            CREATE TABLE IF NOT EXISTS sessions (
121                id INTEGER PRIMARY KEY,
122                hash BLOB,
123                passcode,
124                modified INTEGER,
125                FOREIGN KEY(passcode) REFERENCES passcodes(id)
126            )
127        "#, ()).unwrap();
128
129        return SecureLayer {
130            conn
131        }
132    }
133
134    /**
135    * Registers a passcode so the secure layer will recognize it
136    *
137    * Input: passcode - The passcode string to store
138    * Output: Can fail on database error.
139    */
140    pub fn register_passcode(&self, passcode: &str) -> Result<(), String> {
141        /* Hash the passcode */
142        let passcode_hash = hash(passcode.as_bytes());
143
144        /* Save the hashed passcode */
145        self.conn.execute(r#"
146            INSERT INTO passcodes (hash)
147            VALUES (?1)
148        "#, params![passcode_hash]).unwrap();
149
150        return Ok(())
151    }
152
153    /**
154    * Deletes passcode from the secure layer
155    *
156    * Input: passcode - The passcode string to store
157    * Output: Can fail on database error.
158    */
159    pub fn delete_passcode(&self, passcode: &str) -> Result<(), String> {
160        let passcode_hash = hash(passcode.as_bytes());
161
162        /* Delete the passcode from the system which has the same hash as the given passcode */
163        self.conn.execute(r#"
164            DELETE FROM passcodes
165            WHERE hash=?1
166        "#, params![passcode_hash]).unwrap();
167
168        return Ok(())
169    }
170
171    /**
172    * Deletes session from the system
173    *
174    * Input: session_id - The id of the session to delete
175    * Output: Can fail on database error.
176    */
177    pub fn end_session(&self, session_id: u64) -> Result<(), String> {
178        self.conn.execute(r#"
179            DELETE FROM sessions
180            WHERE id=?1
181        "#, params![session_id]).unwrap();
182
183        return Ok(())
184    }
185
186    /**
187    * Creates a session in the system for a given passcode
188    *
189    * Input: passcode - The passcode for which to start a session
190    * Output: session_id - The id of the new session,
191    * init_key_hash - The initial key after it wased hashed with the passcode and some random string of bytes.
192    * Can fail on database error or if the given passcode is not present in the system.
193    */
194    pub fn start_session(&self, passcode_hash: &[u8]) -> Result<(u64, Vec<u8>), String> {
195        /* Find passcode in the system */
196        let passcode_id: u64 = match self.conn.query_row(
197            "SELECT * FROM passcodes WHERE hash=?",
198            params![passcode_hash],
199            |row| row.get(0),
200        ) {
201            Ok(id) => id,
202            Err(e) => return Err(format!("sql exception: {:?}", e))
203        };
204
205        /* Initialize init key to a random string of bytes */
206        let mut init_key_source = rand_bytes(KEY_SIZE);
207        /* Xor the passcode hash to the random string of bytes */
208        init_key_source.iter_mut()
209            .zip(passcode_hash.iter())
210            .for_each(|(x1, x2)| *x1 ^= *x2);
211        /* Hash the modified init key */
212        let init_key_hash = hash(init_key_source.as_slice());
213
214        /* Create a new session and store the hashed init key as the sessions current key */
215        self.conn.execute(r#"
216            INSERT INTO sessions (hash, passcode, modified)
217            VALUES (?1, ?2, ?3)
218        "#, params![init_key_hash.as_slice(), passcode_id, now_milli()]).unwrap();
219
220        let session_id = self.conn.last_insert_rowid() as u64;
221
222        Ok((session_id, init_key_hash)) /* return session id and hash init key to the user */
223    }
224
225    /**
226    * Authenticates the request to the server and undated the shared key.
227    *
228    * Input: session_id - The id of the session in which the client has made the request
229    * cipher_text - The cipher text encrypting the body of the client request
230    * Output: Returns the decrypted plain text using the hashed key stored in the database for the given session key, also updated the hashed key for the session
231    * with the plain text in order to be synchronised with the client. May fail on decryption error or database error.
232    */
233    pub fn authenticate_request(&self, session_id: u64, cipher_text: &[u8]) -> Result<Vec<u8>, String> {
234        /* fetch the hashed key for the given session */
235        let mut session_key: Vec<u8> = match self.conn.query_row(
236            "SELECT * FROM sessions WHERE id=?",
237            params![session_id],
238            |row| row.get(1),
239        ) {
240            Ok(hash) => hash,
241            Err(e) => return Err(format!("sql exception: {:?}", e))
242        };
243
244        /* decrypt the cipher text using the session key from the database */
245        let plain_text = decrypt(cipher_text, session_key.as_slice()).unwrap();
246
247        /* update the session's hashed key using the decrypted plain text */
248        session_key.extend(plain_text.clone());
249        let new_hash = hash(session_key.as_slice());
250        self.conn.execute(r#"
251            UPDATE sessions
252            SET hash=?1, modified=?2
253            WHERE id=?3
254        "#, params![new_hash, now_milli(), session_id]).unwrap();
255
256        /* return the request's plain text */
257        Ok(plain_text)
258    }
259}
260
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    const PASSCODE: &str = "abc123";
266
267    #[test]
268    fn test_client_server_communications() {
269        let conn = Connection::open_in_memory().unwrap();
270        let mut sl = SecureLayer::new(conn);
271
272        sl.register_passcode(PASSCODE).unwrap();
273        let passcode_hash = hash(PASSCODE.as_bytes());
274
275
276        let (sid, init_hash) = sl.start_session(passcode_hash.as_slice()).unwrap();
277        let mut key = init_hash;
278
279
280        let plain_text1 = "hello world".as_bytes();
281        let payload1 = encrypt(plain_text1, key.as_slice()).unwrap();
282        let req1 = sl.authenticate_request(sid, payload1.as_slice()).unwrap();
283
284        println!( "payload1 received as {:?}", String::from_utf8(req1.to_vec()) );
285
286        key.extend(plain_text1);
287        key = hash(key.as_slice());
288
289        let plain_text2 = "hai".as_bytes();
290        let payload2 = encrypt(plain_text2, key.as_slice()).unwrap();
291        let req2 = sl.authenticate_request(sid, payload2.as_slice()).unwrap();
292
293        println!( "payload2 received as {:?}", String::from_utf8(req2.to_vec()) );
294
295        key.extend(plain_text2);
296        key = hash(key.as_slice());
297
298        let plain_text3 = "it works!".as_bytes();
299        let payload3 = encrypt(plain_text3, key.as_slice()).unwrap();
300        let req3 = sl.authenticate_request(sid, payload3.as_slice()).unwrap();
301
302        println!( "payload3 received as {:?}", String::from_utf8(req3.to_vec()) );
303
304        key.extend(plain_text3);
305        key = hash(key.as_slice());
306
307        sl.end_session(sid);
308
309        let plain_text4 = "This should'nt work".as_bytes();
310        let payload4 = encrypt(plain_text4, key.as_slice()).unwrap();
311        let req4 = sl.authenticate_request(sid, payload4.as_slice());
312
313        assert!( req4.is_err() );
314    }
315}