Skip to main content

yscv_model/
tcp_transport.rs

1//! TCP-based transport for multi-node gradient exchange.
2//!
3//! One node acts as coordinator (server), others connect as workers.
4//! Protocol: simple length-prefixed f32 array exchange.
5
6use std::io::{Read, Write};
7use std::net::{TcpListener, TcpStream};
8use std::sync::{Arc, Mutex};
9
10use crate::ModelError;
11
12/// TCP-based transport for multi-node gradient exchange.
13///
14/// One node acts as **coordinator** (rank 0, TCP server) while all other
15/// nodes act as **workers** (TCP clients). The coordinator calls
16/// [`TcpTransport::coordinator`] which binds to a socket and blocks until
17/// `world_size - 1` workers have connected. Each worker calls
18/// [`TcpTransport::worker`] with the coordinator's address and its own rank.
19///
20/// # Wire protocol
21///
22/// Every message is a length-prefixed `f32` array:
23///
24/// 1. 4-byte little-endian `u32` element count.
25/// 2. `count * 4` bytes of little-endian `f32` values.
26///
27/// # Rank coordination
28///
29/// Workers announce their rank as a 4-byte LE `u32` immediately after the
30/// TCP handshake so the coordinator can place each connection in the correct
31/// slot.
32///
33/// # Usage
34///
35/// For single-machine testing use [`loopback_pair`] which creates a
36/// coordinator + worker pair over `127.0.0.1` on a random port.
37pub struct TcpTransport {
38    #[allow(dead_code)]
39    role: NodeRole,
40    peers: Vec<Arc<Mutex<TcpStream>>>,
41    rank: usize,
42    world_size: usize,
43}
44
45/// Describes whether this node is the coordinator (rank 0) or a worker.
46///
47/// The coordinator binds a TCP listener and waits for workers to connect.
48/// Workers initiate connections to the coordinator. After the initial
49/// handshake every peer can send and receive gradient data symmetrically.
50#[derive(Debug, Clone)]
51pub enum NodeRole {
52    /// The coordinator listens on this address for incoming worker connections.
53    /// Always corresponds to rank 0.
54    Coordinator { bind_addr: String },
55    /// A worker connects to the coordinator at this address.
56    /// Rank must be in `1..world_size`.
57    Worker { coordinator_addr: String },
58}
59
60impl TcpTransport {
61    /// Create a coordinator node that listens for worker connections.
62    ///
63    /// Blocks until `world_size - 1` workers have connected.
64    /// The coordinator is always rank 0.
65    pub fn coordinator(bind_addr: &str, world_size: usize) -> Result<Self, ModelError> {
66        if world_size == 0 {
67            return Err(ModelError::TransportError(
68                "world_size must be > 0".to_string(),
69            ));
70        }
71
72        let listener = TcpListener::bind(bind_addr).map_err(|e| {
73            ModelError::TransportError(format!("coordinator failed to bind {bind_addr}: {e}"))
74        })?;
75
76        let mut peers = Vec::with_capacity(world_size);
77        // Slot 0 is unused (self), but we keep it to index by rank.
78        // We'll fill it with a dummy that is never used.
79        // Actually, store peers in order of connection; index by worker rank.
80        // Workers send their rank as a u32 LE upon connecting.
81
82        // Pre-allocate with None so we can insert by rank.
83        let mut peer_slots: Vec<Option<Arc<Mutex<TcpStream>>>> =
84            (0..world_size).map(|_| None).collect();
85
86        for _ in 1..world_size {
87            let (mut stream, _addr) = listener.accept().map_err(|e| {
88                ModelError::TransportError(format!("coordinator accept failed: {e}"))
89            })?;
90
91            stream
92                .set_nodelay(true)
93                .map_err(|e| ModelError::TransportError(format!("set_nodelay failed: {e}")))?;
94
95            // Read the worker's rank.
96            let mut rank_buf = [0u8; 4];
97            stream.read_exact(&mut rank_buf).map_err(|e| {
98                ModelError::TransportError(format!("failed to read worker rank: {e}"))
99            })?;
100            let worker_rank = u32::from_le_bytes(rank_buf) as usize;
101
102            if worker_rank == 0 || worker_rank >= world_size {
103                return Err(ModelError::TransportError(format!(
104                    "invalid worker rank {worker_rank}"
105                )));
106            }
107
108            peer_slots[worker_rank] = Some(Arc::new(Mutex::new(stream)));
109        }
110
111        // Build the peers vec. Index 0 is self (coordinator) -- store a placeholder.
112        for slot in peer_slots.iter_mut() {
113            if let Some(s) = slot.take() {
114                peers.push(s);
115            }
116        }
117
118        Ok(Self {
119            role: NodeRole::Coordinator {
120                bind_addr: bind_addr.to_string(),
121            },
122            peers,
123            rank: 0,
124            world_size,
125        })
126    }
127
128    /// Create a worker node that connects to the coordinator.
129    pub fn worker(coordinator_addr: &str, rank: usize) -> Result<Self, ModelError> {
130        if rank == 0 {
131            return Err(ModelError::TransportError(
132                "worker rank must be > 0; use coordinator() for rank 0".to_string(),
133            ));
134        }
135
136        let mut stream = TcpStream::connect(coordinator_addr).map_err(|e| {
137            ModelError::TransportError(format!(
138                "worker rank {rank} failed to connect to {coordinator_addr}: {e}"
139            ))
140        })?;
141
142        stream
143            .set_nodelay(true)
144            .map_err(|e| ModelError::TransportError(format!("set_nodelay failed: {e}")))?;
145
146        // Announce our rank.
147        stream
148            .write_all(&(rank as u32).to_le_bytes())
149            .map_err(|e| ModelError::TransportError(format!("failed to send rank: {e}")))?;
150
151        let peers = vec![Arc::new(Mutex::new(stream))];
152
153        Ok(Self {
154            role: NodeRole::Worker {
155                coordinator_addr: coordinator_addr.to_string(),
156            },
157            peers,
158            rank,
159            world_size: 0, // will be unknown until protocol exchange; for now unused
160        })
161    }
162
163    /// Send a tensor's data to a specific peer.
164    ///
165    /// `peer` is an index into the peers list (0-based).
166    pub fn send(&self, peer: usize, data: &[f32]) -> Result<(), ModelError> {
167        if peer >= self.peers.len() {
168            return Err(ModelError::TransportError(format!(
169                "peer index {peer} out of range (have {} peers)",
170                self.peers.len()
171            )));
172        }
173        let stream = &self.peers[peer];
174        let mut stream = stream
175            .lock()
176            .map_err(|_| ModelError::TransportError("lock poisoned".into()))?;
177        let len = data.len() as u32;
178        stream
179            .write_all(&len.to_le_bytes())
180            .map_err(|e| ModelError::TransportError(e.to_string()))?;
181        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
182        stream
183            .write_all(&bytes)
184            .map_err(|e| ModelError::TransportError(e.to_string()))?;
185        Ok(())
186    }
187
188    /// Receive tensor data from a specific peer.
189    ///
190    /// `peer` is an index into the peers list (0-based).
191    pub fn recv(&self, peer: usize) -> Result<Vec<f32>, ModelError> {
192        if peer >= self.peers.len() {
193            return Err(ModelError::TransportError(format!(
194                "peer index {peer} out of range (have {} peers)",
195                self.peers.len()
196            )));
197        }
198        let stream = &self.peers[peer];
199        let mut stream = stream
200            .lock()
201            .map_err(|_| ModelError::TransportError("lock poisoned".into()))?;
202        let mut len_buf = [0u8; 4];
203        stream
204            .read_exact(&mut len_buf)
205            .map_err(|e| ModelError::TransportError(e.to_string()))?;
206        let len = u32::from_le_bytes(len_buf) as usize;
207        let mut buf = vec![0u8; len * 4];
208        stream
209            .read_exact(&mut buf)
210            .map_err(|e| ModelError::TransportError(e.to_string()))?;
211        let data: Vec<f32> = buf
212            .chunks(4)
213            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
214            .collect();
215        Ok(data)
216    }
217
218    /// All-reduce: sum gradients across all nodes.
219    ///
220    /// Implements a simple butterfly all-reduce pattern: the coordinator
221    /// collects data from all workers, computes the element-wise sum, and
222    /// broadcasts the result back.
223    pub fn allreduce_sum(&self, data: &mut [f32]) -> Result<(), ModelError> {
224        let n = self.peers.len() + 1; // total nodes = peers + self
225        if n <= 1 {
226            return Ok(());
227        }
228
229        if self.rank == 0 {
230            // Coordinator: receive from each worker, sum, broadcast result.
231            let mut acc: Vec<f32> = data.to_vec();
232            for peer_idx in 0..self.peers.len() {
233                let remote = self.recv(peer_idx)?;
234                if remote.len() != acc.len() {
235                    return Err(ModelError::TransportError(format!(
236                        "allreduce length mismatch: expected {}, got {}",
237                        acc.len(),
238                        remote.len()
239                    )));
240                }
241                for (a, r) in acc.iter_mut().zip(remote.iter()) {
242                    *a += r;
243                }
244            }
245            // Broadcast summed result back to all workers.
246            for peer_idx in 0..self.peers.len() {
247                self.send(peer_idx, &acc)?;
248            }
249            // Update local data in place.
250            data.copy_from_slice(&acc);
251        } else {
252            // Worker: send local data to coordinator, receive summed result.
253            self.send(0, data)?;
254            let result = self.recv(0)?;
255            if result.len() != data.len() {
256                return Err(ModelError::TransportError(format!(
257                    "allreduce length mismatch: expected {}, got {}",
258                    data.len(),
259                    result.len()
260                )));
261            }
262            data.copy_from_slice(&result);
263        }
264
265        Ok(())
266    }
267
268    /// Returns the rank of this transport instance.
269    pub fn rank(&self) -> usize {
270        self.rank
271    }
272
273    /// Returns the world size (total number of nodes).
274    pub fn world_size(&self) -> usize {
275        // For coordinator, world_size is set during construction.
276        // For workers, world_size = peers.len() + 1 (self).
277        if self.world_size > 0 {
278            self.world_size
279        } else {
280            self.peers.len() + 1
281        }
282    }
283}
284
285/// Wrapper that uses a [`TcpTransport`] for gradient aggregation.
286pub struct TcpAllReduceAggregator {
287    transport: TcpTransport,
288}
289
290impl TcpAllReduceAggregator {
291    /// Create a new aggregator wrapping the given TCP transport.
292    pub fn new(transport: TcpTransport) -> Self {
293        Self { transport }
294    }
295
296    /// All-reduce (sum) raw f32 slices across all connected nodes.
297    pub fn allreduce_sum(&self, data: &mut [f32]) -> Result<(), ModelError> {
298        self.transport.allreduce_sum(data)
299    }
300}
301
302/// Create a loopback TCP transport pair for testing.
303///
304/// Starts a coordinator on a random port on localhost and connects a single
305/// worker to it. Returns `(coordinator, worker)`.
306pub fn loopback_pair() -> Result<(TcpTransport, TcpTransport), ModelError> {
307    // Bind to port 0 to get a random available port.
308    let listener = TcpListener::bind("127.0.0.1:0").map_err(|e| {
309        ModelError::TransportError(format!("failed to bind loopback listener: {e}"))
310    })?;
311    let port = listener
312        .local_addr()
313        .map_err(|e| ModelError::TransportError(format!("failed to get local addr: {e}")))?
314        .port();
315
316    // We need to accept in a separate thread because coordinator() blocks.
317    // Instead, build the transports manually using the raw listener.
318
319    // Spawn the worker connection in a thread.
320    let addr = format!("127.0.0.1:{port}");
321    let addr_clone = addr.clone();
322    let worker_handle = std::thread::spawn(move || -> Result<TcpTransport, ModelError> {
323        TcpTransport::worker(&addr_clone, 1)
324    });
325
326    // Accept the single worker connection as coordinator.
327    // Build coordinator manually from the existing listener.
328    let (mut stream, _) = listener
329        .accept()
330        .map_err(|e| ModelError::TransportError(format!("loopback accept failed: {e}")))?;
331    stream
332        .set_nodelay(true)
333        .map_err(|e| ModelError::TransportError(format!("set_nodelay failed: {e}")))?;
334
335    // Read the worker's rank announcement.
336    let mut rank_buf = [0u8; 4];
337    stream
338        .read_exact(&mut rank_buf)
339        .map_err(|e| ModelError::TransportError(format!("failed to read worker rank: {e}")))?;
340
341    let coordinator = TcpTransport {
342        role: NodeRole::Coordinator { bind_addr: addr },
343        peers: vec![Arc::new(Mutex::new(stream))],
344        rank: 0,
345        world_size: 2,
346    };
347
348    let worker = worker_handle
349        .join()
350        .map_err(|_| ModelError::TransportError("worker thread panicked".to_string()))??;
351
352    Ok((coordinator, worker))
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn tcp_loopback_send_recv() {
361        let (coord, worker) = loopback_pair().unwrap();
362
363        let send_data: Vec<f32> = vec![1.0, 2.5, -3.0, 0.0];
364
365        // Coordinator sends to worker (peer index 0 = the single worker).
366        coord.send(0, &send_data).unwrap();
367
368        // Worker receives from coordinator (peer index 0 = the coordinator).
369        let received = worker.recv(0).unwrap();
370
371        assert_eq!(received, send_data);
372    }
373
374    #[test]
375    fn tcp_loopback_allreduce() {
376        let (coord, worker) = loopback_pair().unwrap();
377
378        // Coordinator has [1.0, 2.0, 3.0], worker has [4.0, 5.0, 6.0].
379        // After allreduce_sum, both should have [5.0, 7.0, 9.0].
380
381        let coord_handle = std::thread::spawn(move || -> Result<Vec<f32>, ModelError> {
382            let mut data = vec![1.0, 2.0, 3.0];
383            coord.allreduce_sum(&mut data)?;
384            Ok(data)
385        });
386
387        let worker_handle = std::thread::spawn(move || -> Result<Vec<f32>, ModelError> {
388            let mut data = vec![4.0, 5.0, 6.0];
389            worker.allreduce_sum(&mut data)?;
390            Ok(data)
391        });
392
393        let coord_result = coord_handle.join().unwrap().unwrap();
394        let worker_result = worker_handle.join().unwrap().unwrap();
395
396        assert_eq!(coord_result, vec![5.0, 7.0, 9.0]);
397        assert_eq!(worker_result, vec![5.0, 7.0, 9.0]);
398    }
399
400    #[test]
401    fn tcp_transport_rank_world_size() {
402        let (coord, worker) = loopback_pair().unwrap();
403
404        assert_eq!(coord.rank(), 0);
405        assert_eq!(coord.world_size(), 2);
406
407        assert_eq!(worker.rank(), 1);
408        assert_eq!(worker.world_size(), 2);
409    }
410}