1use super::utils::hash_bytes;
6use crate::{
7 error::{Error, MetadataError, StorageAddressError},
8 utils::{Encoded, Static},
9};
10use alloc::vec;
11use alloc::vec::Vec;
12use derive_where::derive_where;
13use scale_decode::visitor::IgnoreVisitor;
14use scale_encode::EncodeAsType;
15use scale_info::{PortableRegistry, TypeDef};
16use scale_value::Value;
17use subxt_metadata::{StorageEntryType, StorageHasher};
18
19#[derive(Debug)]
22pub struct StorageHashers {
23 hashers_and_ty_ids: Vec<(StorageHasher, u32)>,
24}
25
26impl StorageHashers {
27 pub fn new(storage_entry: &StorageEntryType, types: &PortableRegistry) -> Result<Self, Error> {
30 let mut hashers_and_ty_ids = vec![];
31 if let StorageEntryType::Map {
32 hashers, key_ty, ..
33 } = storage_entry
34 {
35 let ty = types
36 .resolve(*key_ty)
37 .ok_or(MetadataError::TypeNotFound(*key_ty))?;
38
39 if let TypeDef::Tuple(tuple) = &ty.type_def {
40 if hashers.len() == 1 {
41 let hasher = hashers[0];
43 for f in tuple.fields.iter() {
44 hashers_and_ty_ids.push((hasher, f.id));
45 }
46 } else if hashers.len() < tuple.fields.len() {
47 return Err(StorageAddressError::WrongNumberOfHashers {
48 hashers: hashers.len(),
49 fields: tuple.fields.len(),
50 }
51 .into());
52 } else {
53 for (i, f) in tuple.fields.iter().enumerate() {
54 hashers_and_ty_ids.push((hashers[i], f.id));
55 }
56 }
57 } else {
58 if hashers.len() != 1 {
59 return Err(StorageAddressError::WrongNumberOfHashers {
60 hashers: hashers.len(),
61 fields: 1,
62 }
63 .into());
64 }
65 hashers_and_ty_ids.push((hashers[0], *key_ty));
66 };
67 }
68
69 Ok(Self { hashers_and_ty_ids })
70 }
71
72 pub fn iter(&self) -> StorageHashersIter<'_> {
74 StorageHashersIter {
75 hashers: self,
76 idx: 0,
77 }
78 }
79}
80
81#[derive(Debug)]
84pub struct StorageHashersIter<'a> {
85 hashers: &'a StorageHashers,
86 idx: usize,
87}
88
89impl StorageHashersIter<'_> {
90 fn next_or_err(&mut self) -> Result<(StorageHasher, u32), Error> {
91 self.next().ok_or_else(|| {
92 StorageAddressError::TooManyKeys {
93 expected: self.hashers.hashers_and_ty_ids.len(),
94 }
95 .into()
96 })
97 }
98}
99
100impl Iterator for StorageHashersIter<'_> {
101 type Item = (StorageHasher, u32);
102
103 fn next(&mut self) -> Option<Self::Item> {
104 let item = self.hashers.hashers_and_ty_ids.get(self.idx).copied()?;
105 self.idx += 1;
106 Some(item)
107 }
108}
109
110impl ExactSizeIterator for StorageHashersIter<'_> {
111 fn len(&self) -> usize {
112 self.hashers.hashers_and_ty_ids.len() - self.idx
113 }
114}
115
116pub trait StorageKey {
118 fn encode_storage_key(
120 &self,
121 bytes: &mut Vec<u8>,
122 hashers: &mut StorageHashersIter,
123 types: &PortableRegistry,
124 ) -> Result<(), Error>;
125
126 fn decode_storage_key(
131 bytes: &mut &[u8],
132 hashers: &mut StorageHashersIter,
133 types: &PortableRegistry,
134 ) -> Result<Self, Error>
135 where
136 Self: Sized + 'static;
137}
138
139impl StorageKey for () {
142 fn encode_storage_key(
143 &self,
144 _bytes: &mut Vec<u8>,
145 hashers: &mut StorageHashersIter,
146 _types: &PortableRegistry,
147 ) -> Result<(), Error> {
148 _ = hashers.next_or_err();
149 Ok(())
150 }
151
152 fn decode_storage_key(
153 bytes: &mut &[u8],
154 hashers: &mut StorageHashersIter,
155 types: &PortableRegistry,
156 ) -> Result<Self, Error> {
157 let (hasher, ty_id) = match hashers.next_or_err() {
158 Ok((hasher, ty_id)) => (hasher, ty_id),
159 Err(_) if bytes.is_empty() => return Ok(()),
160 Err(err) => return Err(err),
161 };
162 consume_hash_returning_key_bytes(bytes, hasher, ty_id, types)?;
163 Ok(())
164 }
165}
166
167#[derive_where(Clone, Debug, PartialOrd, PartialEq, Eq)]
170pub struct StaticStorageKey<K: ?Sized> {
171 bytes: Static<Encoded>,
172 _marker: core::marker::PhantomData<K>,
173}
174
175impl<K: codec::Encode + ?Sized> StaticStorageKey<K> {
176 pub fn new(key: &K) -> Self {
178 StaticStorageKey {
179 bytes: Static(Encoded(key.encode())),
180 _marker: core::marker::PhantomData,
181 }
182 }
183}
184
185impl<K: codec::Decode> StaticStorageKey<K> {
186 pub fn decoded(&self) -> Result<K, Error> {
188 let decoded = K::decode(&mut self.bytes())?;
189 Ok(decoded)
190 }
191}
192
193impl<K: ?Sized> StaticStorageKey<K> {
194 pub fn bytes(&self) -> &[u8] {
196 &self.bytes.0 .0
197 }
198}
199
200impl<K: ?Sized> StorageKey for StaticStorageKey<K> {
202 fn encode_storage_key(
203 &self,
204 bytes: &mut Vec<u8>,
205 hashers: &mut StorageHashersIter,
206 types: &PortableRegistry,
207 ) -> Result<(), Error> {
208 let (hasher, ty_id) = hashers.next_or_err()?;
209 let encoded_value = self.bytes.encode_as_type(ty_id, types)?;
210 hash_bytes(&encoded_value, hasher, bytes);
211 Ok(())
212 }
213
214 fn decode_storage_key(
215 bytes: &mut &[u8],
216 hashers: &mut StorageHashersIter,
217 types: &PortableRegistry,
218 ) -> Result<Self, Error>
219 where
220 Self: Sized + 'static,
221 {
222 let (hasher, ty_id) = hashers.next_or_err()?;
223 let key_bytes = consume_hash_returning_key_bytes(bytes, hasher, ty_id, types)?;
224
225 let Some(key_bytes) = key_bytes else {
227 return Err(StorageAddressError::HasherCannotReconstructKey { ty_id, hasher }.into());
228 };
229
230 let key = StaticStorageKey {
232 bytes: Static(Encoded(key_bytes.to_vec())),
233 _marker: core::marker::PhantomData::<K>,
234 };
235 Ok(key)
236 }
237}
238
239impl StorageKey for Vec<scale_value::Value> {
240 fn encode_storage_key(
241 &self,
242 bytes: &mut Vec<u8>,
243 hashers: &mut StorageHashersIter,
244 types: &PortableRegistry,
245 ) -> Result<(), Error> {
246 for value in self.iter() {
247 let (hasher, ty_id) = hashers.next_or_err()?;
248 let encoded_value = value.encode_as_type(ty_id, types)?;
249 hash_bytes(&encoded_value, hasher, bytes);
250 }
251 Ok(())
252 }
253
254 fn decode_storage_key(
255 bytes: &mut &[u8],
256 hashers: &mut StorageHashersIter,
257 types: &PortableRegistry,
258 ) -> Result<Self, Error>
259 where
260 Self: Sized + 'static,
261 {
262 let mut result: Vec<scale_value::Value> = vec![];
263 for (hasher, ty_id) in hashers.by_ref() {
264 match consume_hash_returning_key_bytes(bytes, hasher, ty_id, types)? {
265 Some(value_bytes) => {
266 let value =
267 scale_value::scale::decode_as_type(&mut &*value_bytes, ty_id, types)?;
268
269 result.push(value.remove_context());
270 }
271 None => {
272 result.push(Value::unnamed_composite([]));
273 }
274 }
275 }
276
277 if !bytes.is_empty() {
279 return Err(StorageAddressError::TooManyBytes.into());
280 }
281
282 Ok(result)
283 }
284}
285
286fn consume_hash_returning_key_bytes<'a>(
289 bytes: &mut &'a [u8],
290 hasher: StorageHasher,
291 ty_id: u32,
292 types: &PortableRegistry,
293) -> Result<Option<&'a [u8]>, Error> {
294 let bytes_to_strip = hasher.len_excluding_key();
296 if bytes.len() < bytes_to_strip {
297 return Err(StorageAddressError::NotEnoughBytes.into());
298 }
299 *bytes = &bytes[bytes_to_strip..];
300
301 let before_key = *bytes;
303 if hasher.ends_with_key() {
304 scale_decode::visitor::decode_with_visitor(
305 bytes,
306 ty_id,
307 types,
308 IgnoreVisitor::<PortableRegistry>::new(),
309 )
310 .map_err(|err| Error::Decode(err.into()))?;
311 let key_bytes = &before_key[..before_key.len() - bytes.len()];
313
314 Ok(Some(key_bytes))
315 } else {
316 Ok(None)
318 }
319}
320
321macro_rules! impl_tuples {
323 ($($ty:ident $n:tt),+) => {{
324 impl<$($ty: StorageKey),+> StorageKey for ($( $ty ),+) {
325 fn encode_storage_key(
326 &self,
327 bytes: &mut Vec<u8>,
328 hashers: &mut StorageHashersIter,
329 types: &PortableRegistry,
330 ) -> Result<(), Error> {
331 $( self.$n.encode_storage_key(bytes, hashers, types)?; )+
332 Ok(())
333 }
334
335 fn decode_storage_key(
336 bytes: &mut &[u8],
337 hashers: &mut StorageHashersIter,
338 types: &PortableRegistry,
339 ) -> Result<Self, Error>
340 where
341 Self: Sized + 'static,
342 {
343 Ok( ( $( $ty::decode_storage_key(bytes, hashers, types)?, )+ ) )
344 }
345 }
346 }};
347}
348
349#[rustfmt::skip]
350const _: () = {
351 impl_tuples!(A 0, B 1);
352 impl_tuples!(A 0, B 1, C 2);
353 impl_tuples!(A 0, B 1, C 2, D 3);
354 impl_tuples!(A 0, B 1, C 2, D 3, E 4);
355 impl_tuples!(A 0, B 1, C 2, D 3, E 4, F 5);
356 impl_tuples!(A 0, B 1, C 2, D 3, E 4, F 5, G 6);
357 impl_tuples!(A 0, B 1, C 2, D 3, E 4, F 5, G 6, H 7);
358};
359
360#[cfg(test)]
361mod tests {
362
363 use codec::Encode;
364 use scale_info::{meta_type, PortableRegistry, Registry, TypeInfo};
365 use subxt_metadata::StorageHasher;
366
367 use crate::utils::Era;
368
369 use alloc::string::String;
370 use alloc::vec;
371 use alloc::vec::Vec;
372
373 use super::{StaticStorageKey, StorageKey};
374
375 struct KeyBuilder {
376 registry: Registry,
377 bytes: Vec<u8>,
378 hashers_and_ty_ids: Vec<(StorageHasher, u32)>,
379 }
380
381 impl KeyBuilder {
382 fn new() -> KeyBuilder {
383 KeyBuilder {
384 registry: Registry::new(),
385 bytes: vec![],
386 hashers_and_ty_ids: vec![],
387 }
388 }
389
390 fn add<T: TypeInfo + Encode + 'static>(mut self, value: T, hasher: StorageHasher) -> Self {
391 let id = self.registry.register_type(&meta_type::<T>()).id;
392
393 self.hashers_and_ty_ids.push((hasher, id));
394 for _i in 0..hasher.len_excluding_key() {
395 self.bytes.push(0);
396 }
397 value.encode_to(&mut self.bytes);
398 self
399 }
400
401 fn build(self) -> (PortableRegistry, Vec<u8>, Vec<(StorageHasher, u32)>) {
402 (self.registry.into(), self.bytes, self.hashers_and_ty_ids)
403 }
404 }
405
406 #[test]
407 fn storage_key_decoding_fuzz() {
408 let hashers = [
409 StorageHasher::Blake2_128,
410 StorageHasher::Blake2_128Concat,
411 StorageHasher::Blake2_256,
412 StorageHasher::Identity,
413 StorageHasher::Twox128,
414 StorageHasher::Twox256,
415 StorageHasher::Twox64Concat,
416 ];
417
418 let key_preserving_hashers = [
419 StorageHasher::Blake2_128Concat,
420 StorageHasher::Identity,
421 StorageHasher::Twox64Concat,
422 ];
423
424 type T4A = (
425 (),
426 StaticStorageKey<u32>,
427 StaticStorageKey<String>,
428 StaticStorageKey<Era>,
429 );
430 type T4B = (
431 (),
432 (StaticStorageKey<u32>, StaticStorageKey<String>),
433 StaticStorageKey<Era>,
434 );
435 type T4C = (
436 ((), StaticStorageKey<u32>),
437 (StaticStorageKey<String>, StaticStorageKey<Era>),
438 );
439
440 let era = Era::Immortal;
441 for h0 in hashers {
442 for h1 in key_preserving_hashers {
443 for h2 in key_preserving_hashers {
444 for h3 in key_preserving_hashers {
445 let (types, bytes, hashers_and_ty_ids) = KeyBuilder::new()
446 .add((), h0)
447 .add(13u32, h1)
448 .add("Hello", h2)
449 .add(era, h3)
450 .build();
451
452 let hashers = super::StorageHashers { hashers_and_ty_ids };
453 let keys_a =
454 T4A::decode_storage_key(&mut &bytes[..], &mut hashers.iter(), &types)
455 .unwrap();
456
457 let keys_b =
458 T4B::decode_storage_key(&mut &bytes[..], &mut hashers.iter(), &types)
459 .unwrap();
460
461 let keys_c =
462 T4C::decode_storage_key(&mut &bytes[..], &mut hashers.iter(), &types)
463 .unwrap();
464
465 assert_eq!(keys_a.1.decoded().unwrap(), 13);
466 assert_eq!(keys_b.1 .0.decoded().unwrap(), 13);
467 assert_eq!(keys_c.0 .1.decoded().unwrap(), 13);
468
469 assert_eq!(keys_a.2.decoded().unwrap(), "Hello");
470 assert_eq!(keys_b.1 .1.decoded().unwrap(), "Hello");
471 assert_eq!(keys_c.1 .0.decoded().unwrap(), "Hello");
472 assert_eq!(keys_a.3.decoded().unwrap(), era);
473 assert_eq!(keys_b.2.decoded().unwrap(), era);
474 assert_eq!(keys_c.1 .1.decoded().unwrap(), era);
475 }
476 }
477 }
478 }
479 }
480}