yscv_model/
tcp_transport.rs1use std::io::{Read, Write};
7use std::net::{TcpListener, TcpStream};
8use std::sync::{Arc, Mutex};
9
10use crate::ModelError;
11
12pub struct TcpTransport {
38 #[allow(dead_code)]
39 role: NodeRole,
40 peers: Vec<Arc<Mutex<TcpStream>>>,
41 rank: usize,
42 world_size: usize,
43}
44
45#[derive(Debug, Clone)]
51pub enum NodeRole {
52 Coordinator { bind_addr: String },
55 Worker { coordinator_addr: String },
58}
59
60impl TcpTransport {
61 pub fn coordinator(bind_addr: &str, world_size: usize) -> Result<Self, ModelError> {
66 if world_size == 0 {
67 return Err(ModelError::TransportError(
68 "world_size must be > 0".to_string(),
69 ));
70 }
71
72 let listener = TcpListener::bind(bind_addr).map_err(|e| {
73 ModelError::TransportError(format!("coordinator failed to bind {bind_addr}: {e}"))
74 })?;
75
76 let mut peers = Vec::with_capacity(world_size);
77 let mut peer_slots: Vec<Option<Arc<Mutex<TcpStream>>>> =
84 (0..world_size).map(|_| None).collect();
85
86 for _ in 1..world_size {
87 let (mut stream, _addr) = listener.accept().map_err(|e| {
88 ModelError::TransportError(format!("coordinator accept failed: {e}"))
89 })?;
90
91 stream
92 .set_nodelay(true)
93 .map_err(|e| ModelError::TransportError(format!("set_nodelay failed: {e}")))?;
94
95 let mut rank_buf = [0u8; 4];
97 stream.read_exact(&mut rank_buf).map_err(|e| {
98 ModelError::TransportError(format!("failed to read worker rank: {e}"))
99 })?;
100 let worker_rank = u32::from_le_bytes(rank_buf) as usize;
101
102 if worker_rank == 0 || worker_rank >= world_size {
103 return Err(ModelError::TransportError(format!(
104 "invalid worker rank {worker_rank}"
105 )));
106 }
107
108 peer_slots[worker_rank] = Some(Arc::new(Mutex::new(stream)));
109 }
110
111 for slot in peer_slots.iter_mut() {
113 if let Some(s) = slot.take() {
114 peers.push(s);
115 }
116 }
117
118 Ok(Self {
119 role: NodeRole::Coordinator {
120 bind_addr: bind_addr.to_string(),
121 },
122 peers,
123 rank: 0,
124 world_size,
125 })
126 }
127
128 pub fn worker(coordinator_addr: &str, rank: usize) -> Result<Self, ModelError> {
130 if rank == 0 {
131 return Err(ModelError::TransportError(
132 "worker rank must be > 0; use coordinator() for rank 0".to_string(),
133 ));
134 }
135
136 let mut stream = TcpStream::connect(coordinator_addr).map_err(|e| {
137 ModelError::TransportError(format!(
138 "worker rank {rank} failed to connect to {coordinator_addr}: {e}"
139 ))
140 })?;
141
142 stream
143 .set_nodelay(true)
144 .map_err(|e| ModelError::TransportError(format!("set_nodelay failed: {e}")))?;
145
146 stream
148 .write_all(&(rank as u32).to_le_bytes())
149 .map_err(|e| ModelError::TransportError(format!("failed to send rank: {e}")))?;
150
151 let peers = vec![Arc::new(Mutex::new(stream))];
152
153 Ok(Self {
154 role: NodeRole::Worker {
155 coordinator_addr: coordinator_addr.to_string(),
156 },
157 peers,
158 rank,
159 world_size: 0, })
161 }
162
163 pub fn send(&self, peer: usize, data: &[f32]) -> Result<(), ModelError> {
167 if peer >= self.peers.len() {
168 return Err(ModelError::TransportError(format!(
169 "peer index {peer} out of range (have {} peers)",
170 self.peers.len()
171 )));
172 }
173 let stream = &self.peers[peer];
174 let mut stream = stream
175 .lock()
176 .map_err(|_| ModelError::TransportError("lock poisoned".into()))?;
177 let len = data.len() as u32;
178 stream
179 .write_all(&len.to_le_bytes())
180 .map_err(|e| ModelError::TransportError(e.to_string()))?;
181 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
182 stream
183 .write_all(&bytes)
184 .map_err(|e| ModelError::TransportError(e.to_string()))?;
185 Ok(())
186 }
187
188 pub fn recv(&self, peer: usize) -> Result<Vec<f32>, ModelError> {
192 if peer >= self.peers.len() {
193 return Err(ModelError::TransportError(format!(
194 "peer index {peer} out of range (have {} peers)",
195 self.peers.len()
196 )));
197 }
198 let stream = &self.peers[peer];
199 let mut stream = stream
200 .lock()
201 .map_err(|_| ModelError::TransportError("lock poisoned".into()))?;
202 let mut len_buf = [0u8; 4];
203 stream
204 .read_exact(&mut len_buf)
205 .map_err(|e| ModelError::TransportError(e.to_string()))?;
206 let len = u32::from_le_bytes(len_buf) as usize;
207 let mut buf = vec![0u8; len * 4];
208 stream
209 .read_exact(&mut buf)
210 .map_err(|e| ModelError::TransportError(e.to_string()))?;
211 let data: Vec<f32> = buf
212 .chunks(4)
213 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
214 .collect();
215 Ok(data)
216 }
217
218 pub fn allreduce_sum(&self, data: &mut [f32]) -> Result<(), ModelError> {
224 let n = self.peers.len() + 1; if n <= 1 {
226 return Ok(());
227 }
228
229 if self.rank == 0 {
230 let mut acc: Vec<f32> = data.to_vec();
232 for peer_idx in 0..self.peers.len() {
233 let remote = self.recv(peer_idx)?;
234 if remote.len() != acc.len() {
235 return Err(ModelError::TransportError(format!(
236 "allreduce length mismatch: expected {}, got {}",
237 acc.len(),
238 remote.len()
239 )));
240 }
241 for (a, r) in acc.iter_mut().zip(remote.iter()) {
242 *a += r;
243 }
244 }
245 for peer_idx in 0..self.peers.len() {
247 self.send(peer_idx, &acc)?;
248 }
249 data.copy_from_slice(&acc);
251 } else {
252 self.send(0, data)?;
254 let result = self.recv(0)?;
255 if result.len() != data.len() {
256 return Err(ModelError::TransportError(format!(
257 "allreduce length mismatch: expected {}, got {}",
258 data.len(),
259 result.len()
260 )));
261 }
262 data.copy_from_slice(&result);
263 }
264
265 Ok(())
266 }
267
268 pub fn rank(&self) -> usize {
270 self.rank
271 }
272
273 pub fn world_size(&self) -> usize {
275 if self.world_size > 0 {
278 self.world_size
279 } else {
280 self.peers.len() + 1
281 }
282 }
283}
284
285pub struct TcpAllReduceAggregator {
287 transport: TcpTransport,
288}
289
290impl TcpAllReduceAggregator {
291 pub fn new(transport: TcpTransport) -> Self {
293 Self { transport }
294 }
295
296 pub fn allreduce_sum(&self, data: &mut [f32]) -> Result<(), ModelError> {
298 self.transport.allreduce_sum(data)
299 }
300}
301
302pub fn loopback_pair() -> Result<(TcpTransport, TcpTransport), ModelError> {
307 let listener = TcpListener::bind("127.0.0.1:0").map_err(|e| {
309 ModelError::TransportError(format!("failed to bind loopback listener: {e}"))
310 })?;
311 let port = listener
312 .local_addr()
313 .map_err(|e| ModelError::TransportError(format!("failed to get local addr: {e}")))?
314 .port();
315
316 let addr = format!("127.0.0.1:{port}");
321 let addr_clone = addr.clone();
322 let worker_handle = std::thread::spawn(move || -> Result<TcpTransport, ModelError> {
323 TcpTransport::worker(&addr_clone, 1)
324 });
325
326 let (mut stream, _) = listener
329 .accept()
330 .map_err(|e| ModelError::TransportError(format!("loopback accept failed: {e}")))?;
331 stream
332 .set_nodelay(true)
333 .map_err(|e| ModelError::TransportError(format!("set_nodelay failed: {e}")))?;
334
335 let mut rank_buf = [0u8; 4];
337 stream
338 .read_exact(&mut rank_buf)
339 .map_err(|e| ModelError::TransportError(format!("failed to read worker rank: {e}")))?;
340
341 let coordinator = TcpTransport {
342 role: NodeRole::Coordinator { bind_addr: addr },
343 peers: vec![Arc::new(Mutex::new(stream))],
344 rank: 0,
345 world_size: 2,
346 };
347
348 let worker = worker_handle
349 .join()
350 .map_err(|_| ModelError::TransportError("worker thread panicked".to_string()))??;
351
352 Ok((coordinator, worker))
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn tcp_loopback_send_recv() {
361 let (coord, worker) = loopback_pair().unwrap();
362
363 let send_data: Vec<f32> = vec![1.0, 2.5, -3.0, 0.0];
364
365 coord.send(0, &send_data).unwrap();
367
368 let received = worker.recv(0).unwrap();
370
371 assert_eq!(received, send_data);
372 }
373
374 #[test]
375 fn tcp_loopback_allreduce() {
376 let (coord, worker) = loopback_pair().unwrap();
377
378 let coord_handle = std::thread::spawn(move || -> Result<Vec<f32>, ModelError> {
382 let mut data = vec![1.0, 2.0, 3.0];
383 coord.allreduce_sum(&mut data)?;
384 Ok(data)
385 });
386
387 let worker_handle = std::thread::spawn(move || -> Result<Vec<f32>, ModelError> {
388 let mut data = vec![4.0, 5.0, 6.0];
389 worker.allreduce_sum(&mut data)?;
390 Ok(data)
391 });
392
393 let coord_result = coord_handle.join().unwrap().unwrap();
394 let worker_result = worker_handle.join().unwrap().unwrap();
395
396 assert_eq!(coord_result, vec![5.0, 7.0, 9.0]);
397 assert_eq!(worker_result, vec![5.0, 7.0, 9.0]);
398 }
399
400 #[test]
401 fn tcp_transport_rank_world_size() {
402 let (coord, worker) = loopback_pair().unwrap();
403
404 assert_eq!(coord.rank(), 0);
405 assert_eq!(coord.world_size(), 2);
406
407 assert_eq!(worker.rank(), 1);
408 assert_eq!(worker.world_size(), 2);
409 }
410}