1use super::utils::hash_bytes;
6use crate::error::{Error, MetadataError, StorageAddressError};
7use alloc::vec;
8use alloc::vec::Vec;
9use scale_decode::{DecodeAsType, visitor::IgnoreVisitor};
10use scale_encode::EncodeAsType;
11use scale_info::{PortableRegistry, TypeDef};
12use scale_value::Value;
13use subxt_metadata::{StorageEntryType, StorageHasher};
14
15#[derive(Debug)]
18pub struct StorageHashers {
19 hashers_and_ty_ids: Vec<(StorageHasher, u32)>,
20}
21
22impl StorageHashers {
23 pub fn new(storage_entry: &StorageEntryType, types: &PortableRegistry) -> Result<Self, Error> {
26 let mut hashers_and_ty_ids = vec![];
27 if let StorageEntryType::Map {
28 hashers, key_ty, ..
29 } = storage_entry
30 {
31 let ty = types
32 .resolve(*key_ty)
33 .ok_or(MetadataError::TypeNotFound(*key_ty))?;
34
35 if hashers.len() == 1 {
36 hashers_and_ty_ids = vec![(hashers[0], *key_ty)];
40 } else {
41 let hasher_count = hashers.len();
45 let tuple = match &ty.type_def {
46 TypeDef::Tuple(tuple) => tuple,
47 _ => {
48 return Err(StorageAddressError::WrongNumberOfHashers {
49 hashers: hasher_count,
50 fields: 1,
51 }
52 .into());
53 }
54 };
55
56 let key_count = tuple.fields.len();
58 if hasher_count != key_count {
59 return Err(StorageAddressError::WrongNumberOfHashers {
60 hashers: hasher_count,
61 fields: key_count,
62 }
63 .into());
64 }
65
66 hashers_and_ty_ids = tuple
68 .fields
69 .iter()
70 .zip(hashers)
71 .map(|(field, hasher)| (*hasher, field.id))
72 .collect();
73 }
74 }
75
76 Ok(Self { hashers_and_ty_ids })
77 }
78
79 pub fn iter(&self) -> StorageHashersIter<'_> {
81 StorageHashersIter {
82 hashers: self,
83 idx: 0,
84 }
85 }
86}
87
88#[derive(Debug)]
91pub struct StorageHashersIter<'a> {
92 hashers: &'a StorageHashers,
93 idx: usize,
94}
95
96impl StorageHashersIter<'_> {
97 fn next_or_err(&mut self) -> Result<(StorageHasher, u32), Error> {
98 self.next().ok_or_else(|| {
99 StorageAddressError::TooManyKeys {
100 expected: self.hashers.hashers_and_ty_ids.len(),
101 }
102 .into()
103 })
104 }
105}
106
107impl Iterator for StorageHashersIter<'_> {
108 type Item = (StorageHasher, u32);
109
110 fn next(&mut self) -> Option<Self::Item> {
111 let item = self.hashers.hashers_and_ty_ids.get(self.idx).copied()?;
112 self.idx += 1;
113 Some(item)
114 }
115}
116
117impl ExactSizeIterator for StorageHashersIter<'_> {
118 fn len(&self) -> usize {
119 self.hashers.hashers_and_ty_ids.len() - self.idx
120 }
121}
122
123pub trait StorageKey {
125 fn encode_storage_key(
127 &self,
128 bytes: &mut Vec<u8>,
129 hashers: &mut StorageHashersIter,
130 types: &PortableRegistry,
131 ) -> Result<(), Error>;
132
133 fn decode_storage_key(
138 bytes: &mut &[u8],
139 hashers: &mut StorageHashersIter,
140 types: &PortableRegistry,
141 ) -> Result<Self, Error>
142 where
143 Self: Sized + 'static;
144}
145
146impl StorageKey for () {
149 fn encode_storage_key(
150 &self,
151 _bytes: &mut Vec<u8>,
152 hashers: &mut StorageHashersIter,
153 _types: &PortableRegistry,
154 ) -> Result<(), Error> {
155 _ = hashers.next_or_err();
156 Ok(())
157 }
158
159 fn decode_storage_key(
160 bytes: &mut &[u8],
161 hashers: &mut StorageHashersIter,
162 types: &PortableRegistry,
163 ) -> Result<Self, Error> {
164 let (hasher, ty_id) = match hashers.next_or_err() {
165 Ok((hasher, ty_id)) => (hasher, ty_id),
166 Err(_) if bytes.is_empty() => return Ok(()),
167 Err(err) => return Err(err),
168 };
169 consume_hash_returning_key_bytes(bytes, hasher, ty_id, types)?;
170 Ok(())
171 }
172}
173
174#[derive(Clone, Debug, PartialOrd, PartialEq, Eq)]
176pub struct StaticStorageKey<K> {
177 key: K,
178}
179
180impl<K> StaticStorageKey<K> {
181 pub fn new(key: K) -> Self {
183 StaticStorageKey { key }
184 }
185}
186
187impl<K: Clone> StaticStorageKey<K> {
188 pub fn into_key(self) -> K {
190 self.key
191 }
192}
193
194impl<K: EncodeAsType + DecodeAsType> StorageKey for StaticStorageKey<K> {
195 fn encode_storage_key(
196 &self,
197 bytes: &mut Vec<u8>,
198 hashers: &mut StorageHashersIter,
199 types: &PortableRegistry,
200 ) -> Result<(), Error> {
201 let (hasher, ty_id) = hashers.next_or_err()?;
202 let encoded_value = self.key.encode_as_type(ty_id, types)?;
203 hash_bytes(&encoded_value, hasher, bytes);
204 Ok(())
205 }
206
207 fn decode_storage_key(
208 bytes: &mut &[u8],
209 hashers: &mut StorageHashersIter,
210 types: &PortableRegistry,
211 ) -> Result<Self, Error>
212 where
213 Self: Sized + 'static,
214 {
215 let (hasher, ty_id) = hashers.next_or_err()?;
216 let key_bytes = consume_hash_returning_key_bytes(bytes, hasher, ty_id, types)?;
217
218 let Some(key_bytes) = key_bytes else {
220 return Err(StorageAddressError::HasherCannotReconstructKey { ty_id, hasher }.into());
221 };
222
223 let key = K::decode_as_type(&mut &*key_bytes, ty_id, types)?;
225 let key = StaticStorageKey { key };
226 Ok(key)
227 }
228}
229
230impl StorageKey for Vec<scale_value::Value> {
231 fn encode_storage_key(
232 &self,
233 bytes: &mut Vec<u8>,
234 hashers: &mut StorageHashersIter,
235 types: &PortableRegistry,
236 ) -> Result<(), Error> {
237 for value in self.iter() {
238 let (hasher, ty_id) = hashers.next_or_err()?;
239 let encoded_value = value.encode_as_type(ty_id, types)?;
240 hash_bytes(&encoded_value, hasher, bytes);
241 }
242 Ok(())
243 }
244
245 fn decode_storage_key(
246 bytes: &mut &[u8],
247 hashers: &mut StorageHashersIter,
248 types: &PortableRegistry,
249 ) -> Result<Self, Error>
250 where
251 Self: Sized + 'static,
252 {
253 let mut result: Vec<scale_value::Value> = vec![];
254 for (hasher, ty_id) in hashers.by_ref() {
255 match consume_hash_returning_key_bytes(bytes, hasher, ty_id, types)? {
256 Some(value_bytes) => {
257 let value =
258 scale_value::scale::decode_as_type(&mut &*value_bytes, ty_id, types)?;
259
260 result.push(value.remove_context());
261 }
262 None => {
263 result.push(Value::unnamed_composite([]));
264 }
265 }
266 }
267
268 if !bytes.is_empty() {
270 return Err(StorageAddressError::TooManyBytes.into());
271 }
272
273 Ok(result)
274 }
275}
276
277fn consume_hash_returning_key_bytes<'a>(
280 bytes: &mut &'a [u8],
281 hasher: StorageHasher,
282 ty_id: u32,
283 types: &PortableRegistry,
284) -> Result<Option<&'a [u8]>, Error> {
285 let bytes_to_strip = hasher.len_excluding_key();
287 if bytes.len() < bytes_to_strip {
288 return Err(StorageAddressError::NotEnoughBytes.into());
289 }
290 *bytes = &bytes[bytes_to_strip..];
291
292 let before_key = *bytes;
294 if hasher.ends_with_key() {
295 scale_decode::visitor::decode_with_visitor(
296 bytes,
297 ty_id,
298 types,
299 IgnoreVisitor::<PortableRegistry>::new(),
300 )
301 .map_err(|err| Error::Decode(err.into()))?;
302 let key_bytes = &before_key[..before_key.len() - bytes.len()];
304
305 Ok(Some(key_bytes))
306 } else {
307 Ok(None)
309 }
310}
311
312macro_rules! impl_tuples {
314 ($($ty:ident $n:tt),+) => {{
315 impl<$($ty: StorageKey),+> StorageKey for ($( $ty ),+) {
316 fn encode_storage_key(
317 &self,
318 bytes: &mut Vec<u8>,
319 hashers: &mut StorageHashersIter,
320 types: &PortableRegistry,
321 ) -> Result<(), Error> {
322 $( self.$n.encode_storage_key(bytes, hashers, types)?; )+
323 Ok(())
324 }
325
326 fn decode_storage_key(
327 bytes: &mut &[u8],
328 hashers: &mut StorageHashersIter,
329 types: &PortableRegistry,
330 ) -> Result<Self, Error>
331 where
332 Self: Sized + 'static,
333 {
334 Ok( ( $( $ty::decode_storage_key(bytes, hashers, types)?, )+ ) )
335 }
336 }
337 }};
338}
339
340#[rustfmt::skip]
341const _: () = {
342 impl_tuples!(A 0, B 1);
343 impl_tuples!(A 0, B 1, C 2);
344 impl_tuples!(A 0, B 1, C 2, D 3);
345 impl_tuples!(A 0, B 1, C 2, D 3, E 4);
346 impl_tuples!(A 0, B 1, C 2, D 3, E 4, F 5);
347 impl_tuples!(A 0, B 1, C 2, D 3, E 4, F 5, G 6);
348 impl_tuples!(A 0, B 1, C 2, D 3, E 4, F 5, G 6, H 7);
349};
350
351#[cfg(test)]
352mod tests {
353
354 use codec::Encode;
355 use scale_info::{PortableRegistry, Registry, TypeInfo, meta_type};
356 use subxt_metadata::StorageHasher;
357
358 use crate::utils::Era;
359
360 use alloc::string::String;
361 use alloc::vec;
362 use alloc::vec::Vec;
363
364 use super::{StaticStorageKey, StorageKey};
365
366 struct KeyBuilder {
367 registry: Registry,
368 bytes: Vec<u8>,
369 hashers_and_ty_ids: Vec<(StorageHasher, u32)>,
370 }
371
372 impl KeyBuilder {
373 fn new() -> KeyBuilder {
374 KeyBuilder {
375 registry: Registry::new(),
376 bytes: vec![],
377 hashers_and_ty_ids: vec![],
378 }
379 }
380
381 fn add<T: TypeInfo + Encode + 'static>(mut self, value: T, hasher: StorageHasher) -> Self {
382 let id = self.registry.register_type(&meta_type::<T>()).id;
383
384 self.hashers_and_ty_ids.push((hasher, id));
385 for _i in 0..hasher.len_excluding_key() {
386 self.bytes.push(0);
387 }
388 value.encode_to(&mut self.bytes);
389 self
390 }
391
392 fn build(self) -> (PortableRegistry, Vec<u8>, Vec<(StorageHasher, u32)>) {
393 (self.registry.into(), self.bytes, self.hashers_and_ty_ids)
394 }
395 }
396
397 #[test]
398 fn storage_key_decoding_fuzz() {
399 let hashers = [
400 StorageHasher::Blake2_128,
401 StorageHasher::Blake2_128Concat,
402 StorageHasher::Blake2_256,
403 StorageHasher::Identity,
404 StorageHasher::Twox128,
405 StorageHasher::Twox256,
406 StorageHasher::Twox64Concat,
407 ];
408
409 let key_preserving_hashers = [
410 StorageHasher::Blake2_128Concat,
411 StorageHasher::Identity,
412 StorageHasher::Twox64Concat,
413 ];
414
415 type T4A = (
416 (),
417 StaticStorageKey<u32>,
418 StaticStorageKey<String>,
419 StaticStorageKey<Era>,
420 );
421 type T4B = (
422 (),
423 (StaticStorageKey<u32>, StaticStorageKey<String>),
424 StaticStorageKey<Era>,
425 );
426 type T4C = (
427 ((), StaticStorageKey<u32>),
428 (StaticStorageKey<String>, StaticStorageKey<Era>),
429 );
430
431 let era = Era::Immortal;
432 for h0 in hashers {
433 for h1 in key_preserving_hashers {
434 for h2 in key_preserving_hashers {
435 for h3 in key_preserving_hashers {
436 let (types, bytes, hashers_and_ty_ids) = KeyBuilder::new()
437 .add((), h0)
438 .add(13u32, h1)
439 .add("Hello", h2)
440 .add(era, h3)
441 .build();
442
443 let hashers = super::StorageHashers { hashers_and_ty_ids };
444 let keys_a =
445 T4A::decode_storage_key(&mut &bytes[..], &mut hashers.iter(), &types)
446 .unwrap();
447
448 let keys_b =
449 T4B::decode_storage_key(&mut &bytes[..], &mut hashers.iter(), &types)
450 .unwrap();
451
452 let keys_c =
453 T4C::decode_storage_key(&mut &bytes[..], &mut hashers.iter(), &types)
454 .unwrap();
455
456 assert_eq!(keys_a.1.into_key(), 13);
457 assert_eq!(keys_b.1.0.into_key(), 13);
458 assert_eq!(keys_c.0.1.into_key(), 13);
459
460 assert_eq!(keys_a.2.into_key(), "Hello");
461 assert_eq!(keys_b.1.1.into_key(), "Hello");
462 assert_eq!(keys_c.1.0.into_key(), "Hello");
463 assert_eq!(keys_a.3.into_key(), era);
464 assert_eq!(keys_b.2.into_key(), era);
465 assert_eq!(keys_c.1.1.into_key(), era);
466 }
467 }
468 }
469 }
470 }
471}