1use std::collections::BTreeMap;
6
7use anyhow::{Result, bail};
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RpcHashMap {
13 pub salt: String,
14 pub batch: String,
15 pub procedures: BTreeMap<String, String>,
16}
17
18pub fn generate_random_salt() -> String {
20 let bytes: [u8; 8] = rand::random();
21 hex::encode(bytes)
22}
23
24fn hash_name(name: &str, salt: &str, hash_length: usize, prefix: &str) -> String {
26 let mut hasher = Sha256::new();
27 hasher.update(name.as_bytes());
28 hasher.update(salt.as_bytes());
29 let result = hasher.finalize();
30 let bytes_needed = hash_length.div_ceil(2);
31 let hex = hex::encode(&result[..bytes_needed]);
32 format!("{}{}", prefix, &hex[..hash_length])
33}
34
35pub fn generate_rpc_hash_map(
40 names: &[&str],
41 salt: &str,
42 hash_length: usize,
43 type_hint: bool,
44) -> Result<RpcHashMap> {
45 let prefix = if type_hint { "rpc-" } else { "" };
46
47 for attempt in 0..100u32 {
48 let effective_salt = if attempt == 0 { salt.to_string() } else { format!("{salt}{attempt}") };
49
50 let mut procedures = BTreeMap::new();
51 let mut seen = BTreeMap::new();
52 let mut collision = false;
53
54 let batch_hash = hash_name("_batch", &effective_salt, hash_length, prefix);
56 seen.insert(batch_hash.clone(), "_batch".to_string());
57
58 for &name in names {
59 let hash = hash_name(name, &effective_salt, hash_length, prefix);
60 if let Some(existing) = seen.get(&hash)
61 && existing != name
62 {
63 collision = true;
64 break;
65 }
66 seen.insert(hash.clone(), name.to_string());
67 procedures.insert(name.to_string(), hash);
68 }
69
70 if !collision {
71 return Ok(RpcHashMap { salt: effective_salt, batch: batch_hash, procedures });
72 }
73 }
74
75 bail!("failed to generate collision-free RPC hash map after 100 attempts")
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn deterministic_with_same_salt() {
84 let map1 =
85 generate_rpc_hash_map(&["getUser", "getSession"], "abcd1234abcd1234", 12, true).unwrap();
86 let map2 =
87 generate_rpc_hash_map(&["getUser", "getSession"], "abcd1234abcd1234", 12, true).unwrap();
88 assert_eq!(map1.procedures, map2.procedures);
89 assert_eq!(map1.batch, map2.batch);
90 }
91
92 #[test]
93 fn different_salt_different_hashes() {
94 let map1 = generate_rpc_hash_map(&["getUser"], "salt_a_1234567890", 12, true).unwrap();
95 let map2 = generate_rpc_hash_map(&["getUser"], "salt_b_1234567890", 12, true).unwrap();
96 assert_ne!(map1.procedures["getUser"], map2.procedures["getUser"]);
97 }
98
99 #[test]
100 fn no_collision_on_typical_set() {
101 let names: Vec<&str> = vec![
102 "getUser",
103 "getSession",
104 "listPosts",
105 "createPost",
106 "updatePost",
107 "deletePost",
108 "getComments",
109 "addComment",
110 ];
111 let map = generate_rpc_hash_map(&names, "test_salt_12345678", 12, true).unwrap();
112 assert_eq!(map.procedures.len(), names.len());
113 let hashes: std::collections::HashSet<_> = map.procedures.values().collect();
115 assert_eq!(hashes.len(), names.len());
116 }
117
118 #[test]
119 fn hash_length_type_hint_true() {
120 let map = generate_rpc_hash_map(&["test"], "salt_for_testing_1", 12, true).unwrap();
121 let hash = &map.procedures["test"];
122 assert_eq!(hash.len(), 16);
124 assert!(hash.starts_with("rpc-"));
125 assert!(hash[4..].chars().all(|c| c.is_ascii_hexdigit()));
126 assert_eq!(map.batch.len(), 16);
127 assert!(map.batch.starts_with("rpc-"));
128 }
129
130 #[test]
131 fn hash_length_type_hint_false() {
132 let map = generate_rpc_hash_map(&["test"], "salt_for_testing_1", 12, false).unwrap();
133 let hash = &map.procedures["test"];
134 assert_eq!(hash.len(), 12);
136 assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
137 assert!(!hash.starts_with("rpc-"));
138 assert_eq!(map.batch.len(), 12);
139 }
140
141 #[test]
142 fn hash_length_custom() {
143 let map = generate_rpc_hash_map(&["test"], "salt_for_testing_1", 8, true).unwrap();
145 let hash = &map.procedures["test"];
146 assert_eq!(hash.len(), 12); assert!(hash.starts_with("rpc-"));
148 assert!(hash[4..].chars().all(|c| c.is_ascii_hexdigit()));
149
150 let map = generate_rpc_hash_map(&["test"], "salt_for_testing_1", 20, false).unwrap();
152 let hash = &map.procedures["test"];
153 assert_eq!(hash.len(), 20);
154 assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
155
156 let map = generate_rpc_hash_map(&["test"], "salt_for_testing_1", 7, false).unwrap();
158 let hash = &map.procedures["test"];
159 assert_eq!(hash.len(), 7);
160 assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
161 }
162
163 #[test]
164 fn serialization_roundtrip() {
165 let map = generate_rpc_hash_map(&["a", "b"], "roundtrip_salt_00", 12, true).unwrap();
166 let json = serde_json::to_string(&map).unwrap();
167 let restored: RpcHashMap = serde_json::from_str(&json).unwrap();
168 assert_eq!(map.salt, restored.salt);
169 assert_eq!(map.batch, restored.batch);
170 assert_eq!(map.procedures, restored.procedures);
171 }
172
173 #[test]
174 fn random_salt_is_16_hex_chars() {
175 let salt = generate_random_salt();
176 assert_eq!(salt.len(), 16);
177 assert!(salt.chars().all(|c| c.is_ascii_hexdigit()));
178 }
179
180 #[test]
181 fn empty_procedures() {
182 let map = generate_rpc_hash_map(&[], "empty_salt_123456", 12, true).unwrap();
183 assert!(map.procedures.is_empty());
184 assert!(!map.batch.is_empty());
185 }
186}