ssi_eip712/
encode.rs

1use std::collections::HashMap;
2
3use keccak_hash::keccak;
4
5use crate::{
6    bytes_from_hex, hashing::TypedDataHashError, StructName, TypeDefinition, TypeRef, Types, Value,
7};
8
9static EMPTY_32: [u8; 32] = [0; 32];
10
11impl Value {
12    pub fn as_bytes(&self) -> Result<Option<Vec<u8>>, TypedDataHashError> {
13        let bytes = match self {
14            Value::Bytes(bytes) => bytes.to_vec(),
15            Value::Integer(int) => int.to_be_bytes().to_vec(),
16            Value::String(string) => {
17                bytes_from_hex(string).ok_or(TypedDataHashError::ExpectedHex)?
18            }
19            _ => {
20                return Err(TypedDataHashError::ExpectedBytes);
21            }
22        };
23        Ok(Some(bytes))
24    }
25
26    /// Encode the value into a byte string according to the [EIP-712
27    /// `encodeData` function][1].
28    ///
29    /// Note: this implementation follows eth-sig-util
30    /// which [diverges from EIP-712 when encoding arrays][2].
31    ///
32    /// [1]: <https://eips.ethereum.org/EIPS/eip-712#definition-of-encodedata>
33    /// [2]: <https://github.com/MetaMask/eth-sig-util/issues/106>
34    pub fn encode(&self, type_: &TypeRef, types: &Types) -> Result<Vec<u8>, TypedDataHashError> {
35        let bytes = match type_ {
36            TypeRef::Bytes => {
37                let bytes_opt;
38                let bytes = match self {
39                    Value::Bytes(bytes) => Some(bytes),
40                    Value::String(string) => {
41                        bytes_opt = bytes_from_hex(string);
42                        bytes_opt.as_ref()
43                    }
44                    _ => None,
45                }
46                .ok_or(TypedDataHashError::ExpectedBytes)?;
47                keccak(bytes).to_fixed_bytes().to_vec()
48            }
49            TypeRef::String => {
50                let string = match self {
51                    Value::String(string) => string,
52                    _ => {
53                        return Err(TypedDataHashError::ExpectedString);
54                    }
55                };
56                keccak(string.as_bytes()).to_fixed_bytes().to_vec()
57            }
58            TypeRef::BytesN(n) => {
59                let n = *n;
60                if !(1..=32).contains(&n) {
61                    return Err(TypedDataHashError::BytesLength(n));
62                }
63                let mut bytes = match self {
64                    Value::Bytes(bytes) => Some(bytes.to_vec()),
65                    Value::String(string) => bytes_from_hex(string),
66                    _ => None,
67                }
68                .ok_or(TypedDataHashError::ExpectedBytes)?;
69                let len = bytes.len();
70                if len != n {
71                    return Err(TypedDataHashError::ExpectedBytesLength(n, len));
72                }
73                if len < 32 {
74                    bytes.resize(32, 0);
75                }
76                bytes
77            }
78            TypeRef::UintN(n) => {
79                let n = *n;
80                if n % 8 != 0 {
81                    return Err(TypedDataHashError::TypeNotByteAligned("uint", n));
82                }
83                if !(8..=256).contains(&n) {
84                    return Err(TypedDataHashError::IntegerLength(n));
85                }
86                let int = self
87                    .as_bytes()?
88                    .ok_or(TypedDataHashError::ExpectedInteger)?;
89                let len = int.len();
90                if len > 32 {
91                    return Err(TypedDataHashError::IntegerTooLong(len));
92                }
93                if len == 32 {
94                    return Ok(int);
95                }
96                // Left-pad to 256 bits
97                [EMPTY_32[0..(32 - len)].to_vec(), int].concat()
98            }
99            TypeRef::IntN(n) => {
100                let n = *n;
101                if n % 8 != 0 {
102                    return Err(TypedDataHashError::TypeNotByteAligned("int", n));
103                }
104                if !(8..=256).contains(&n) {
105                    return Err(TypedDataHashError::IntegerLength(n));
106                }
107                let int = self
108                    .as_bytes()?
109                    .ok_or(TypedDataHashError::ExpectedInteger)?;
110                let len = int.len();
111                if len > 32 {
112                    return Err(TypedDataHashError::IntegerTooLong(len));
113                }
114                if len == 32 {
115                    return Ok(int);
116                }
117                // Left-pad to 256 bits, with sign extension.
118                let negative = int[0] & 0x80 == 0x80;
119                static PADDING_POS: [u8; 32] = [0; 32];
120                static PADDING_NEG: [u8; 32] = [0xff; 32];
121                let padding = if negative { PADDING_NEG } else { PADDING_POS };
122                [padding[0..(32 - len)].to_vec(), int].concat()
123            }
124            TypeRef::Bool => {
125                let b = self.as_bool().ok_or(TypedDataHashError::ExpectedBoolean)?;
126                let mut bytes: [u8; 32] = [0; 32];
127                if b {
128                    bytes[31] = 1;
129                }
130                bytes.to_vec()
131            }
132            TypeRef::Address => {
133                let bytes = self.as_bytes()?.ok_or(TypedDataHashError::ExpectedBytes)?;
134                if bytes.len() != 20 {
135                    return Err(TypedDataHashError::ExpectedAddressLength(bytes.len()));
136                }
137                static PADDING: [u8; 12] = [0; 12];
138                [PADDING.to_vec(), bytes].concat()
139            }
140            TypeRef::Array(member_type) => {
141                // Note: this implementation follows eth-sig-util
142                // which diverges from EIP-712 when encoding arrays.
143                // Ref: https://github.com/MetaMask/eth-sig-util/issues/106
144                let array = match self {
145                    Value::Array(array) => array,
146                    _ => {
147                        return Err(TypedDataHashError::ExpectedArray(
148                            member_type.to_string(),
149                            self.kind(),
150                        ));
151                    }
152                };
153                let mut enc = Vec::with_capacity(32 * array.len());
154                for member in array {
155                    let mut member_enc = encode_field(member, member_type, types)?;
156                    enc.append(&mut member_enc);
157                }
158                enc
159            }
160            TypeRef::ArrayN(member_type, n) => {
161                let array = match self {
162                    Value::Array(array) => array,
163                    _ => {
164                        return Err(TypedDataHashError::ExpectedArray(
165                            member_type.to_string(),
166                            self.kind(),
167                        ));
168                    }
169                };
170                let n = *n;
171                let len = array.len();
172                if len != n {
173                    return Err(TypedDataHashError::ExpectedArrayLength(n, len));
174                }
175                let mut enc = Vec::with_capacity(32 * n);
176                for member in array {
177                    let mut member_enc = encode_field(member, member_type, types)?;
178                    enc.append(&mut member_enc);
179                }
180                enc
181            }
182            TypeRef::Struct(struct_name) => {
183                let struct_type = types.get(struct_name).ok_or_else(|| {
184                    TypedDataHashError::MissingReferencedType(struct_name.to_string())
185                })?;
186                let hash_map = match self {
187                    Value::Struct(hash_map) => hash_map,
188                    _ => {
189                        return Err(TypedDataHashError::ExpectedObject(
190                            struct_name.to_string(),
191                            self.kind(),
192                        ));
193                    }
194                };
195                let mut enc = Vec::with_capacity(32 * (struct_type.member_variables().len() + 1));
196                let type_hash = struct_type.hash(struct_name, types)?;
197                enc.append(&mut type_hash.to_vec());
198                let mut keys: std::collections::HashSet<String> =
199                    hash_map.keys().map(|k| k.to_owned()).collect();
200                for member in struct_type.member_variables() {
201                    let mut member_enc = match hash_map.get(&member.name) {
202                        Some(value) => encode_field(value, &member.type_, types)?,
203                        // Allow missing member structs
204                        None => EMPTY_32.to_vec(),
205                    };
206                    keys.remove(&member.name);
207                    enc.append(&mut member_enc);
208                }
209                if !keys.is_empty() {
210                    // A key was remaining in the data that does not have a type in the struct.
211                    let names: Vec<String> = keys.into_iter().collect();
212                    return Err(TypedDataHashError::UntypedProperties(names));
213                }
214                enc
215            }
216        };
217        Ok(bytes)
218    }
219}
220
221fn encode_field(
222    data: &Value,
223    type_: &TypeRef,
224    types: &Types,
225) -> Result<Vec<u8>, TypedDataHashError> {
226    let is_struct_or_array = matches!(
227        type_,
228        TypeRef::Struct(_) | TypeRef::Array(_) | TypeRef::ArrayN(_, _)
229    );
230    let encoded = data.encode(type_, types)?;
231    if is_struct_or_array {
232        let hash = keccak(&encoded).to_fixed_bytes().to_vec();
233        Ok(hash)
234    } else {
235        Ok(encoded)
236    }
237}
238
239impl TypeDefinition {
240    /// Encode the type into a byte string using the [EIP-712 `encodeType`
241    /// function][1].
242    ///
243    /// [1]: <https://eips.ethereum.org/EIPS/eip-712#definition-of-encodetype>
244    #[allow(clippy::ptr_arg)]
245    pub fn encode(
246        &self,
247        struct_name: &StructName,
248        types: &Types,
249    ) -> Result<Vec<u8>, TypedDataHashError> {
250        let mut string = String::new();
251        encode_type_single(struct_name, self, &mut string);
252        let mut referenced_types = HashMap::new();
253        gather_referenced_struct_types(self, types, &mut referenced_types)?;
254        let mut types: Vec<(&String, &TypeDefinition)> = referenced_types.into_iter().collect();
255        types.sort_by(|(name1, _), (name2, _)| name1.cmp(name2));
256        for (name, type_) in types {
257            encode_type_single(name, type_, &mut string);
258        }
259        Ok(string.into_bytes())
260    }
261}
262
263fn gather_referenced_struct_types<'a>(
264    type_: &'a TypeDefinition,
265    types: &'a Types,
266    memo: &mut HashMap<&'a String, &'a TypeDefinition>,
267) -> Result<(), TypedDataHashError> {
268    for member in type_.member_variables() {
269        if let Some(struct_name) = member.type_.as_struct_name() {
270            use std::collections::hash_map::Entry;
271            let entry = memo.entry(struct_name);
272            if let Entry::Vacant(o) = entry {
273                let referenced_struct = types.get(struct_name).ok_or_else(|| {
274                    TypedDataHashError::MissingReferencedType(struct_name.to_string())
275                })?;
276                o.insert(referenced_struct);
277                gather_referenced_struct_types(referenced_struct, types, memo)?;
278            }
279        }
280    }
281    Ok(())
282}
283
284#[allow(clippy::ptr_arg)]
285fn encode_type_single(type_name: &StructName, type_: &TypeDefinition, string: &mut String) {
286    string.push_str(type_name);
287    string.push('(');
288    let mut first = true;
289    for member in type_.member_variables() {
290        if first {
291            first = false;
292        } else {
293            string.push(',');
294        }
295        string.push_str(&String::from(member.type_.clone()));
296        string.push(' ');
297        string.push_str(&member.name);
298    }
299    string.push(')');
300}