Skip to main content

shardcache_client_rs/
routing.rs

1use std::io::Write;
2use std::net::{SocketAddr, ToSocketAddrs};
3
4use crate::error::{Result, ShardCacheClientError};
5
6/// Router for shard-owned SCNP direct ports.
7///
8/// The base address is the first shard-owned port. Shard `n` is expected at
9/// `base_port + n`; when the server also exposes a fanout port, this base is
10/// usually `SHARDCACHE_DIRECT_SHARD_BASE_PORT` or the fanout port + 1.
11#[derive(Debug, Clone, Copy)]
12pub struct ShardCacheDirectRouter {
13    base_addr: SocketAddr,
14    shard_count: usize,
15    shift: u32,
16    route_mode: ShardCacheRouteMode,
17}
18
19impl ShardCacheDirectRouter {
20    /// Creates a direct router for `shard_count` server shards.
21    pub fn new(addr: impl ToSocketAddrs, shard_count: usize) -> Result<Self> {
22        if shard_count == 0 || !shard_count.is_power_of_two() {
23            return Err(ShardCacheClientError::Config(format!(
24                "SCNP direct shard count must be a non-zero power of two: {shard_count}"
25            )));
26        }
27        let base_addr = resolve_one(addr)?;
28        Ok(Self {
29            base_addr,
30            shard_count,
31            shift: shift_for(shard_count),
32            route_mode: ShardCacheRouteMode::FullKey,
33        })
34    }
35
36    /// Sets how direct shard routing chooses the owning shard.
37    ///
38    /// `FullKey` is the normal point-key mode. `SessionPrefix` routes keys of
39    /// the form `s:<session>:c:<chunk>` by `s:<session>` while preserving the
40    /// full-key hash used for lookup within that shard.
41    pub fn with_route_mode(mut self, route_mode: ShardCacheRouteMode) -> Self {
42        self.route_mode = route_mode;
43        self
44    }
45
46    /// Returns the number of direct shard ports.
47    pub fn shard_count(&self) -> usize {
48        self.shard_count
49    }
50
51    /// Computes the routed SCNP metadata for `key`.
52    pub fn route_key(&self, key: &[u8]) -> ShardCacheRoute {
53        let key_hash = hash_key(key);
54        let route_hash = self.route_mode.route_hash(key, key_hash);
55        ShardCacheRoute {
56            key_hash,
57            key_tag: hash_key_tag_from_hash(key_hash),
58            shard_id: stripe_index(route_hash, self.shift),
59        }
60    }
61
62    /// Returns the socket address for `shard_id`.
63    pub fn shard_addr(&self, shard_id: usize) -> Result<SocketAddr> {
64        if shard_id >= self.shard_count {
65            return Err(ShardCacheClientError::Config(format!(
66                "SCNP direct shard {shard_id} out of range for {} shards",
67                self.shard_count
68            )));
69        }
70        let mut addr = self.base_addr;
71        let offset = u16::try_from(shard_id).map_err(|_| {
72            ShardCacheClientError::Config(format!("SCNP direct shard id exceeds u16: {shard_id}"))
73        })?;
74        let port = self.base_addr.port().checked_add(offset).ok_or_else(|| {
75            ShardCacheClientError::Config(format!(
76                "SCNP direct shard port overflows for shard {shard_id}"
77            ))
78        })?;
79        addr.set_port(port);
80        Ok(addr)
81    }
82}
83
84/// Shard-routing mode for direct SCNP clients.
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum ShardCacheRouteMode {
87    /// Route every key by its full key hash.
88    FullKey,
89    /// Route `s:<session>:c:<chunk>` keys by the session prefix.
90    SessionPrefix,
91}
92
93impl ShardCacheRouteMode {
94    /// Parses the route mode used by benchmark and deployment knobs.
95    pub fn parse(value: &str) -> Result<Self> {
96        match value {
97            "full_key" | "full-key" | "point" => Ok(Self::FullKey),
98            "session_prefix" | "session-prefix" | "session" => Ok(Self::SessionPrefix),
99            other => Err(ShardCacheClientError::Config(format!(
100                "unknown SCNP direct route mode `{other}`; use full_key or session_prefix"
101            ))),
102        }
103    }
104
105    fn route_hash(self, key: &[u8], key_hash: u64) -> u64 {
106        match self {
107            Self::FullKey => key_hash,
108            Self::SessionPrefix => hash_key(session_route_prefix(key)),
109        }
110    }
111}
112
113/// Precomputed routed SCNP metadata for a key.
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub struct ShardCacheRoute {
116    /// Primary key hash.
117    pub key_hash: u64,
118    /// Secondary key tag used by direct full-key lookups.
119    pub key_tag: u64,
120    /// Owning shard id.
121    pub shard_id: usize,
122}
123
124impl ShardCacheRoute {
125    pub(crate) fn write_to<W: Write>(&self, w: &mut W) -> Result<()> {
126        w.write_all(&self.key_hash.to_le_bytes())?;
127        w.write_all(&(self.shard_id as u32).to_le_bytes())?;
128        w.write_all(&self.key_tag.to_le_bytes())?;
129        Ok(())
130    }
131}
132
133/// Computes shardcache's primary XXH3 key hash.
134pub fn hash_key(key: &[u8]) -> u64 {
135    xxhash_rust::xxh3::xxh3_64(key)
136}
137
138/// Computes shardcache's secondary key fingerprint.
139pub fn hash_key_tag(key: &[u8]) -> u64 {
140    hash_key_tag_from_hash(hash_key(key))
141}
142
143/// Computes the secondary key fingerprint from an already-computed primary hash.
144pub fn hash_key_tag_from_hash(hash: u64) -> u64 {
145    hash >> 56
146}
147
148fn session_route_prefix(key: &[u8]) -> &[u8] {
149    if !key.starts_with(b"s:") {
150        return key;
151    }
152
153    if let Some(index) = session_chunk_separator(key) {
154        return &key[..index];
155    }
156
157    key
158}
159
160#[inline(always)]
161fn session_chunk_separator(key: &[u8]) -> Option<usize> {
162    if key.len() < 3 {
163        return None;
164    }
165
166    let mut index = key.len() - 3;
167    loop {
168        if key[index] == b':' && key[index + 1] == b'c' && key[index + 2] == b':' {
169            return Some(index);
170        }
171        if index == 0 {
172            return None;
173        }
174        index -= 1;
175    }
176}
177
178/// Computes the shard index for `hash` and `shard_count`.
179pub fn shard_index(hash: u64, shard_count: usize) -> Result<usize> {
180    if shard_count == 0 || !shard_count.is_power_of_two() {
181        return Err(ShardCacheClientError::Config(format!(
182            "shard count must be a non-zero power of two: {shard_count}"
183        )));
184    }
185    Ok(stripe_index(hash, shift_for(shard_count)))
186}
187
188fn stripe_index(hash: u64, shift: u32) -> usize {
189    if shift == usize::BITS {
190        0
191    } else {
192        ((hash as usize) << 7) >> shift
193    }
194}
195
196fn shift_for(shard_count: usize) -> u32 {
197    debug_assert!(shard_count > 0 && shard_count.is_power_of_two());
198    usize::BITS - shard_count.trailing_zeros()
199}
200
201fn resolve_one(addr: impl ToSocketAddrs) -> Result<SocketAddr> {
202    addr.to_socket_addrs()?.next().ok_or_else(|| {
203        ShardCacheClientError::Config("SCNP address resolved to no socket addresses".into())
204    })
205}