1use crate::hash::CryptoHash;
2use crate::types::{AccountId, NumShards};
3use borsh::{BorshDeserialize, BorshSerialize};
4use std::collections::HashMap;
5use std::{fmt, str};
6use unc_primitives_core::types::ShardId;
7
8pub type ShardVersion = u32;
49
50#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
51pub enum ShardLayout {
52 V0(ShardLayoutV0),
53 V1(ShardLayoutV1),
54}
55
56#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
62pub struct ShardLayoutV0 {
63 num_shards: NumShards,
65 version: ShardVersion,
67}
68
69type ShardSplitMap = Vec<Vec<ShardId>>;
74
75#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
76pub struct ShardLayoutV1 {
77 boundary_accounts: Vec<AccountId>,
82 shards_split_map: Option<ShardSplitMap>,
86 to_parent_shard_map: Option<Vec<ShardId>>,
89 version: ShardVersion,
91}
92
93#[derive(Debug)]
94pub enum ShardLayoutError {
95 InvalidShardIdError { shard_id: ShardId },
96}
97
98impl ShardLayout {
99 pub fn v0_single_shard() -> Self {
101 Self::v0(1, 0)
102 }
103
104 pub fn v0(num_shards: NumShards, version: ShardVersion) -> Self {
106 Self::V0(ShardLayoutV0 { num_shards, version })
107 }
108
109 pub fn v1(
111 boundary_accounts: Vec<AccountId>,
112 shards_split_map: Option<ShardSplitMap>,
113 version: ShardVersion,
114 ) -> Self {
115 let to_parent_shard_map = if let Some(shards_split_map) = &shards_split_map {
116 let mut to_parent_shard_map = HashMap::new();
117 let num_shards = (boundary_accounts.len() + 1) as NumShards;
118 for (parent_shard_id, shard_ids) in shards_split_map.iter().enumerate() {
119 for &shard_id in shard_ids {
120 let prev = to_parent_shard_map.insert(shard_id, parent_shard_id as ShardId);
121 assert!(prev.is_none(), "no shard should appear in the map twice");
122 assert!(shard_id < num_shards, "shard id should be valid");
123 }
124 }
125 Some((0..num_shards).map(|shard_id| to_parent_shard_map[&shard_id]).collect())
126 } else {
127 None
128 };
129 Self::V1(ShardLayoutV1 {
130 boundary_accounts,
131 shards_split_map,
132 to_parent_shard_map,
133 version,
134 })
135 }
136
137 pub fn v0_test() -> Self {
139 ShardLayout::v0(1, 0)
140 }
141
142 pub fn get_children_shards_uids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardUId>> {
145 self.get_children_shards_ids(parent_shard_id).map(|shards| {
146 shards.into_iter().map(|id| ShardUId::from_shard_id_and_layout(id, self)).collect()
147 })
148 }
149
150 pub fn get_children_shards_ids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardId>> {
153 match self {
154 Self::V0(_) => None,
155 Self::V1(v1) => match &v1.shards_split_map {
156 Some(shards_split_map) => shards_split_map.get(parent_shard_id as usize).cloned(),
157 None => None,
158 },
159 }
160 }
161
162 pub fn get_parent_shard_id(&self, shard_id: ShardId) -> Result<ShardId, ShardLayoutError> {
167 if !self.shard_ids().any(|id| id == shard_id) {
168 return Err(ShardLayoutError::InvalidShardIdError { shard_id });
169 }
170 let parent_shard_id = match self {
171 Self::V0(_) => panic!("shard layout has no parent shard"),
172 Self::V1(v1) => match &v1.to_parent_shard_map {
173 Some(to_parent_shard_map) => *to_parent_shard_map.get(shard_id as usize).unwrap(),
176 None => panic!("shard_layout has no parent shard"),
177 },
178 };
179 Ok(parent_shard_id)
180 }
181
182 #[inline]
183 pub fn version(&self) -> ShardVersion {
184 match self {
185 Self::V0(v0) => v0.version,
186 Self::V1(v1) => v1.version,
187 }
188 }
189
190 fn num_shards(&self) -> NumShards {
191 match self {
192 Self::V0(v0) => v0.num_shards,
193 Self::V1(v1) => (v1.boundary_accounts.len() + 1) as NumShards,
194 }
195 }
196
197 pub fn shard_ids(&self) -> impl Iterator<Item = ShardId> {
198 0..self.num_shards()
199 }
200
201 pub fn shard_uids(&self) -> impl Iterator<Item = ShardUId> + '_ {
204 self.shard_ids().map(|shard_id| ShardUId::from_shard_id_and_layout(shard_id, self))
205 }
206}
207
208pub fn account_id_to_shard_id(account_id: &AccountId, shard_layout: &ShardLayout) -> ShardId {
212 match shard_layout {
213 ShardLayout::V0(ShardLayoutV0 { num_shards, .. }) => {
214 let hash = CryptoHash::hash_bytes(account_id.as_bytes());
215 let (bytes, _) = stdx::split_array::<32, 8, 24>(hash.as_bytes());
216 u64::from_le_bytes(*bytes) % num_shards
217 }
218 ShardLayout::V1(ShardLayoutV1 { boundary_accounts, .. }) => {
219 let mut shard_id: ShardId = 0;
223 for boundary_account in boundary_accounts {
224 if account_id < boundary_account {
225 break;
226 }
227 shard_id += 1;
228 }
229 shard_id
230 }
231 }
232}
233
234pub fn account_id_to_shard_uid(account_id: &AccountId, shard_layout: &ShardLayout) -> ShardUId {
236 ShardUId::from_shard_id_and_layout(
237 account_id_to_shard_id(account_id, shard_layout),
238 shard_layout,
239 )
240}
241
242#[derive(BorshSerialize, BorshDeserialize, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
244pub struct ShardUId {
245 pub version: ShardVersion,
246 pub shard_id: u32,
247}
248
249impl ShardUId {
250 pub fn single_shard() -> Self {
251 Self { version: 0, shard_id: 0 }
252 }
253
254 pub fn to_bytes(&self) -> [u8; 8] {
256 let mut res = [0; 8];
257 res[0..4].copy_from_slice(&u32::to_le_bytes(self.version));
258 res[4..].copy_from_slice(&u32::to_le_bytes(self.shard_id));
259 res
260 }
261
262 pub fn next_shard_prefix(shard_uid_bytes: &[u8; 8]) -> [u8; 8] {
263 let mut result = *shard_uid_bytes;
264 for i in (0..8).rev() {
265 if result[i] == u8::MAX {
266 result[i] = 0;
267 } else {
268 result[i] += 1;
269 return result;
270 }
271 }
272 panic!("Next shard prefix for shard bytes {shard_uid_bytes:?} does not exist");
273 }
274
275 pub fn from_shard_id_and_layout(shard_id: ShardId, shard_layout: &ShardLayout) -> Self {
277 assert!(shard_layout.shard_ids().any(|i| i == shard_id));
278 Self { shard_id: shard_id as u32, version: shard_layout.version() }
279 }
280
281 pub fn shard_id(&self) -> ShardId {
283 ShardId::from(self.shard_id)
284 }
285}
286
287impl TryFrom<&[u8]> for ShardUId {
288 type Error = Box<dyn std::error::Error + Send + Sync>;
289
290 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
292 if bytes.len() != 8 {
293 return Err("incorrect length for ShardUId".into());
294 }
295 let version = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
296 let shard_id = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
297 Ok(Self { version, shard_id })
298 }
299}
300
301pub fn get_block_shard_uid(block_hash: &CryptoHash, shard_uid: &ShardUId) -> Vec<u8> {
303 let mut res = Vec::with_capacity(40);
304 res.extend_from_slice(block_hash.as_ref());
305 res.extend_from_slice(&shard_uid.to_bytes());
306 res
307}
308
309#[allow(unused)]
311pub fn get_block_shard_uid_rev(
312 key: &[u8],
313) -> Result<(CryptoHash, ShardUId), Box<dyn std::error::Error + Send + Sync>> {
314 if key.len() != 40 {
315 return Err(
316 std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid key length").into()
317 );
318 }
319 let block_hash = CryptoHash::try_from(&key[..32])?;
320 let shard_id = ShardUId::try_from(&key[32..])?;
321 Ok((block_hash, shard_id))
322}
323
324impl fmt::Display for ShardUId {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 write!(f, "s{}.v{}", self.shard_id, self.version)
327 }
328}
329
330impl fmt::Debug for ShardUId {
331 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
332 fmt::Display::fmt(self, f)
333 }
334}
335
336impl str::FromStr for ShardUId {
337 type Err = String;
338
339 fn from_str(s: &str) -> Result<Self, Self::Err> {
340 let (shard_str, version_str) = s
341 .split_once(".")
342 .ok_or_else(|| "shard version and number must be separated by \".\"".to_string())?;
343
344 let version = version_str
345 .strip_prefix("v")
346 .ok_or_else(|| "shard version must start with \"v\"".to_string())?
347 .parse::<ShardVersion>()
348 .map_err(|e| format!("shard version after \"v\" must be a number, {e}"))?;
349
350 let shard_str = shard_str
351 .strip_prefix("s")
352 .ok_or_else(|| "shard id must start with \"s\"".to_string())?;
353 let shard_id = shard_str
354 .parse::<u32>()
355 .map_err(|e| format!("shard id after \"s\" must be a number, {e}"))?;
356
357 Ok(ShardUId { shard_id, version })
358 }
359}
360
361impl<'de> serde::Deserialize<'de> for ShardUId {
362 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
363 where
364 D: serde::Deserializer<'de>,
365 {
366 deserializer.deserialize_any(ShardUIdVisitor)
367 }
368}
369
370impl serde::Serialize for ShardUId {
371 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
372 where
373 S: serde::Serializer,
374 {
375 serializer.serialize_str(&self.to_string())
376 }
377}
378
379struct ShardUIdVisitor;
380impl<'de> serde::de::Visitor<'de> for ShardUIdVisitor {
381 type Value = ShardUId;
382
383 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
384 write!(
385 formatter,
386 "either string format of `ShardUId` like s0v1 for shard 0 version 1, or a map"
387 )
388 }
389
390 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
391 where
392 E: serde::de::Error,
393 {
394 v.parse().map_err(|e| E::custom(e))
395 }
396
397 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
398 where
399 A: serde::de::MapAccess<'de>,
400 {
401 let mut version = None;
405 let mut shard_id = None;
406
407 while let Some((field, value)) = map.next_entry()? {
408 match field {
409 "version" => version = Some(value),
410 "shard_id" => shard_id = Some(value),
411 _ => return Err(serde::de::Error::unknown_field(field, &["version", "shard_id"])),
412 }
413 }
414
415 match (version, shard_id) {
416 (None, _) => Err(serde::de::Error::missing_field("version")),
417 (_, None) => Err(serde::de::Error::missing_field("shard_id")),
418 (Some(version), Some(shard_id)) => Ok(ShardUId { version, shard_id }),
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use crate::shard_layout::{account_id_to_shard_id, ShardLayout, ShardLayoutV1, ShardUId};
426 use rand::distributions::Alphanumeric;
427 use rand::rngs::StdRng;
428 use rand::{Rng, SeedableRng};
429 use std::collections::HashMap;
430 use unc_primitives_core::types::{AccountId, ShardId};
431
432 use super::{ShardSplitMap, ShardVersion};
433
434 #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)]
436 pub struct OldShardLayoutV1 {
437 fixed_shards: Vec<AccountId>,
440 boundary_accounts: Vec<AccountId>,
442 shards_split_map: Option<ShardSplitMap>,
446 to_parent_shard_map: Option<Vec<ShardId>>,
449 version: ShardVersion,
451 }
452
453 #[test]
454 fn test_shard_layout_v0() {
455 let num_shards = 4;
456 let shard_layout = ShardLayout::v0(num_shards, 0);
457 let mut shard_id_distribution: HashMap<_, _> =
458 shard_layout.shard_ids().map(|shard_id| (shard_id, 0)).collect();
459 let mut rng = StdRng::from_seed([0; 32]);
460 for _i in 0..1000 {
461 let s: Vec<u8> = (&mut rng).sample_iter(&Alphanumeric).take(10).collect();
462 let s = String::from_utf8(s).unwrap();
463 let account_id = s.to_lowercase().parse().unwrap();
464 let shard_id = account_id_to_shard_id(&account_id, &shard_layout);
465 assert!(shard_id < num_shards);
466 *shard_id_distribution.get_mut(&shard_id).unwrap() += 1;
467 }
468 let expected_distribution: HashMap<_, _> =
469 [(0, 247), (1, 268), (2, 233), (3, 252)].into_iter().collect();
470 assert_eq!(shard_id_distribution, expected_distribution);
471 }
472
473 #[test]
474 fn test_shard_layout_v1() {
475 let shard_layout = ShardLayout::v1(
476 parse_account_ids(&["aurora", "bar", "foo", "foo.baz", "paz"]),
477 Some(vec![vec![0, 1, 2], vec![3, 4, 5]]),
478 1,
479 );
480 assert_eq!(
481 shard_layout.get_children_shards_uids(0).unwrap(),
482 (0..3).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
483 );
484 assert_eq!(
485 shard_layout.get_children_shards_uids(1).unwrap(),
486 (3..6).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
487 );
488 for x in 0..3 {
489 assert_eq!(shard_layout.get_parent_shard_id(x).unwrap(), 0);
490 assert_eq!(shard_layout.get_parent_shard_id(x + 3).unwrap(), 1);
491 }
492
493 assert_eq!(account_id_to_shard_id(&"aurora".parse().unwrap(), &shard_layout), 1);
494 assert_eq!(account_id_to_shard_id(&"foo.aurora".parse().unwrap(), &shard_layout), 3);
495 assert_eq!(account_id_to_shard_id(&"bar.foo.aurora".parse().unwrap(), &shard_layout), 2);
496 assert_eq!(account_id_to_shard_id(&"bar".parse().unwrap(), &shard_layout), 2);
497 assert_eq!(account_id_to_shard_id(&"bar.bar".parse().unwrap(), &shard_layout), 2);
498 assert_eq!(account_id_to_shard_id(&"foo".parse().unwrap(), &shard_layout), 3);
499 assert_eq!(account_id_to_shard_id(&"baz.foo".parse().unwrap(), &shard_layout), 2);
500 assert_eq!(account_id_to_shard_id(&"foo.baz".parse().unwrap(), &shard_layout), 4);
501 assert_eq!(account_id_to_shard_id(&"a.foo.baz".parse().unwrap(), &shard_layout), 0);
502
503 assert_eq!(account_id_to_shard_id(&"aaa".parse().unwrap(), &shard_layout), 0);
504 assert_eq!(account_id_to_shard_id(&"abc".parse().unwrap(), &shard_layout), 0);
505 assert_eq!(account_id_to_shard_id(&"bbb".parse().unwrap(), &shard_layout), 2);
506 assert_eq!(account_id_to_shard_id(&"foo.goo".parse().unwrap(), &shard_layout), 4);
507 assert_eq!(account_id_to_shard_id(&"goo".parse().unwrap(), &shard_layout), 4);
508 assert_eq!(account_id_to_shard_id(&"zoo".parse().unwrap(), &shard_layout), 5);
509 }
510
511 #[test]
515 fn test_remove_fixed_shards() {
516 let old = OldShardLayoutV1 {
517 fixed_shards: vec![],
518 boundary_accounts: parse_account_ids(&["aaa", "bbb"]),
519 shards_split_map: Some(vec![vec![0, 1, 2]]),
520 to_parent_shard_map: Some(vec![0, 0, 0]),
521 version: 1,
522 };
523 let json = serde_json::to_string_pretty(&old).unwrap();
524 println!("json");
525 println!("{json:#?}");
526
527 let new = serde_json::from_str::<ShardLayoutV1>(json.as_str()).unwrap();
528 assert_eq!(old.boundary_accounts, new.boundary_accounts);
529 assert_eq!(old.shards_split_map, new.shards_split_map);
530 assert_eq!(old.to_parent_shard_map, new.to_parent_shard_map);
531 assert_eq!(old.version, new.version);
532 }
533
534 fn parse_account_ids(ids: &[&str]) -> Vec<AccountId> {
535 ids.into_iter().map(|a| a.parse().unwrap()).collect()
536 }
537
538 #[test]
539 fn test_shard_layout_all() {
540 let v0 = ShardLayout::v0(1, 0);
541
542 insta::assert_snapshot!(serde_json::to_string_pretty(&v0).unwrap(), @r###"
543 {
544 "V0": {
545 "num_shards": 1,
546 "version": 0
547 }
548 }
549 "###);
550 }
551}