Skip to main content

st_protocol/
address.rs

1//! Network addressing for daemon routing
2//!
3//! ## Address Prefix Format
4//!
5//! Single byte prefix for routing:
6//! - `0x00` = local daemon (Unix socket /run/st.sock)
7//! - `0x01-0x7F` = cached host index (up to 127 known hosts)
8//! - `0x80-0xFE` = inline address follows (len = byte - 0x80)
9//! - `0xFF` = broadcast/discover
10
11#[cfg(feature = "std")]
12use std::collections::HashMap;
13
14#[cfg(feature = "std")]
15extern crate std as alloc;
16
17#[cfg(all(feature = "alloc", not(feature = "std")))]
18extern crate alloc;
19
20/// Network address for daemon communication
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum Address {
23    /// Local daemon via Unix socket
24    Local,
25    /// Cached host by index (1-127)
26    Cached(u8),
27    /// Inline address string (hostname:port or IP:port)
28    Inline(AddressString),
29    /// Broadcast/discover all daemons
30    Broadcast,
31}
32
33/// Inline address string (max 126 bytes)
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct AddressString {
36    data: [u8; 126],
37    len: usize,
38}
39
40impl AddressString {
41    /// Create from string
42    pub fn new(s: &str) -> Option<Self> {
43        if s.len() > 126 {
44            return None;
45        }
46        let mut data = [0u8; 126];
47        data[..s.len()].copy_from_slice(s.as_bytes());
48        Some(AddressString {
49            data,
50            len: s.len(),
51        })
52    }
53
54    /// Get as string slice
55    pub fn as_str(&self) -> &str {
56        core::str::from_utf8(&self.data[..self.len]).unwrap_or("")
57    }
58
59    /// Get as bytes
60    pub fn as_bytes(&self) -> &[u8] {
61        &self.data[..self.len]
62    }
63
64    /// Length
65    pub fn len(&self) -> usize {
66        self.len
67    }
68
69    /// Check if empty
70    pub fn is_empty(&self) -> bool {
71        self.len == 0
72    }
73}
74
75impl Address {
76    /// Encode address as prefix byte(s)
77    #[cfg(any(feature = "std", feature = "alloc"))]
78    pub fn encode(&self) -> alloc::vec::Vec<u8> {
79        match self {
80            Address::Local => alloc::vec![0x00],
81            Address::Cached(idx) => alloc::vec![*idx],
82            Address::Inline(addr) => {
83                let len = addr.len();
84                let mut out = alloc::vec::Vec::with_capacity(len + 1);
85                out.push((len as u8) + 0x80);
86                out.extend_from_slice(addr.as_bytes());
87                out
88            }
89            Address::Broadcast => alloc::vec![0xFF],
90        }
91    }
92
93    #[cfg(not(any(feature = "std", feature = "alloc")))]
94    pub fn encode_to(&self, buf: &mut [u8]) -> usize {
95        match self {
96            Address::Local => {
97                buf[0] = 0x00;
98                1
99            }
100            Address::Cached(idx) => {
101                buf[0] = *idx;
102                1
103            }
104            Address::Inline(addr) => {
105                let len = addr.len();
106                buf[0] = (len as u8) + 0x80;
107                buf[1..1 + len].copy_from_slice(addr.as_bytes());
108                1 + len
109            }
110            Address::Broadcast => {
111                buf[0] = 0xFF;
112                1
113            }
114        }
115    }
116
117    /// Decode address from prefix byte(s)
118    pub fn decode(data: &[u8]) -> Option<(Self, usize)> {
119        if data.is_empty() {
120            return None;
121        }
122
123        let first = data[0];
124
125        match first {
126            0x00 => Some((Address::Local, 1)),
127            0x01..=0x7F => Some((Address::Cached(first), 1)),
128            0x80..=0xFE => {
129                let len = (first - 0x80) as usize;
130                if data.len() < 1 + len {
131                    return None;
132                }
133                let addr = AddressString::new(core::str::from_utf8(&data[1..1 + len]).ok()?)?;
134                Some((Address::Inline(addr), 1 + len))
135            }
136            0xFF => Some((Address::Broadcast, 1)),
137        }
138    }
139
140    /// Check if this is a local address
141    pub fn is_local(&self) -> bool {
142        matches!(self, Address::Local)
143    }
144
145    /// Check if this is a remote address
146    pub fn is_remote(&self) -> bool {
147        !self.is_local()
148    }
149}
150
151/// Host cache for remembered remote daemons
152#[cfg(feature = "std")]
153#[derive(Debug, Clone)]
154pub struct HostCache {
155    /// Index -> (hostname:port, display_name)
156    hosts: HashMap<u8, (String, String)>,
157    /// Hostname -> index (reverse lookup)
158    by_name: HashMap<String, u8>,
159    /// Next available index
160    next_index: u8,
161}
162
163#[cfg(feature = "std")]
164impl Default for HostCache {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170#[cfg(feature = "std")]
171impl HostCache {
172    /// Create empty cache
173    pub fn new() -> Self {
174        HostCache {
175            hosts: HashMap::new(),
176            by_name: HashMap::new(),
177            next_index: 1, // 0 is reserved for local
178        }
179    }
180
181    /// Add or update a host
182    pub fn add(&mut self, host: &str, name: &str) -> Option<u8> {
183        // Check if already exists
184        if let Some(&idx) = self.by_name.get(host) {
185            return Some(idx);
186        }
187
188        // Check capacity (1-127)
189        if self.next_index > 127 {
190            return None;
191        }
192
193        let idx = self.next_index;
194        self.next_index += 1;
195
196        self.hosts.insert(idx, (host.to_string(), name.to_string()));
197        self.by_name.insert(host.to_string(), idx);
198
199        Some(idx)
200    }
201
202    /// Lookup by index
203    pub fn get(&self, idx: u8) -> Option<&(String, String)> {
204        self.hosts.get(&idx)
205    }
206
207    /// Lookup by hostname
208    pub fn get_by_name(&self, host: &str) -> Option<u8> {
209        self.by_name.get(host).copied()
210    }
211
212    /// Remove a host
213    pub fn remove(&mut self, idx: u8) {
214        if let Some((host, _)) = self.hosts.remove(&idx) {
215            self.by_name.remove(&host);
216        }
217    }
218
219    /// List all hosts
220    pub fn list(&self) -> impl Iterator<Item = (u8, &str, &str)> {
221        self.hosts.iter().map(|(&idx, (host, name))| (idx, host.as_str(), name.as_str()))
222    }
223
224    /// Number of cached hosts
225    pub fn len(&self) -> usize {
226        self.hosts.len()
227    }
228
229    /// Check if empty
230    pub fn is_empty(&self) -> bool {
231        self.hosts.is_empty()
232    }
233
234    /// Resolve address - returns connection string
235    pub fn resolve(&self, addr: &Address) -> Option<String> {
236        match addr {
237            Address::Local => Some("local".to_string()),
238            Address::Cached(idx) => self.get(*idx).map(|(host, _)| host.clone()),
239            Address::Inline(s) => Some(s.as_str().to_string()),
240            Address::Broadcast => None, // Cannot resolve broadcast
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_local_address() {
251        let addr = Address::Local;
252        let encoded = addr.encode();
253        assert_eq!(encoded, vec![0x00]);
254
255        let (decoded, len) = Address::decode(&encoded).unwrap();
256        assert_eq!(decoded, Address::Local);
257        assert_eq!(len, 1);
258    }
259
260    #[test]
261    fn test_cached_address() {
262        let addr = Address::Cached(5);
263        let encoded = addr.encode();
264        assert_eq!(encoded, vec![0x05]);
265
266        let (decoded, len) = Address::decode(&encoded).unwrap();
267        assert_eq!(decoded, Address::Cached(5));
268        assert_eq!(len, 1);
269    }
270
271    #[test]
272    fn test_inline_address() {
273        let addr = Address::Inline(AddressString::new("192.168.1.5:28428").unwrap());
274        let encoded = addr.encode();
275
276        // First byte: 0x80 + 16 = 0x90
277        assert_eq!(encoded[0], 0x90);
278
279        let (decoded, len) = Address::decode(&encoded).unwrap();
280        if let Address::Inline(s) = decoded {
281            assert_eq!(s.as_str(), "192.168.1.5:28428");
282        } else {
283            panic!("expected inline address");
284        }
285        assert_eq!(len, 17);
286    }
287
288    #[test]
289    fn test_broadcast() {
290        let addr = Address::Broadcast;
291        let encoded = addr.encode();
292        assert_eq!(encoded, vec![0xFF]);
293    }
294
295    #[cfg(feature = "std")]
296    #[test]
297    fn test_host_cache() {
298        let mut cache = HostCache::new();
299
300        let idx1 = cache.add("server1.local:28428", "Server 1").unwrap();
301        let idx2 = cache.add("server2.local:28428", "Server 2").unwrap();
302
303        assert_eq!(idx1, 1);
304        assert_eq!(idx2, 2);
305
306        // Duplicate returns same index
307        let idx1_again = cache.add("server1.local:28428", "Server 1").unwrap();
308        assert_eq!(idx1_again, idx1);
309
310        // Lookup
311        assert_eq!(cache.get_by_name("server1.local:28428"), Some(1));
312        assert_eq!(cache.get(1).unwrap().0, "server1.local:28428");
313    }
314}