use crate::hash::CryptoHash;
use crate::types::{AccountId, NumShards};
use borsh::{BorshDeserialize, BorshSerialize};
use std::collections::HashMap;
use std::{fmt, str};
use unc_primitives_core::types::ShardId;
pub type ShardVersion = u32;
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
pub enum ShardLayout {
V0(ShardLayoutV0),
V1(ShardLayoutV1),
}
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct ShardLayoutV0 {
num_shards: NumShards,
version: ShardVersion,
}
type ShardSplitMap = Vec<Vec<ShardId>>;
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct ShardLayoutV1 {
boundary_accounts: Vec<AccountId>,
shards_split_map: Option<ShardSplitMap>,
to_parent_shard_map: Option<Vec<ShardId>>,
version: ShardVersion,
}
#[derive(Debug)]
pub enum ShardLayoutError {
InvalidShardIdError { shard_id: ShardId },
}
impl ShardLayout {
pub fn v0_single_shard() -> Self {
Self::v0(1, 0)
}
pub fn v0(num_shards: NumShards, version: ShardVersion) -> Self {
Self::V0(ShardLayoutV0 { num_shards, version })
}
pub fn v1(
boundary_accounts: Vec<AccountId>,
shards_split_map: Option<ShardSplitMap>,
version: ShardVersion,
) -> Self {
let to_parent_shard_map = if let Some(shards_split_map) = &shards_split_map {
let mut to_parent_shard_map = HashMap::new();
let num_shards = (boundary_accounts.len() + 1) as NumShards;
for (parent_shard_id, shard_ids) in shards_split_map.iter().enumerate() {
for &shard_id in shard_ids {
let prev = to_parent_shard_map.insert(shard_id, parent_shard_id as ShardId);
assert!(prev.is_none(), "no shard should appear in the map twice");
assert!(shard_id < num_shards, "shard id should be valid");
}
}
Some((0..num_shards).map(|shard_id| to_parent_shard_map[&shard_id]).collect())
} else {
None
};
Self::V1(ShardLayoutV1 {
boundary_accounts,
shards_split_map,
to_parent_shard_map,
version,
})
}
pub fn v0_test() -> Self {
ShardLayout::v0(1, 0)
}
pub fn get_children_shards_uids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardUId>> {
self.get_children_shards_ids(parent_shard_id).map(|shards| {
shards.into_iter().map(|id| ShardUId::from_shard_id_and_layout(id, self)).collect()
})
}
pub fn get_children_shards_ids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardId>> {
match self {
Self::V0(_) => None,
Self::V1(v1) => match &v1.shards_split_map {
Some(shards_split_map) => shards_split_map.get(parent_shard_id as usize).cloned(),
None => None,
},
}
}
pub fn get_parent_shard_id(&self, shard_id: ShardId) -> Result<ShardId, ShardLayoutError> {
if !self.shard_ids().any(|id| id == shard_id) {
return Err(ShardLayoutError::InvalidShardIdError { shard_id });
}
let parent_shard_id = match self {
Self::V0(_) => panic!("shard layout has no parent shard"),
Self::V1(v1) => match &v1.to_parent_shard_map {
Some(to_parent_shard_map) => *to_parent_shard_map.get(shard_id as usize).unwrap(),
None => panic!("shard_layout has no parent shard"),
},
};
Ok(parent_shard_id)
}
#[inline]
pub fn version(&self) -> ShardVersion {
match self {
Self::V0(v0) => v0.version,
Self::V1(v1) => v1.version,
}
}
fn num_shards(&self) -> NumShards {
match self {
Self::V0(v0) => v0.num_shards,
Self::V1(v1) => (v1.boundary_accounts.len() + 1) as NumShards,
}
}
pub fn shard_ids(&self) -> impl Iterator<Item = ShardId> {
0..self.num_shards()
}
pub fn shard_uids(&self) -> impl Iterator<Item = ShardUId> + '_ {
self.shard_ids().map(|shard_id| ShardUId::from_shard_id_and_layout(shard_id, self))
}
}
pub fn account_id_to_shard_id(account_id: &AccountId, shard_layout: &ShardLayout) -> ShardId {
match shard_layout {
ShardLayout::V0(ShardLayoutV0 { num_shards, .. }) => {
let hash = CryptoHash::hash_bytes(account_id.as_bytes());
let (bytes, _) = stdx::split_array::<32, 8, 24>(hash.as_bytes());
u64::from_le_bytes(*bytes) % num_shards
}
ShardLayout::V1(ShardLayoutV1 { boundary_accounts, .. }) => {
let mut shard_id: ShardId = 0;
for boundary_account in boundary_accounts {
if account_id < boundary_account {
break;
}
shard_id += 1;
}
shard_id
}
}
}
pub fn account_id_to_shard_uid(account_id: &AccountId, shard_layout: &ShardLayout) -> ShardUId {
ShardUId::from_shard_id_and_layout(
account_id_to_shard_id(account_id, shard_layout),
shard_layout,
)
}
#[derive(BorshSerialize, BorshDeserialize, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct ShardUId {
pub version: ShardVersion,
pub shard_id: u32,
}
impl ShardUId {
pub fn single_shard() -> Self {
Self { version: 0, shard_id: 0 }
}
pub fn to_bytes(&self) -> [u8; 8] {
let mut res = [0; 8];
res[0..4].copy_from_slice(&u32::to_le_bytes(self.version));
res[4..].copy_from_slice(&u32::to_le_bytes(self.shard_id));
res
}
pub fn next_shard_prefix(shard_uid_bytes: &[u8; 8]) -> [u8; 8] {
let mut result = *shard_uid_bytes;
for i in (0..8).rev() {
if result[i] == u8::MAX {
result[i] = 0;
} else {
result[i] += 1;
return result;
}
}
panic!("Next shard prefix for shard bytes {shard_uid_bytes:?} does not exist");
}
pub fn from_shard_id_and_layout(shard_id: ShardId, shard_layout: &ShardLayout) -> Self {
assert!(shard_layout.shard_ids().any(|i| i == shard_id));
Self { shard_id: shard_id as u32, version: shard_layout.version() }
}
pub fn shard_id(&self) -> ShardId {
ShardId::from(self.shard_id)
}
}
impl TryFrom<&[u8]> for ShardUId {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() != 8 {
return Err("incorrect length for ShardUId".into());
}
let version = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
let shard_id = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
Ok(Self { version, shard_id })
}
}
pub fn get_block_shard_uid(block_hash: &CryptoHash, shard_uid: &ShardUId) -> Vec<u8> {
let mut res = Vec::with_capacity(40);
res.extend_from_slice(block_hash.as_ref());
res.extend_from_slice(&shard_uid.to_bytes());
res
}
#[allow(unused)]
pub fn get_block_shard_uid_rev(
key: &[u8],
) -> Result<(CryptoHash, ShardUId), Box<dyn std::error::Error + Send + Sync>> {
if key.len() != 40 {
return Err(
std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid key length").into()
);
}
let block_hash = CryptoHash::try_from(&key[..32])?;
let shard_id = ShardUId::try_from(&key[32..])?;
Ok((block_hash, shard_id))
}
impl fmt::Display for ShardUId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "s{}.v{}", self.shard_id, self.version)
}
}
impl fmt::Debug for ShardUId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl str::FromStr for ShardUId {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (shard_str, version_str) = s
.split_once(".")
.ok_or_else(|| "shard version and number must be separated by \".\"".to_string())?;
let version = version_str
.strip_prefix("v")
.ok_or_else(|| "shard version must start with \"v\"".to_string())?
.parse::<ShardVersion>()
.map_err(|e| format!("shard version after \"v\" must be a number, {e}"))?;
let shard_str = shard_str
.strip_prefix("s")
.ok_or_else(|| "shard id must start with \"s\"".to_string())?;
let shard_id = shard_str
.parse::<u32>()
.map_err(|e| format!("shard id after \"s\" must be a number, {e}"))?;
Ok(ShardUId { shard_id, version })
}
}
impl<'de> serde::Deserialize<'de> for ShardUId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(ShardUIdVisitor)
}
}
impl serde::Serialize for ShardUId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
struct ShardUIdVisitor;
impl<'de> serde::de::Visitor<'de> for ShardUIdVisitor {
type Value = ShardUId;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(
formatter,
"either string format of `ShardUId` like s0v1 for shard 0 version 1, or a map"
)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map_err(|e| E::custom(e))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut version = None;
let mut shard_id = None;
while let Some((field, value)) = map.next_entry()? {
match field {
"version" => version = Some(value),
"shard_id" => shard_id = Some(value),
_ => return Err(serde::de::Error::unknown_field(field, &["version", "shard_id"])),
}
}
match (version, shard_id) {
(None, _) => Err(serde::de::Error::missing_field("version")),
(_, None) => Err(serde::de::Error::missing_field("shard_id")),
(Some(version), Some(shard_id)) => Ok(ShardUId { version, shard_id }),
}
}
}
#[cfg(test)]
mod tests {
use crate::shard_layout::{account_id_to_shard_id, ShardLayout, ShardLayoutV1, ShardUId};
use rand::distributions::Alphanumeric;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::collections::HashMap;
use unc_primitives_core::types::{AccountId, ShardId};
use super::{ShardSplitMap, ShardVersion};
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct OldShardLayoutV1 {
fixed_shards: Vec<AccountId>,
boundary_accounts: Vec<AccountId>,
shards_split_map: Option<ShardSplitMap>,
to_parent_shard_map: Option<Vec<ShardId>>,
version: ShardVersion,
}
#[test]
fn test_shard_layout_v0() {
let num_shards = 4;
let shard_layout = ShardLayout::v0(num_shards, 0);
let mut shard_id_distribution: HashMap<_, _> =
shard_layout.shard_ids().map(|shard_id| (shard_id, 0)).collect();
let mut rng = StdRng::from_seed([0; 32]);
for _i in 0..1000 {
let s: Vec<u8> = (&mut rng).sample_iter(&Alphanumeric).take(10).collect();
let s = String::from_utf8(s).unwrap();
let account_id = s.to_lowercase().parse().unwrap();
let shard_id = account_id_to_shard_id(&account_id, &shard_layout);
assert!(shard_id < num_shards);
*shard_id_distribution.get_mut(&shard_id).unwrap() += 1;
}
let expected_distribution: HashMap<_, _> =
[(0, 247), (1, 268), (2, 233), (3, 252)].into_iter().collect();
assert_eq!(shard_id_distribution, expected_distribution);
}
#[test]
fn test_shard_layout_v1() {
let shard_layout = ShardLayout::v1(
parse_account_ids(&["aurora", "bar", "foo", "foo.baz", "paz"]),
Some(vec![vec![0, 1, 2], vec![3, 4, 5]]),
1,
);
assert_eq!(
shard_layout.get_children_shards_uids(0).unwrap(),
(0..3).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
);
assert_eq!(
shard_layout.get_children_shards_uids(1).unwrap(),
(3..6).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
);
for x in 0..3 {
assert_eq!(shard_layout.get_parent_shard_id(x).unwrap(), 0);
assert_eq!(shard_layout.get_parent_shard_id(x + 3).unwrap(), 1);
}
assert_eq!(account_id_to_shard_id(&"aurora".parse().unwrap(), &shard_layout), 1);
assert_eq!(account_id_to_shard_id(&"foo.aurora".parse().unwrap(), &shard_layout), 3);
assert_eq!(account_id_to_shard_id(&"bar.foo.aurora".parse().unwrap(), &shard_layout), 2);
assert_eq!(account_id_to_shard_id(&"bar".parse().unwrap(), &shard_layout), 2);
assert_eq!(account_id_to_shard_id(&"bar.bar".parse().unwrap(), &shard_layout), 2);
assert_eq!(account_id_to_shard_id(&"foo".parse().unwrap(), &shard_layout), 3);
assert_eq!(account_id_to_shard_id(&"baz.foo".parse().unwrap(), &shard_layout), 2);
assert_eq!(account_id_to_shard_id(&"foo.baz".parse().unwrap(), &shard_layout), 4);
assert_eq!(account_id_to_shard_id(&"a.foo.baz".parse().unwrap(), &shard_layout), 0);
assert_eq!(account_id_to_shard_id(&"aaa".parse().unwrap(), &shard_layout), 0);
assert_eq!(account_id_to_shard_id(&"abc".parse().unwrap(), &shard_layout), 0);
assert_eq!(account_id_to_shard_id(&"bbb".parse().unwrap(), &shard_layout), 2);
assert_eq!(account_id_to_shard_id(&"foo.goo".parse().unwrap(), &shard_layout), 4);
assert_eq!(account_id_to_shard_id(&"goo".parse().unwrap(), &shard_layout), 4);
assert_eq!(account_id_to_shard_id(&"zoo".parse().unwrap(), &shard_layout), 5);
}
#[test]
fn test_remove_fixed_shards() {
let old = OldShardLayoutV1 {
fixed_shards: vec![],
boundary_accounts: parse_account_ids(&["aaa", "bbb"]),
shards_split_map: Some(vec![vec![0, 1, 2]]),
to_parent_shard_map: Some(vec![0, 0, 0]),
version: 1,
};
let json = serde_json::to_string_pretty(&old).unwrap();
println!("json");
println!("{json:#?}");
let new = serde_json::from_str::<ShardLayoutV1>(json.as_str()).unwrap();
assert_eq!(old.boundary_accounts, new.boundary_accounts);
assert_eq!(old.shards_split_map, new.shards_split_map);
assert_eq!(old.to_parent_shard_map, new.to_parent_shard_map);
assert_eq!(old.version, new.version);
}
fn parse_account_ids(ids: &[&str]) -> Vec<AccountId> {
ids.into_iter().map(|a| a.parse().unwrap()).collect()
}
#[test]
fn test_shard_layout_all() {
let v0 = ShardLayout::v0(1, 0);
insta::assert_snapshot!(serde_json::to_string_pretty(&v0).unwrap(), @r###"
{
"V0": {
"num_shards": 1,
"version": 0
}
}
"###);
}
}