1use std::collections::HashMap;
2
3use keccak_hash::keccak;
4
5use crate::{
6 bytes_from_hex, hashing::TypedDataHashError, StructName, TypeDefinition, TypeRef, Types, Value,
7};
8
9static EMPTY_32: [u8; 32] = [0; 32];
10
11impl Value {
12 pub fn as_bytes(&self) -> Result<Option<Vec<u8>>, TypedDataHashError> {
13 let bytes = match self {
14 Value::Bytes(bytes) => bytes.to_vec(),
15 Value::Integer(int) => int.to_be_bytes().to_vec(),
16 Value::String(string) => {
17 bytes_from_hex(string).ok_or(TypedDataHashError::ExpectedHex)?
18 }
19 _ => {
20 return Err(TypedDataHashError::ExpectedBytes);
21 }
22 };
23 Ok(Some(bytes))
24 }
25
26 pub fn encode(&self, type_: &TypeRef, types: &Types) -> Result<Vec<u8>, TypedDataHashError> {
35 let bytes = match type_ {
36 TypeRef::Bytes => {
37 let bytes_opt;
38 let bytes = match self {
39 Value::Bytes(bytes) => Some(bytes),
40 Value::String(string) => {
41 bytes_opt = bytes_from_hex(string);
42 bytes_opt.as_ref()
43 }
44 _ => None,
45 }
46 .ok_or(TypedDataHashError::ExpectedBytes)?;
47 keccak(bytes).to_fixed_bytes().to_vec()
48 }
49 TypeRef::String => {
50 let string = match self {
51 Value::String(string) => string,
52 _ => {
53 return Err(TypedDataHashError::ExpectedString);
54 }
55 };
56 keccak(string.as_bytes()).to_fixed_bytes().to_vec()
57 }
58 TypeRef::BytesN(n) => {
59 let n = *n;
60 if !(1..=32).contains(&n) {
61 return Err(TypedDataHashError::BytesLength(n));
62 }
63 let mut bytes = match self {
64 Value::Bytes(bytes) => Some(bytes.to_vec()),
65 Value::String(string) => bytes_from_hex(string),
66 _ => None,
67 }
68 .ok_or(TypedDataHashError::ExpectedBytes)?;
69 let len = bytes.len();
70 if len != n {
71 return Err(TypedDataHashError::ExpectedBytesLength(n, len));
72 }
73 if len < 32 {
74 bytes.resize(32, 0);
75 }
76 bytes
77 }
78 TypeRef::UintN(n) => {
79 let n = *n;
80 if n % 8 != 0 {
81 return Err(TypedDataHashError::TypeNotByteAligned("uint", n));
82 }
83 if !(8..=256).contains(&n) {
84 return Err(TypedDataHashError::IntegerLength(n));
85 }
86 let int = self
87 .as_bytes()?
88 .ok_or(TypedDataHashError::ExpectedInteger)?;
89 let len = int.len();
90 if len > 32 {
91 return Err(TypedDataHashError::IntegerTooLong(len));
92 }
93 if len == 32 {
94 return Ok(int);
95 }
96 [EMPTY_32[0..(32 - len)].to_vec(), int].concat()
98 }
99 TypeRef::IntN(n) => {
100 let n = *n;
101 if n % 8 != 0 {
102 return Err(TypedDataHashError::TypeNotByteAligned("int", n));
103 }
104 if !(8..=256).contains(&n) {
105 return Err(TypedDataHashError::IntegerLength(n));
106 }
107 let int = self
108 .as_bytes()?
109 .ok_or(TypedDataHashError::ExpectedInteger)?;
110 let len = int.len();
111 if len > 32 {
112 return Err(TypedDataHashError::IntegerTooLong(len));
113 }
114 if len == 32 {
115 return Ok(int);
116 }
117 let negative = int[0] & 0x80 == 0x80;
119 static PADDING_POS: [u8; 32] = [0; 32];
120 static PADDING_NEG: [u8; 32] = [0xff; 32];
121 let padding = if negative { PADDING_NEG } else { PADDING_POS };
122 [padding[0..(32 - len)].to_vec(), int].concat()
123 }
124 TypeRef::Bool => {
125 let b = self.as_bool().ok_or(TypedDataHashError::ExpectedBoolean)?;
126 let mut bytes: [u8; 32] = [0; 32];
127 if b {
128 bytes[31] = 1;
129 }
130 bytes.to_vec()
131 }
132 TypeRef::Address => {
133 let bytes = self.as_bytes()?.ok_or(TypedDataHashError::ExpectedBytes)?;
134 if bytes.len() != 20 {
135 return Err(TypedDataHashError::ExpectedAddressLength(bytes.len()));
136 }
137 static PADDING: [u8; 12] = [0; 12];
138 [PADDING.to_vec(), bytes].concat()
139 }
140 TypeRef::Array(member_type) => {
141 let array = match self {
145 Value::Array(array) => array,
146 _ => {
147 return Err(TypedDataHashError::ExpectedArray(
148 member_type.to_string(),
149 self.kind(),
150 ));
151 }
152 };
153 let mut enc = Vec::with_capacity(32 * array.len());
154 for member in array {
155 let mut member_enc = encode_field(member, member_type, types)?;
156 enc.append(&mut member_enc);
157 }
158 enc
159 }
160 TypeRef::ArrayN(member_type, n) => {
161 let array = match self {
162 Value::Array(array) => array,
163 _ => {
164 return Err(TypedDataHashError::ExpectedArray(
165 member_type.to_string(),
166 self.kind(),
167 ));
168 }
169 };
170 let n = *n;
171 let len = array.len();
172 if len != n {
173 return Err(TypedDataHashError::ExpectedArrayLength(n, len));
174 }
175 let mut enc = Vec::with_capacity(32 * n);
176 for member in array {
177 let mut member_enc = encode_field(member, member_type, types)?;
178 enc.append(&mut member_enc);
179 }
180 enc
181 }
182 TypeRef::Struct(struct_name) => {
183 let struct_type = types.get(struct_name).ok_or_else(|| {
184 TypedDataHashError::MissingReferencedType(struct_name.to_string())
185 })?;
186 let hash_map = match self {
187 Value::Struct(hash_map) => hash_map,
188 _ => {
189 return Err(TypedDataHashError::ExpectedObject(
190 struct_name.to_string(),
191 self.kind(),
192 ));
193 }
194 };
195 let mut enc = Vec::with_capacity(32 * (struct_type.member_variables().len() + 1));
196 let type_hash = struct_type.hash(struct_name, types)?;
197 enc.append(&mut type_hash.to_vec());
198 let mut keys: std::collections::HashSet<String> =
199 hash_map.keys().map(|k| k.to_owned()).collect();
200 for member in struct_type.member_variables() {
201 let mut member_enc = match hash_map.get(&member.name) {
202 Some(value) => encode_field(value, &member.type_, types)?,
203 None => EMPTY_32.to_vec(),
205 };
206 keys.remove(&member.name);
207 enc.append(&mut member_enc);
208 }
209 if !keys.is_empty() {
210 let names: Vec<String> = keys.into_iter().collect();
212 return Err(TypedDataHashError::UntypedProperties(names));
213 }
214 enc
215 }
216 };
217 Ok(bytes)
218 }
219}
220
221fn encode_field(
222 data: &Value,
223 type_: &TypeRef,
224 types: &Types,
225) -> Result<Vec<u8>, TypedDataHashError> {
226 let is_struct_or_array = matches!(
227 type_,
228 TypeRef::Struct(_) | TypeRef::Array(_) | TypeRef::ArrayN(_, _)
229 );
230 let encoded = data.encode(type_, types)?;
231 if is_struct_or_array {
232 let hash = keccak(&encoded).to_fixed_bytes().to_vec();
233 Ok(hash)
234 } else {
235 Ok(encoded)
236 }
237}
238
239impl TypeDefinition {
240 #[allow(clippy::ptr_arg)]
245 pub fn encode(
246 &self,
247 struct_name: &StructName,
248 types: &Types,
249 ) -> Result<Vec<u8>, TypedDataHashError> {
250 let mut string = String::new();
251 encode_type_single(struct_name, self, &mut string);
252 let mut referenced_types = HashMap::new();
253 gather_referenced_struct_types(self, types, &mut referenced_types)?;
254 let mut types: Vec<(&String, &TypeDefinition)> = referenced_types.into_iter().collect();
255 types.sort_by(|(name1, _), (name2, _)| name1.cmp(name2));
256 for (name, type_) in types {
257 encode_type_single(name, type_, &mut string);
258 }
259 Ok(string.into_bytes())
260 }
261}
262
263fn gather_referenced_struct_types<'a>(
264 type_: &'a TypeDefinition,
265 types: &'a Types,
266 memo: &mut HashMap<&'a String, &'a TypeDefinition>,
267) -> Result<(), TypedDataHashError> {
268 for member in type_.member_variables() {
269 if let Some(struct_name) = member.type_.as_struct_name() {
270 use std::collections::hash_map::Entry;
271 let entry = memo.entry(struct_name);
272 if let Entry::Vacant(o) = entry {
273 let referenced_struct = types.get(struct_name).ok_or_else(|| {
274 TypedDataHashError::MissingReferencedType(struct_name.to_string())
275 })?;
276 o.insert(referenced_struct);
277 gather_referenced_struct_types(referenced_struct, types, memo)?;
278 }
279 }
280 }
281 Ok(())
282}
283
284#[allow(clippy::ptr_arg)]
285fn encode_type_single(type_name: &StructName, type_: &TypeDefinition, string: &mut String) {
286 string.push_str(type_name);
287 string.push('(');
288 let mut first = true;
289 for member in type_.member_variables() {
290 if first {
291 first = false;
292 } else {
293 string.push(',');
294 }
295 string.push_str(&String::from(member.type_.clone()));
296 string.push(' ');
297 string.push_str(&member.name);
298 }
299 string.push(')');
300}