unc_primitives/
shard_layout.rs

1use crate::hash::CryptoHash;
2use crate::types::{AccountId, NumShards};
3use borsh::{BorshDeserialize, BorshSerialize};
4use std::collections::HashMap;
5use std::{fmt, str};
6use unc_primitives_core::types::ShardId;
7
8/// This file implements two data structure `ShardLayout` and `ShardUId`
9///
10/// `ShardLayout`
11/// A versioned struct that contains all information needed to assign accounts
12/// to shards. Because of re-sharding, the chain may use different shard layout to
13/// split shards at different times.
14/// Currently, `ShardLayout` is stored as part of `EpochConfig`, which is generated each epoch
15/// given the epoch protocol version.
16/// In mainnet/testnet, we use two shard layouts since re-sharding has only happened once.
17/// It is stored as part of genesis config, see default_simple_nightshade_shard_layout()
18/// Below is an overview for some important functionalities of ShardLayout interface.
19///
20/// `version`
21/// `ShardLayout` has a version number. The version number should increment as when sharding changes.
22/// This guarantees the version number is unique across different shard layouts, which in turn guarantees
23/// `ShardUId` is different across shards from different shard layouts, as `ShardUId` includes
24/// `version` and `shard_id`
25///
26/// `get_parent_shard_id` and `get_split_shard_ids`
27/// `ShardLayout` also includes information needed for resharding. In particular, it encodes
28/// which shards from the previous shard layout split to which shards in the following shard layout.
29/// If shard A in shard layout 0 splits to shard B and C in shard layout 1,
30/// we call shard A the parent shard of shard B and C.
31/// Note that a shard can only have one parent shard. For example, the following case will be prohibited,
32/// a shard C in shard layout 1 contains accounts in both shard A and B in shard layout 0.
33/// Parent/split shard information can be accessed through these two functions.
34///
35/// `account_id_to_shard_id`
36///  Maps an account to the shard that it belongs to given a shard_layout
37///
38/// `ShardUId`
39/// `ShardUId` is a unique representation for shards from different shard layouts.
40/// Comparing to `ShardId`, which is just an ordinal number ranging from 0 to NUM_SHARDS-1,
41/// `ShardUId` provides a way to unique identify shards when shard layouts may change across epochs.
42/// This is important because we store states indexed by shards in our database, so we need a
43/// way to unique identify shard even when shards change across epochs.
44/// Another difference between `ShardUId` and `ShardId` is that `ShardUId` should only exist in
45/// a node's internal state while `ShardId` can be exposed to outside APIs and used in protocol
46/// level information (for example, `ShardChunkHeader` contains `ShardId` instead of `ShardUId`)
47
48pub type ShardVersion = u32;
49
50#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
51pub enum ShardLayout {
52    V0(ShardLayoutV0),
53    V1(ShardLayoutV1),
54}
55
56/// A shard layout that maps accounts evenly across all shards -- by calculate the hash of account
57/// id and mod number of shards. This is added to capture the old `account_id_to_shard_id` algorithm,
58/// to keep backward compatibility for some existing tests.
59/// `parent_shards` for `ShardLayoutV1` is always `None`, meaning it can only be the first shard layout
60/// a chain uses.
61#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
62pub struct ShardLayoutV0 {
63    /// Map accounts evenly across all shards
64    num_shards: NumShards,
65    /// Version of the shard layout, this is useful for uniquely identify the shard layout
66    version: ShardVersion,
67}
68
69/// A map that maps shards from the last shard layout to shards that it splits to in this shard layout.
70/// Instead of using map, we just use a vec here because shard_id ranges from 0 to num_shards-1
71/// For example, if a shard layout with only shard 0 splits into shards 0, 1, 2, 3, the ShardsSplitMap
72/// will be `[[0, 1, 2, 3]]`
73type ShardSplitMap = Vec<Vec<ShardId>>;
74
75#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
76pub struct ShardLayoutV1 {
77    /// The boundary accounts are the accounts on boundaries between shards.
78    /// Each shard contains a range of accounts from one boundary account to
79    /// another - or the the smallest or largest account possible. The total
80    /// number of shards is equal to the number of boundary accounts plus 1.
81    boundary_accounts: Vec<AccountId>,
82    /// Maps shards from the last shard layout to shards that it splits to in this shard layout,
83    /// Useful for constructing states for the shards.
84    /// None for the genesis shard layout
85    shards_split_map: Option<ShardSplitMap>,
86    /// Maps shard in this shard layout to their parent shard
87    /// Since shard_ids always range from 0 to num_shards - 1, we use vec instead of a hashmap
88    to_parent_shard_map: Option<Vec<ShardId>>,
89    /// Version of the shard layout, this is useful for uniquely identify the shard layout
90    version: ShardVersion,
91}
92
93#[derive(Debug)]
94pub enum ShardLayoutError {
95    InvalidShardIdError { shard_id: ShardId },
96}
97
98impl ShardLayout {
99    /* Some constructors */
100    pub fn v0_single_shard() -> Self {
101        Self::v0(1, 0)
102    }
103
104    /// Return a V0 Shardlayout
105    pub fn v0(num_shards: NumShards, version: ShardVersion) -> Self {
106        Self::V0(ShardLayoutV0 { num_shards, version })
107    }
108
109    /// Return a V1 Shardlayout
110    pub fn v1(
111        boundary_accounts: Vec<AccountId>,
112        shards_split_map: Option<ShardSplitMap>,
113        version: ShardVersion,
114    ) -> Self {
115        let to_parent_shard_map = if let Some(shards_split_map) = &shards_split_map {
116            let mut to_parent_shard_map = HashMap::new();
117            let num_shards = (boundary_accounts.len() + 1) as NumShards;
118            for (parent_shard_id, shard_ids) in shards_split_map.iter().enumerate() {
119                for &shard_id in shard_ids {
120                    let prev = to_parent_shard_map.insert(shard_id, parent_shard_id as ShardId);
121                    assert!(prev.is_none(), "no shard should appear in the map twice");
122                    assert!(shard_id < num_shards, "shard id should be valid");
123                }
124            }
125            Some((0..num_shards).map(|shard_id| to_parent_shard_map[&shard_id]).collect())
126        } else {
127            None
128        };
129        Self::V1(ShardLayoutV1 {
130            boundary_accounts,
131            shards_split_map,
132            to_parent_shard_map,
133            version,
134        })
135    }
136
137    /// Returns a V0 ShardLayout. It is only used in tests
138    pub fn v0_test() -> Self {
139        ShardLayout::v0(1, 0)
140    }
141
142    /// Given a parent shard id, return the shard uids for the shards in the current shard layout that
143    /// are split from this parent shard. If this shard layout has no parent shard layout, return None
144    pub fn get_children_shards_uids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardUId>> {
145        self.get_children_shards_ids(parent_shard_id).map(|shards| {
146            shards.into_iter().map(|id| ShardUId::from_shard_id_and_layout(id, self)).collect()
147        })
148    }
149
150    /// Given a parent shard id, return the shard ids for the shards in the current shard layout that
151    /// are split from this parent shard. If this shard layout has no parent shard layout, return None
152    pub fn get_children_shards_ids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardId>> {
153        match self {
154            Self::V0(_) => None,
155            Self::V1(v1) => match &v1.shards_split_map {
156                Some(shards_split_map) => shards_split_map.get(parent_shard_id as usize).cloned(),
157                None => None,
158            },
159        }
160    }
161
162    /// Return the parent shard id for a given shard in the shard layout
163    /// Only calls this function for shard layout that has parent shard layouts
164    /// Returns error if `shard_id` is an invalid shard id in the current layout
165    /// Panics if `self` has no parent shard layout
166    pub fn get_parent_shard_id(&self, shard_id: ShardId) -> Result<ShardId, ShardLayoutError> {
167        if !self.shard_ids().any(|id| id == shard_id) {
168            return Err(ShardLayoutError::InvalidShardIdError { shard_id });
169        }
170        let parent_shard_id = match self {
171            Self::V0(_) => panic!("shard layout has no parent shard"),
172            Self::V1(v1) => match &v1.to_parent_shard_map {
173                // we can safely unwrap here because the construction of to_parent_shard_map guarantees
174                // that every shard has a parent shard
175                Some(to_parent_shard_map) => *to_parent_shard_map.get(shard_id as usize).unwrap(),
176                None => panic!("shard_layout has no parent shard"),
177            },
178        };
179        Ok(parent_shard_id)
180    }
181
182    #[inline]
183    pub fn version(&self) -> ShardVersion {
184        match self {
185            Self::V0(v0) => v0.version,
186            Self::V1(v1) => v1.version,
187        }
188    }
189
190    fn num_shards(&self) -> NumShards {
191        match self {
192            Self::V0(v0) => v0.num_shards,
193            Self::V1(v1) => (v1.boundary_accounts.len() + 1) as NumShards,
194        }
195    }
196
197    pub fn shard_ids(&self) -> impl Iterator<Item = ShardId> {
198        0..self.num_shards()
199    }
200
201    /// Returns an iterator that iterates over all the shard uids for all the
202    /// shards in the shard layout
203    pub fn shard_uids(&self) -> impl Iterator<Item = ShardUId> + '_ {
204        self.shard_ids().map(|shard_id| ShardUId::from_shard_id_and_layout(shard_id, self))
205    }
206}
207
208/// Maps an account to the shard that it belongs to given a shard_layout
209/// For V0, maps according to hash of account id
210/// For V1, accounts are divided to ranges, each range of account is mapped to a shard.
211pub fn account_id_to_shard_id(account_id: &AccountId, shard_layout: &ShardLayout) -> ShardId {
212    match shard_layout {
213        ShardLayout::V0(ShardLayoutV0 { num_shards, .. }) => {
214            let hash = CryptoHash::hash_bytes(account_id.as_bytes());
215            let (bytes, _) = stdx::split_array::<32, 8, 24>(hash.as_bytes());
216            u64::from_le_bytes(*bytes) % num_shards
217        }
218        ShardLayout::V1(ShardLayoutV1 { boundary_accounts, .. }) => {
219            // Note: As we scale up the number of shards we can consider
220            // changing this method to do a binary search rather than linear
221            // scan. For the time being, with only 4 shards, this is perfectly fine.
222            let mut shard_id: ShardId = 0;
223            for boundary_account in boundary_accounts {
224                if account_id < boundary_account {
225                    break;
226                }
227                shard_id += 1;
228            }
229            shard_id
230        }
231    }
232}
233
234/// Maps an account to the shard that it belongs to given a shard_layout
235pub fn account_id_to_shard_uid(account_id: &AccountId, shard_layout: &ShardLayout) -> ShardUId {
236    ShardUId::from_shard_id_and_layout(
237        account_id_to_shard_id(account_id, shard_layout),
238        shard_layout,
239    )
240}
241
242/// ShardUId is an unique representation for shards from different shard layout
243#[derive(BorshSerialize, BorshDeserialize, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
244pub struct ShardUId {
245    pub version: ShardVersion,
246    pub shard_id: u32,
247}
248
249impl ShardUId {
250    pub fn single_shard() -> Self {
251        Self { version: 0, shard_id: 0 }
252    }
253
254    /// Byte representation of the shard uid
255    pub fn to_bytes(&self) -> [u8; 8] {
256        let mut res = [0; 8];
257        res[0..4].copy_from_slice(&u32::to_le_bytes(self.version));
258        res[4..].copy_from_slice(&u32::to_le_bytes(self.shard_id));
259        res
260    }
261
262    pub fn next_shard_prefix(shard_uid_bytes: &[u8; 8]) -> [u8; 8] {
263        let mut result = *shard_uid_bytes;
264        for i in (0..8).rev() {
265            if result[i] == u8::MAX {
266                result[i] = 0;
267            } else {
268                result[i] += 1;
269                return result;
270            }
271        }
272        panic!("Next shard prefix for shard bytes {shard_uid_bytes:?} does not exist");
273    }
274
275    /// Constructs a shard uid from shard id and a shard layout
276    pub fn from_shard_id_and_layout(shard_id: ShardId, shard_layout: &ShardLayout) -> Self {
277        assert!(shard_layout.shard_ids().any(|i| i == shard_id));
278        Self { shard_id: shard_id as u32, version: shard_layout.version() }
279    }
280
281    /// Returns shard id
282    pub fn shard_id(&self) -> ShardId {
283        ShardId::from(self.shard_id)
284    }
285}
286
287impl TryFrom<&[u8]> for ShardUId {
288    type Error = Box<dyn std::error::Error + Send + Sync>;
289
290    /// Deserialize `bytes` to shard uid
291    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
292        if bytes.len() != 8 {
293            return Err("incorrect length for ShardUId".into());
294        }
295        let version = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
296        let shard_id = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
297        Ok(Self { version, shard_id })
298    }
299}
300
301/// Returns the byte representation for (block, shard_uid)
302pub fn get_block_shard_uid(block_hash: &CryptoHash, shard_uid: &ShardUId) -> Vec<u8> {
303    let mut res = Vec::with_capacity(40);
304    res.extend_from_slice(block_hash.as_ref());
305    res.extend_from_slice(&shard_uid.to_bytes());
306    res
307}
308
309/// Deserialize from a byte representation to (block, shard_uid)
310#[allow(unused)]
311pub fn get_block_shard_uid_rev(
312    key: &[u8],
313) -> Result<(CryptoHash, ShardUId), Box<dyn std::error::Error + Send + Sync>> {
314    if key.len() != 40 {
315        return Err(
316            std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid key length").into()
317        );
318    }
319    let block_hash = CryptoHash::try_from(&key[..32])?;
320    let shard_id = ShardUId::try_from(&key[32..])?;
321    Ok((block_hash, shard_id))
322}
323
324impl fmt::Display for ShardUId {
325    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326        write!(f, "s{}.v{}", self.shard_id, self.version)
327    }
328}
329
330impl fmt::Debug for ShardUId {
331    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
332        fmt::Display::fmt(self, f)
333    }
334}
335
336impl str::FromStr for ShardUId {
337    type Err = String;
338
339    fn from_str(s: &str) -> Result<Self, Self::Err> {
340        let (shard_str, version_str) = s
341            .split_once(".")
342            .ok_or_else(|| "shard version and number must be separated by \".\"".to_string())?;
343
344        let version = version_str
345            .strip_prefix("v")
346            .ok_or_else(|| "shard version must start with \"v\"".to_string())?
347            .parse::<ShardVersion>()
348            .map_err(|e| format!("shard version after \"v\" must be a number, {e}"))?;
349
350        let shard_str = shard_str
351            .strip_prefix("s")
352            .ok_or_else(|| "shard id must start with \"s\"".to_string())?;
353        let shard_id = shard_str
354            .parse::<u32>()
355            .map_err(|e| format!("shard id after \"s\" must be a number, {e}"))?;
356
357        Ok(ShardUId { shard_id, version })
358    }
359}
360
361impl<'de> serde::Deserialize<'de> for ShardUId {
362    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
363    where
364        D: serde::Deserializer<'de>,
365    {
366        deserializer.deserialize_any(ShardUIdVisitor)
367    }
368}
369
370impl serde::Serialize for ShardUId {
371    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
372    where
373        S: serde::Serializer,
374    {
375        serializer.serialize_str(&self.to_string())
376    }
377}
378
379struct ShardUIdVisitor;
380impl<'de> serde::de::Visitor<'de> for ShardUIdVisitor {
381    type Value = ShardUId;
382
383    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
384        write!(
385            formatter,
386            "either string format of `ShardUId` like s0v1 for shard 0 version 1, or a map"
387        )
388    }
389
390    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
391    where
392        E: serde::de::Error,
393    {
394        v.parse().map_err(|e| E::custom(e))
395    }
396
397    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
398    where
399        A: serde::de::MapAccess<'de>,
400    {
401        // custom struct deserialization for backwards compatibility
402        // TODO(#7894): consider removing this code after checking
403        // `ShardUId` is nowhere serialized in the old format
404        let mut version = None;
405        let mut shard_id = None;
406
407        while let Some((field, value)) = map.next_entry()? {
408            match field {
409                "version" => version = Some(value),
410                "shard_id" => shard_id = Some(value),
411                _ => return Err(serde::de::Error::unknown_field(field, &["version", "shard_id"])),
412            }
413        }
414
415        match (version, shard_id) {
416            (None, _) => Err(serde::de::Error::missing_field("version")),
417            (_, None) => Err(serde::de::Error::missing_field("shard_id")),
418            (Some(version), Some(shard_id)) => Ok(ShardUId { version, shard_id }),
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use crate::shard_layout::{account_id_to_shard_id, ShardLayout, ShardLayoutV1, ShardUId};
426    use rand::distributions::Alphanumeric;
427    use rand::rngs::StdRng;
428    use rand::{Rng, SeedableRng};
429    use std::collections::HashMap;
430    use unc_primitives_core::types::{AccountId, ShardId};
431
432    use super::{ShardSplitMap, ShardVersion};
433
434    // The old ShardLayoutV1, before fixed shards were removed. tests only
435    #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
436    pub struct OldShardLayoutV1 {
437        /// num_shards = fixed_shards.len() + boundary_accounts.len() + 1
438        /// Each account and all sub-accounts map to the shard of position in this array.
439        fixed_shards: Vec<AccountId>,
440        /// The rest are divided by boundary_accounts to ranges, each range is mapped to a shard
441        boundary_accounts: Vec<AccountId>,
442        /// Maps shards from the last shard layout to shards that it splits to in this shard layout,
443        /// Useful for constructing states for the shards.
444        /// None for the genesis shard layout
445        shards_split_map: Option<ShardSplitMap>,
446        /// Maps shard in this shard layout to their parent shard
447        /// Since shard_ids always range from 0 to num_shards - 1, we use vec instead of a hashmap
448        to_parent_shard_map: Option<Vec<ShardId>>,
449        /// Version of the shard layout, this is useful for uniquely identify the shard layout
450        version: ShardVersion,
451    }
452
453    #[test]
454    fn test_shard_layout_v0() {
455        let num_shards = 4;
456        let shard_layout = ShardLayout::v0(num_shards, 0);
457        let mut shard_id_distribution: HashMap<_, _> =
458            shard_layout.shard_ids().map(|shard_id| (shard_id, 0)).collect();
459        let mut rng = StdRng::from_seed([0; 32]);
460        for _i in 0..1000 {
461            let s: Vec<u8> = (&mut rng).sample_iter(&Alphanumeric).take(10).collect();
462            let s = String::from_utf8(s).unwrap();
463            let account_id = s.to_lowercase().parse().unwrap();
464            let shard_id = account_id_to_shard_id(&account_id, &shard_layout);
465            assert!(shard_id < num_shards);
466            *shard_id_distribution.get_mut(&shard_id).unwrap() += 1;
467        }
468        let expected_distribution: HashMap<_, _> =
469            [(0, 247), (1, 268), (2, 233), (3, 252)].into_iter().collect();
470        assert_eq!(shard_id_distribution, expected_distribution);
471    }
472
473    #[test]
474    fn test_shard_layout_v1() {
475        let shard_layout = ShardLayout::v1(
476            parse_account_ids(&["aurora", "bar", "foo", "foo.baz", "paz"]),
477            Some(vec![vec![0, 1, 2], vec![3, 4, 5]]),
478            1,
479        );
480        assert_eq!(
481            shard_layout.get_children_shards_uids(0).unwrap(),
482            (0..3).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
483        );
484        assert_eq!(
485            shard_layout.get_children_shards_uids(1).unwrap(),
486            (3..6).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
487        );
488        for x in 0..3 {
489            assert_eq!(shard_layout.get_parent_shard_id(x).unwrap(), 0);
490            assert_eq!(shard_layout.get_parent_shard_id(x + 3).unwrap(), 1);
491        }
492
493        assert_eq!(account_id_to_shard_id(&"aurora".parse().unwrap(), &shard_layout), 1);
494        assert_eq!(account_id_to_shard_id(&"foo.aurora".parse().unwrap(), &shard_layout), 3);
495        assert_eq!(account_id_to_shard_id(&"bar.foo.aurora".parse().unwrap(), &shard_layout), 2);
496        assert_eq!(account_id_to_shard_id(&"bar".parse().unwrap(), &shard_layout), 2);
497        assert_eq!(account_id_to_shard_id(&"bar.bar".parse().unwrap(), &shard_layout), 2);
498        assert_eq!(account_id_to_shard_id(&"foo".parse().unwrap(), &shard_layout), 3);
499        assert_eq!(account_id_to_shard_id(&"baz.foo".parse().unwrap(), &shard_layout), 2);
500        assert_eq!(account_id_to_shard_id(&"foo.baz".parse().unwrap(), &shard_layout), 4);
501        assert_eq!(account_id_to_shard_id(&"a.foo.baz".parse().unwrap(), &shard_layout), 0);
502
503        assert_eq!(account_id_to_shard_id(&"aaa".parse().unwrap(), &shard_layout), 0);
504        assert_eq!(account_id_to_shard_id(&"abc".parse().unwrap(), &shard_layout), 0);
505        assert_eq!(account_id_to_shard_id(&"bbb".parse().unwrap(), &shard_layout), 2);
506        assert_eq!(account_id_to_shard_id(&"foo.goo".parse().unwrap(), &shard_layout), 4);
507        assert_eq!(account_id_to_shard_id(&"goo".parse().unwrap(), &shard_layout), 4);
508        assert_eq!(account_id_to_shard_id(&"zoo".parse().unwrap(), &shard_layout), 5);
509    }
510
511    // check that after removing the fixed shards from the shard layout v1
512    // the fixed shards are skipped in deserialization
513    // this should be the default as long as serde(deny_unknown_fields) is not set
514    #[test]
515    fn test_remove_fixed_shards() {
516        let old = OldShardLayoutV1 {
517            fixed_shards: vec![],
518            boundary_accounts: parse_account_ids(&["aaa", "bbb"]),
519            shards_split_map: Some(vec![vec![0, 1, 2]]),
520            to_parent_shard_map: Some(vec![0, 0, 0]),
521            version: 1,
522        };
523        let json = serde_json::to_string_pretty(&old).unwrap();
524        println!("json");
525        println!("{json:#?}");
526
527        let new = serde_json::from_str::<ShardLayoutV1>(json.as_str()).unwrap();
528        assert_eq!(old.boundary_accounts, new.boundary_accounts);
529        assert_eq!(old.shards_split_map, new.shards_split_map);
530        assert_eq!(old.to_parent_shard_map, new.to_parent_shard_map);
531        assert_eq!(old.version, new.version);
532    }
533
534    fn parse_account_ids(ids: &[&str]) -> Vec<AccountId> {
535        ids.into_iter().map(|a| a.parse().unwrap()).collect()
536    }
537
538    #[test]
539    fn test_shard_layout_all() {
540        let v0 = ShardLayout::v0(1, 0);
541
542        insta::assert_snapshot!(serde_json::to_string_pretty(&v0).unwrap(), @r###"
543        {
544          "V0": {
545            "num_shards": 1,
546            "version": 0
547          }
548        }
549        "###);
550    }
551}