1use std::collections::BTreeMap;
6
7use anyhow::{Result, bail};
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RpcHashMap {
13 pub salt: String,
14 pub batch: String,
15 pub procedures: BTreeMap<String, String>,
16}
17
18pub fn generate_random_salt() -> String {
20 let bytes: [u8; 8] = rand::random();
21 hex::encode(bytes)
22}
23
24fn hash_name(name: &str, salt: &str, hash_length: usize, prefix: &str) -> String {
26 let mut hasher = Sha256::new();
27 hasher.update(name.as_bytes());
28 hasher.update(salt.as_bytes());
29 let result = hasher.finalize();
30 let bytes_needed = hash_length.div_ceil(2);
31 let hex = hex::encode(&result[..bytes_needed]);
32 format!("{}{}", prefix, &hex[..hash_length])
33}
34
35pub fn generate_rpc_hash_map(
40 names: &[&str],
41 salt: &str,
42 hash_length: usize,
43 type_hint: bool,
44) -> Result<RpcHashMap> {
45 let prefix = if type_hint { "rpc-" } else { "" };
46
47 for attempt in 0..100u32 {
48 let effective_salt = if attempt == 0 { salt.to_string() } else { format!("{salt}{attempt}") };
49
50 let mut procedures = BTreeMap::new();
51 let mut seen = BTreeMap::new();
52 let mut collision = false;
53
54 let batch_hash = hash_name("_batch", &effective_salt, hash_length, prefix);
56 seen.insert(batch_hash.clone(), "_batch".to_string());
57
58 for &name in names {
59 let hash = hash_name(name, &effective_salt, hash_length, prefix);
60 if let Some(existing) = seen.get(&hash)
61 && existing != name
62 {
63 collision = true;
64 break;
65 }
66 seen.insert(hash.clone(), name.to_string());
67 procedures.insert(name.to_string(), hash);
68 }
69
70 if !collision {
71 return Ok(RpcHashMap { salt: effective_salt, batch: batch_hash, procedures });
72 }
73 }
74
75 bail!("failed to generate collision-free RPC hash map after 100 attempts")
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn deterministic_with_same_salt() {
84 let salt = generate_random_salt();
85 let map1 = generate_rpc_hash_map(&["getUser", "getSession"], &salt, 12, true).unwrap();
86 let map2 = generate_rpc_hash_map(&["getUser", "getSession"], &salt, 12, true).unwrap();
87 assert_eq!(map1.procedures, map2.procedures);
88 assert_eq!(map1.batch, map2.batch);
89 }
90
91 #[test]
92 fn different_salt_different_hashes() {
93 let salt1 = generate_random_salt();
94 let mut salt2 = generate_random_salt();
95 if salt1 == salt2 {
96 salt2 = generate_random_salt();
97 }
98 let map1 = generate_rpc_hash_map(&["getUser"], &salt1, 12, true).unwrap();
99 let map2 = generate_rpc_hash_map(&["getUser"], &salt2, 12, true).unwrap();
100 assert_ne!(map1.procedures["getUser"], map2.procedures["getUser"]);
101 }
102
103 #[test]
104 fn no_collision_on_typical_set() {
105 let names: Vec<&str> = vec![
106 "getUser",
107 "getSession",
108 "listPosts",
109 "createPost",
110 "updatePost",
111 "deletePost",
112 "getComments",
113 "addComment",
114 ];
115 let salt = generate_random_salt();
116 let map = generate_rpc_hash_map(&names, &salt, 12, true).unwrap();
117 assert_eq!(map.procedures.len(), names.len());
118 let hashes: std::collections::HashSet<_> = map.procedures.values().collect();
120 assert_eq!(hashes.len(), names.len());
121 }
122
123 #[test]
124 fn hash_length_type_hint_true() {
125 let salt = generate_random_salt();
126 let map = generate_rpc_hash_map(&["test"], &salt, 12, true).unwrap();
127 let hash = &map.procedures["test"];
128 assert_eq!(hash.len(), 16);
130 assert!(hash.starts_with("rpc-"));
131 assert!(hash[4..].chars().all(|c| c.is_ascii_hexdigit()));
132 assert_eq!(map.batch.len(), 16);
133 assert!(map.batch.starts_with("rpc-"));
134 }
135
136 #[test]
137 fn hash_length_type_hint_false() {
138 let salt = generate_random_salt();
139 let map = generate_rpc_hash_map(&["test"], &salt, 12, false).unwrap();
140 let hash = &map.procedures["test"];
141 assert_eq!(hash.len(), 12);
143 assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
144 assert!(!hash.starts_with("rpc-"));
145 assert_eq!(map.batch.len(), 12);
146 }
147
148 #[test]
149 fn hash_length_custom() {
150 let salt = generate_random_salt();
151 let map = generate_rpc_hash_map(&["test"], &salt, 8, true).unwrap();
153 let hash = &map.procedures["test"];
154 assert_eq!(hash.len(), 12); assert!(hash.starts_with("rpc-"));
156 assert!(hash[4..].chars().all(|c| c.is_ascii_hexdigit()));
157
158 let map = generate_rpc_hash_map(&["test"], &salt, 20, false).unwrap();
160 let hash = &map.procedures["test"];
161 assert_eq!(hash.len(), 20);
162 assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
163
164 let map = generate_rpc_hash_map(&["test"], &salt, 7, false).unwrap();
166 let hash = &map.procedures["test"];
167 assert_eq!(hash.len(), 7);
168 assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
169 }
170
171 #[test]
172 fn serialization_roundtrip() {
173 let salt = generate_random_salt();
174 let map = generate_rpc_hash_map(&["a", "b"], &salt, 12, true).unwrap();
175 let json = serde_json::to_string(&map).unwrap();
176 let restored: RpcHashMap = serde_json::from_str(&json).unwrap();
177 assert_eq!(map.salt, restored.salt);
178 assert_eq!(map.batch, restored.batch);
179 assert_eq!(map.procedures, restored.procedures);
180 }
181
182 #[test]
183 fn random_salt_is_16_hex_chars() {
184 let salt = generate_random_salt();
185 assert_eq!(salt.len(), 16);
186 assert!(salt.chars().all(|c| c.is_ascii_hexdigit()));
187 }
188
189 #[test]
190 fn empty_procedures() {
191 let salt = generate_random_salt();
192 let map = generate_rpc_hash_map(&[], &salt, 12, true).unwrap();
193 assert!(map.procedures.is_empty());
194 assert!(!map.batch.is_empty());
195 }
196}