Skip to main content

torsh_distributed/
rpc.rs

1//! Remote Procedure Call (RPC) framework for distributed training
2//!
3//! This module provides a complete RPC framework for distributed training,
4//! supporting remote function calls, remote references, and distributed
5//! computation patterns.
6
7use crate::{TorshDistributedError, TorshResult};
8use log::info;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12use std::time::Duration;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::net::{TcpListener, TcpStream};
15use tokio::sync::{mpsc, oneshot, RwLock};
16use uuid::Uuid;
17
18// Type aliases for complex types
19type PendingRequestMap = Arc<Mutex<HashMap<String, oneshot::Sender<Result<Vec<u8>, String>>>>>;
20type FunctionRegistry =
21    Arc<RwLock<HashMap<String, Box<dyn Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync>>>>;
22
23/// RPC backend options
24#[derive(Debug, Clone)]
25pub struct RpcBackendOptions {
26    pub num_worker_threads: usize,
27    pub rpc_timeout: Duration,
28    pub init_method: String,
29    pub buffer_size: usize,
30    pub max_connections: usize,
31}
32
33impl Default for RpcBackendOptions {
34    fn default() -> Self {
35        Self {
36            num_worker_threads: 4,
37            rpc_timeout: Duration::from_secs(60),
38            init_method: String::from("env://"),
39            buffer_size: 8192,
40            max_connections: 100,
41        }
42    }
43}
44
45/// RPC message types
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum RpcMessage {
48    /// Function call request
49    FunctionCall {
50        id: String,
51        function_name: String,
52        args: Vec<u8>, // Serialized arguments
53    },
54    /// Function call response
55    FunctionResponse {
56        id: String,
57        result: Result<Vec<u8>, String>, // Serialized result or error
58    },
59    /// Remote reference creation
60    RemoteRef {
61        id: String,
62        function_name: String,
63        args: Vec<u8>,
64        rref_id: String,
65    },
66    /// Remote reference response
67    RemoteRefResponse {
68        id: String,
69        result: Result<String, String>, // RRef ID or error
70    },
71    /// Remote reference deletion
72    DeleteRRef { rref_id: String },
73    /// Ping message for health checks
74    Ping,
75    /// Pong response
76    Pong,
77}
78
79/// Remote reference to a value
80#[derive(Debug, Clone)]
81pub struct RRef<T> {
82    id: String,
83    owner_rank: u32,
84    _phantom: std::marker::PhantomData<T>,
85}
86
87impl<T> RRef<T> {
88    pub fn new(id: String, owner_rank: u32) -> Self {
89        Self {
90            id,
91            owner_rank,
92            _phantom: std::marker::PhantomData,
93        }
94    }
95
96    pub fn id(&self) -> &str {
97        &self.id
98    }
99
100    pub fn owner_rank(&self) -> u32 {
101        self.owner_rank
102    }
103}
104
105/// RPC worker state
106struct RpcWorker {
107    rank: u32,
108    world_size: u32,
109    connections: Arc<RwLock<HashMap<u32, TcpStream>>>,
110    pending_requests: PendingRequestMap,
111    remote_refs: Arc<RwLock<HashMap<String, Box<dyn std::any::Any + Send + Sync>>>>,
112    function_registry: FunctionRegistry,
113    shutdown_tx: Option<mpsc::Sender<()>>,
114}
115
116/// Global RPC worker instance
117static RPC_WORKER: once_cell::sync::OnceCell<Arc<Mutex<Option<RpcWorker>>>> =
118    once_cell::sync::OnceCell::new();
119
120// Test-only RPC worker for isolated testing
121#[cfg(test)]
122thread_local! {
123    static TEST_RPC_WORKER: std::cell::RefCell<Option<Arc<Mutex<Option<RpcWorker>>>>> = const { std::cell::RefCell::new(None) };
124}
125
126/// Get the global RPC worker
127fn get_rpc_worker() -> TorshResult<Arc<Mutex<Option<RpcWorker>>>> {
128    #[cfg(test)]
129    {
130        // In tests, try thread-local first
131        let local_worker = TEST_RPC_WORKER.with(|w| w.borrow().clone());
132        if let Some(worker) = local_worker {
133            return Ok(worker);
134        }
135    }
136
137    RPC_WORKER
138        .get()
139        .ok_or(TorshDistributedError::BackendNotInitialized)
140        .cloned()
141}
142
143/// Initialize RPC framework
144pub async fn init_rpc(
145    name: &str,
146    rank: u32,
147    world_size: u32,
148    _options: RpcBackendOptions,
149) -> TorshResult<()> {
150    // Initialize the RPC worker
151    let worker = RpcWorker {
152        rank,
153        world_size,
154        connections: Arc::new(RwLock::new(HashMap::new())),
155        pending_requests: Arc::new(Mutex::new(HashMap::new())),
156        remote_refs: Arc::new(RwLock::new(HashMap::new())),
157        function_registry: Arc::new(RwLock::new(HashMap::new())),
158        shutdown_tx: None,
159    };
160
161    let worker_arc = Arc::new(Mutex::new(Some(worker)));
162
163    // Set the worker (test-local in tests, global otherwise)
164    #[cfg(test)]
165    {
166        TEST_RPC_WORKER.with(|w| *w.borrow_mut() = Some(worker_arc.clone()));
167    }
168
169    #[cfg(not(test))]
170    {
171        RPC_WORKER
172            .set(worker_arc.clone())
173            .map_err(|_| TorshDistributedError::backend_error("rpc", "RPC already initialized"))?;
174    }
175
176    // Start RPC server with dynamic port allocation for tests
177    let base_port = if cfg!(test) {
178        // Use a wider range of ports for testing to avoid conflicts
179        29600 + (std::process::id() % 1000) + rank * 100
180    } else {
181        29600 + rank
182    };
183
184    // Try multiple ports if the first one fails
185    let mut listener = None;
186    for port_offset in 0..10 {
187        let listen_addr = format!("127.0.0.1:{}", base_port + port_offset);
188        match TcpListener::bind(&listen_addr).await {
189            Ok(l) => {
190                info!(
191                    "[RPC] Worker '{}' (rank {}) starting on {}",
192                    name, rank, listen_addr
193                );
194                listener = Some(l);
195                break;
196            }
197            Err(e) => {
198                if port_offset == 9 {
199                    return Err(TorshDistributedError::communication_error(
200                        "rpc_server",
201                        format!("Failed to bind after trying 10 ports: {}", e),
202                    ));
203                }
204            }
205        }
206    }
207
208    let listener = listener.expect("listener should be successfully bound");
209
210    let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
211
212    // Update worker with shutdown channel
213    {
214        let mut worker_guard = worker_arc.lock().expect("lock should not be poisoned");
215        if let Some(ref mut worker) = *worker_guard {
216            worker.shutdown_tx = Some(shutdown_tx);
217        }
218    }
219
220    // Spawn server task
221    let worker_for_server = worker_arc.clone();
222    tokio::spawn(async move {
223        loop {
224            tokio::select! {
225                result = listener.accept() => {
226                    match result {
227                        Ok((stream, addr)) => {
228                            info!("[RPC] Accepted connection from {}", addr);
229                            let worker_clone = worker_for_server.clone();
230                            tokio::spawn(handle_connection(stream, worker_clone));
231                        }
232                        Err(e) => {
233                            info!("[RPC] Failed to accept connection: {}", e);
234                        }
235                    }
236                }
237                _ = shutdown_rx.recv() => {
238                    info!("[RPC] Server shutting down");
239                    break;
240                }
241            }
242        }
243    });
244
245    // Wait a bit for other workers to start
246    tokio::time::sleep(Duration::from_millis(100)).await;
247
248    // Connect to other workers (skip for single-worker setups)
249    if world_size > 1 {
250        for other_rank in 0..world_size {
251            if other_rank != rank {
252                let target_addr = format!("127.0.0.1:{}", base_port + other_rank);
253
254                // Retry connection with exponential backoff
255                let mut retries = 0;
256                let max_retries = if cfg!(test) { 3 } else { 10 }; // Fewer retries in tests
257
258                while retries < max_retries {
259                    match TcpStream::connect(&target_addr).await {
260                        Ok(stream) => {
261                            info!(
262                                "[RPC] Connected to worker {} at {}",
263                                other_rank, target_addr
264                            );
265                            let connections = {
266                                let worker_guard =
267                                    worker_arc.lock().expect("lock should not be poisoned");
268                                worker_guard
269                                    .as_ref()
270                                    .expect("worker should be initialized")
271                                    .connections
272                                    .clone()
273                            };
274                            let mut connections_guard = connections.write().await;
275                            connections_guard.insert(other_rank, stream);
276                            break;
277                        }
278                        Err(e) => {
279                            retries += 1;
280                            let delay = Duration::from_millis(100 * (1 << retries.min(3)));
281                            tokio::time::sleep(delay).await;
282                            if retries == max_retries {
283                                // In tests, just log the error and continue
284                                if cfg!(test) {
285                                    info!(
286                                        "[RPC] Failed to connect to worker {} (test mode): {}",
287                                        other_rank, e
288                                    );
289                                    break;
290                                } else {
291                                    return Err(TorshDistributedError::communication_error(
292                                        "rpc_connect",
293                                        format!(
294                                            "Failed to connect to worker {}: {}",
295                                            other_rank, e
296                                        ),
297                                    ));
298                                }
299                            }
300                        }
301                    }
302                }
303            }
304        }
305    }
306
307    info!(
308        "[RPC] Worker '{}' (rank {}) initialized successfully",
309        name, rank
310    );
311    Ok(())
312}
313
314/// Handle incoming RPC connection
315async fn handle_connection(mut stream: TcpStream, worker: Arc<Mutex<Option<RpcWorker>>>) {
316    let mut buffer = vec![0u8; 8192];
317
318    loop {
319        match stream.read(&mut buffer).await {
320            Ok(0) => break, // Connection closed
321            Ok(n) => {
322                let data = &buffer[..n];
323
324                // Try to deserialize the message
325                let result: Result<(RpcMessage, usize), _> =
326                    oxicode::serde::decode_from_slice(data, oxicode::config::standard());
327                match result {
328                    Ok((message, _)) => {
329                        if let Err(e) = handle_rpc_message(message, &mut stream, &worker).await {
330                            info!("[RPC] Error handling message: {}", e);
331                        }
332                    }
333                    Err(e) => {
334                        info!("[RPC] Failed to deserialize message: {}", e);
335                    }
336                }
337            }
338            Err(e) => {
339                info!("[RPC] Connection error: {}", e);
340                break;
341            }
342        }
343    }
344}
345
346/// Handle a specific RPC message
347async fn handle_rpc_message(
348    message: RpcMessage,
349    stream: &mut TcpStream,
350    worker: &Arc<Mutex<Option<RpcWorker>>>,
351) -> TorshResult<()> {
352    match message {
353        RpcMessage::FunctionCall {
354            id,
355            function_name,
356            args,
357        } => {
358            // Get a clone of the function registry to avoid holding locks across await
359            let function_registry = {
360                let worker_guard = worker.lock().expect("lock should not be poisoned");
361                let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
362                worker_ref.function_registry.clone()
363            };
364
365            let result = {
366                let registry = function_registry.read().await;
367                if let Some(func) = registry.get(&function_name) {
368                    func(&args)
369                } else {
370                    Err(format!("Function '{}' not found", function_name))
371                }
372            };
373
374            let response = RpcMessage::FunctionResponse { id, result };
375            let response_data =
376                oxicode::serde::encode_to_vec(&response, oxicode::config::standard()).map_err(
377                    |e| {
378                        TorshDistributedError::communication_error(
379                            "rpc",
380                            format!("Serialization error: {}", e),
381                        )
382                    },
383                )?;
384
385            stream.write_all(&response_data).await.map_err(|e| {
386                TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
387            })?;
388        }
389
390        RpcMessage::RemoteRef {
391            id,
392            function_name,
393            args,
394            rref_id,
395        } => {
396            // Get clones to avoid holding locks across await
397            let (function_registry, remote_refs) = {
398                let worker_guard = worker.lock().expect("lock should not be poisoned");
399                let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
400                (
401                    worker_ref.function_registry.clone(),
402                    worker_ref.remote_refs.clone(),
403                )
404            };
405
406            let result = {
407                let registry = function_registry.read().await;
408                if let Some(func) = registry.get(&function_name) {
409                    match func(&args) {
410                        Ok(result_data) => {
411                            // Store the result as a remote reference
412                            let mut refs = remote_refs.write().await;
413                            // For simplicity, store as Vec<u8>
414                            refs.insert(rref_id.clone(), Box::new(result_data));
415                            Ok(rref_id)
416                        }
417                        Err(e) => Err(e),
418                    }
419                } else {
420                    Err(format!("Function '{}' not found", function_name))
421                }
422            };
423
424            let response = RpcMessage::RemoteRefResponse { id, result };
425            let response_data =
426                oxicode::serde::encode_to_vec(&response, oxicode::config::standard()).map_err(
427                    |e| {
428                        TorshDistributedError::communication_error(
429                            "rpc",
430                            format!("Serialization error: {}", e),
431                        )
432                    },
433                )?;
434
435            stream.write_all(&response_data).await.map_err(|e| {
436                TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
437            })?;
438        }
439
440        RpcMessage::DeleteRRef { rref_id } => {
441            let remote_refs = {
442                let worker_guard = worker.lock().expect("lock should not be poisoned");
443                let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
444                worker_ref.remote_refs.clone()
445            };
446
447            let mut refs = remote_refs.write().await;
448            refs.remove(&rref_id);
449        }
450
451        RpcMessage::Ping => {
452            let response = RpcMessage::Pong;
453            let response_data =
454                oxicode::serde::encode_to_vec(&response, oxicode::config::standard()).map_err(
455                    |e| {
456                        TorshDistributedError::communication_error(
457                            "rpc",
458                            format!("Serialization error: {}", e),
459                        )
460                    },
461                )?;
462
463            stream.write_all(&response_data).await.map_err(|e| {
464                TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
465            })?;
466        }
467
468        _ => {
469            // Handle responses by forwarding to pending requests
470            if let RpcMessage::FunctionResponse { id, result } = message {
471                let worker_guard = worker.lock().expect("lock should not be poisoned");
472                let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
473                let mut pending = worker_ref
474                    .pending_requests
475                    .lock()
476                    .expect("lock should not be poisoned");
477
478                if let Some(sender) = pending.remove(&id) {
479                    let _ = sender.send(result);
480                }
481            }
482        }
483    }
484
485    Ok(())
486}
487
488/// Shutdown RPC framework
489pub async fn shutdown() -> TorshResult<()> {
490    let worker_arc = get_rpc_worker()?;
491
492    let (shutdown_tx, remote_refs) = {
493        let mut worker_guard = worker_arc.lock().expect("lock should not be poisoned");
494        if let Some(worker) = worker_guard.take() {
495            (worker.shutdown_tx, Some(worker.remote_refs))
496        } else {
497            (None, None)
498        }
499    };
500
501    if let Some(shutdown_tx) = shutdown_tx {
502        let _ = shutdown_tx.send(()).await;
503    }
504
505    if let Some(remote_refs) = remote_refs {
506        // Clear all remote references
507        let mut refs = remote_refs.write().await;
508        refs.clear();
509    }
510
511    info!("[RPC] Framework shut down successfully");
512    Ok(())
513}
514
515/// Register a function for remote execution
516pub async fn register_function<F, Args, Ret>(name: &str, func: F) -> TorshResult<()>
517where
518    F: Fn(Args) -> Result<Ret, String> + Send + Sync + 'static,
519    Args: for<'de> Deserialize<'de> + 'static,
520    Ret: Serialize + 'static,
521{
522    let worker_arc = get_rpc_worker()?;
523    let function_registry = {
524        let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
525        let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
526        worker_ref.function_registry.clone()
527    };
528
529    let wrapper = move |args_bytes: &[u8]| -> Result<Vec<u8>, String> {
530        let (args, _): (Args, usize) =
531            oxicode::serde::decode_from_slice(args_bytes, oxicode::config::standard())
532                .map_err(|e| format!("Deserialization error: {}", e))?;
533
534        let result = func(args)?;
535
536        oxicode::serde::encode_to_vec(&result, oxicode::config::standard())
537            .map_err(|e| format!("Serialization error: {}", e))
538    };
539
540    let mut registry = function_registry.write().await;
541    registry.insert(name.to_string(), Box::new(wrapper));
542
543    Ok(())
544}
545
546/// Call a remote function
547pub async fn rpc_async<Args, Ret>(to: u32, function_name: &str, args: Args) -> TorshResult<Ret>
548where
549    Args: Serialize,
550    Ret: for<'de> Deserialize<'de>,
551{
552    let worker_arc = get_rpc_worker()?;
553
554    // Serialize arguments
555    let args_bytes =
556        oxicode::serde::encode_to_vec(&args, oxicode::config::standard()).map_err(|e| {
557            TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
558        })?;
559
560    // Generate request ID
561    let request_id = Uuid::new_v4().to_string();
562
563    // Create message
564    let message = RpcMessage::FunctionCall {
565        id: request_id.clone(),
566        function_name: function_name.to_string(),
567        args: args_bytes,
568    };
569
570    // Serialize message
571    let message_bytes = oxicode::serde::encode_to_vec(&message, oxicode::config::standard())
572        .map_err(|e| {
573            TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
574        })?;
575
576    // Get clones to avoid holding locks across await
577    let (connections, pending_requests) = {
578        let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
579        let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
580        (
581            worker_ref.connections.clone(),
582            worker_ref.pending_requests.clone(),
583        )
584    };
585
586    // Create response channel and register pending request
587    let (response_tx, response_rx) = oneshot::channel();
588    {
589        let mut pending = pending_requests
590            .lock()
591            .expect("lock should not be poisoned");
592        pending.insert(request_id, response_tx);
593    }
594
595    // Get connection and send message
596    {
597        let mut connections_guard = connections.write().await;
598        let connection = connections_guard.get_mut(&to).ok_or_else(|| {
599            TorshDistributedError::communication_error(
600                "rpc",
601                format!("No connection to worker {}", to),
602            )
603        })?;
604
605        connection.write_all(&message_bytes).await.map_err(|e| {
606            TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
607        })?;
608    }
609
610    // Wait for response with timeout
611    let result = tokio::time::timeout(Duration::from_secs(60), response_rx)
612        .await
613        .map_err(|_| TorshDistributedError::communication_error("rpc", "RPC timeout"))?
614        .map_err(|_| {
615            TorshDistributedError::communication_error("rpc", "Response channel closed")
616        })?;
617
618    match result {
619        Ok(result_bytes) => {
620            let (value, _): (Ret, usize) =
621                oxicode::serde::decode_from_slice(&result_bytes, oxicode::config::standard())
622                    .map_err(|e| {
623                        TorshDistributedError::communication_error(
624                            "rpc",
625                            format!("Deserialization error: {}", e),
626                        )
627                    })?;
628            Ok(value)
629        }
630        Err(error_msg) => Err(TorshDistributedError::communication_error(
631            "rpc_remote",
632            format!("Remote function error: {}", error_msg),
633        )),
634    }
635}
636
637/// Get a remote reference
638pub async fn remote<Args, Ret>(to: u32, function_name: &str, args: Args) -> TorshResult<RRef<Ret>>
639where
640    Args: Serialize,
641    Ret: for<'de> Deserialize<'de> + 'static,
642{
643    let worker_arc = get_rpc_worker()?;
644
645    // Serialize arguments
646    let args_bytes =
647        oxicode::serde::encode_to_vec(&args, oxicode::config::standard()).map_err(|e| {
648            TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
649        })?;
650
651    // Generate request and RRef IDs
652    let request_id = Uuid::new_v4().to_string();
653    let rref_id = Uuid::new_v4().to_string();
654
655    // Create message
656    let message = RpcMessage::RemoteRef {
657        id: request_id.clone(),
658        function_name: function_name.to_string(),
659        args: args_bytes,
660        rref_id: rref_id.clone(),
661    };
662
663    // Serialize message
664    let message_bytes = oxicode::serde::encode_to_vec(&message, oxicode::config::standard())
665        .map_err(|e| {
666            TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
667        })?;
668
669    // Get clones to avoid holding locks across await
670    let (connections, pending_requests) = {
671        let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
672        let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
673        (
674            worker_ref.connections.clone(),
675            worker_ref.pending_requests.clone(),
676        )
677    };
678
679    // Create response channel and register pending request
680    let (response_tx, response_rx) = oneshot::channel();
681    {
682        let mut pending = pending_requests
683            .lock()
684            .expect("lock should not be poisoned");
685        pending.insert(request_id, response_tx);
686    }
687
688    // Get connection and send message
689    {
690        let mut connections_guard = connections.write().await;
691        let connection = connections_guard.get_mut(&to).ok_or_else(|| {
692            TorshDistributedError::communication_error(
693                "rpc",
694                format!("No connection to worker {}", to),
695            )
696        })?;
697
698        connection.write_all(&message_bytes).await.map_err(|e| {
699            TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
700        })?;
701    }
702
703    // Wait for response with timeout
704    let result = tokio::time::timeout(Duration::from_secs(60), response_rx)
705        .await
706        .map_err(|_| TorshDistributedError::communication_error("rpc", "RPC timeout"))?
707        .map_err(|_| {
708            TorshDistributedError::communication_error("rpc", "Response channel closed")
709        })?;
710
711    match result {
712        Ok(returned_rref_id) => {
713            let (actual_rref_id, _): (String, usize) =
714                oxicode::serde::decode_from_slice(&returned_rref_id, oxicode::config::standard())
715                    .map_err(|e| {
716                    TorshDistributedError::communication_error(
717                        "rpc",
718                        format!("Deserialization error: {}", e),
719                    )
720                })?;
721
722            Ok(RRef::new(actual_rref_id, to))
723        }
724        Err(error_msg) => Err(TorshDistributedError::communication_error(
725            "rpc_remote",
726            format!("Remote function error: {}", error_msg),
727        )),
728    }
729}
730
731/// Delete a remote reference
732pub async fn delete_rref<T>(rref: RRef<T>) -> TorshResult<()> {
733    let worker_arc = get_rpc_worker()?;
734    let connections = {
735        let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
736        let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
737        worker_ref.connections.clone()
738    };
739
740    let message = RpcMessage::DeleteRRef {
741        rref_id: rref.id().to_string(),
742    };
743
744    let message_bytes = oxicode::serde::encode_to_vec(&message, oxicode::config::standard())
745        .map_err(|e| {
746            TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
747        })?;
748
749    let mut connections_guard = connections.write().await;
750    if let Some(connection) = connections_guard.get_mut(&rref.owner_rank()) {
751        connection.write_all(&message_bytes).await.map_err(|e| {
752            TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
753        })?;
754    }
755
756    Ok(())
757}
758
759/// Check if RPC framework is initialized
760pub fn is_initialized() -> bool {
761    #[cfg(test)]
762    {
763        let local_worker = TEST_RPC_WORKER.with(|w| w.borrow().clone());
764        if local_worker.is_some() {
765            return true;
766        }
767    }
768
769    RPC_WORKER.get().is_some()
770}
771
772/// Reset RPC framework (for testing only)
773#[cfg(test)]
774pub fn reset_rpc() {
775    TEST_RPC_WORKER.with(|w| *w.borrow_mut() = None);
776}
777
778/// Get current worker rank
779pub fn get_worker_rank() -> TorshResult<u32> {
780    let worker_arc = get_rpc_worker()?;
781    let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
782    let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
783    Ok(worker_ref.rank)
784}
785
786/// Get world size
787pub fn get_world_size() -> TorshResult<u32> {
788    let worker_arc = get_rpc_worker()?;
789    let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
790    let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
791    Ok(worker_ref.world_size)
792}
793
794#[cfg(test)]
795mod tests {
796    use super::*;
797    use serde::{Deserialize, Serialize};
798
799    #[derive(Serialize, Deserialize, Debug, PartialEq)]
800    struct TestArgs {
801        x: i32,
802        y: i32,
803    }
804
805    #[derive(Serialize, Deserialize, Debug, PartialEq)]
806    struct TestResult {
807        sum: i32,
808    }
809
810    fn add_function(args: TestArgs) -> Result<TestResult, String> {
811        Ok(TestResult {
812            sum: args.x + args.y,
813        })
814    }
815
816    fn multiply_function(args: TestArgs) -> Result<TestResult, String> {
817        Ok(TestResult {
818            sum: args.x * args.y,
819        })
820    }
821
822    #[tokio::test]
823    async fn test_rpc_initialization() -> TorshResult<()> {
824        reset_rpc();
825
826        let options = RpcBackendOptions::default();
827
828        // Test initialization
829        let result = init_rpc("test_worker", 0, 1, options).await;
830        if let Err(e) = &result {
831            info!("RPC initialization failed: {}", e);
832        }
833        assert!(result.is_ok());
834
835        // Test that we can get worker info
836        assert_eq!(get_worker_rank()?, 0);
837        assert_eq!(get_world_size()?, 1);
838        assert!(is_initialized());
839
840        // Clean up
841        shutdown().await?;
842        reset_rpc();
843
844        Ok(())
845    }
846
847    #[tokio::test]
848    async fn test_function_registration() -> TorshResult<()> {
849        reset_rpc();
850
851        let options = RpcBackendOptions::default();
852        init_rpc("test_worker", 0, 1, options).await?;
853
854        // Register a function
855        register_function("add", add_function).await?;
856        register_function("multiply", multiply_function).await?;
857
858        // Verify functions are registered
859        let function_registry = {
860            let worker_arc = get_rpc_worker()?;
861            let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
862            let worker_ref = worker_guard.as_ref().unwrap();
863            worker_ref.function_registry.clone()
864        }; // Guard dropped here
865
866        let registry = function_registry.read().await;
867        assert!(registry.contains_key("add"));
868        assert!(registry.contains_key("multiply"));
869        drop(registry); // Release the registry lock
870
871        shutdown().await?;
872        reset_rpc();
873
874        Ok(())
875    }
876
877    #[tokio::test]
878    async fn test_rpc_message_serialization() -> TorshResult<()> {
879        let message = RpcMessage::FunctionCall {
880            id: "test-123".to_string(),
881            function_name: "add".to_string(),
882            args: vec![1, 2, 3, 4],
883        };
884
885        // Test serialization
886        let serialized =
887            oxicode::serde::encode_to_vec(&message, oxicode::config::standard()).unwrap();
888
889        // Test deserialization
890        let (deserialized, _): (RpcMessage, usize) =
891            oxicode::serde::decode_from_slice(&serialized, oxicode::config::standard()).unwrap();
892
893        match (message, deserialized) {
894            (
895                RpcMessage::FunctionCall {
896                    id: id1,
897                    function_name: fn1,
898                    args: args1,
899                },
900                RpcMessage::FunctionCall {
901                    id: id2,
902                    function_name: fn2,
903                    args: args2,
904                },
905            ) => {
906                assert_eq!(id1, id2);
907                assert_eq!(fn1, fn2);
908                assert_eq!(args1, args2);
909            }
910            _ => panic!("Message types don't match"),
911        }
912
913        Ok(())
914    }
915
916    #[tokio::test]
917    async fn test_rref_creation() -> TorshResult<()> {
918        let rref: RRef<TestResult> = RRef::new("test-id".to_string(), 42);
919
920        assert_eq!(rref.id(), "test-id");
921        assert_eq!(rref.owner_rank(), 42);
922
923        Ok(())
924    }
925
926    #[tokio::test]
927    async fn test_rpc_backend_options() {
928        let default_options = RpcBackendOptions::default();
929        assert_eq!(default_options.num_worker_threads, 4);
930        assert_eq!(default_options.rpc_timeout, Duration::from_secs(60));
931        assert_eq!(default_options.init_method, "env://");
932        assert_eq!(default_options.buffer_size, 8192);
933        assert_eq!(default_options.max_connections, 100);
934
935        let custom_options = RpcBackendOptions {
936            num_worker_threads: 8,
937            rpc_timeout: Duration::from_secs(30),
938            init_method: "file://".to_string(),
939            buffer_size: 4096,
940            max_connections: 50,
941        };
942
943        assert_eq!(custom_options.num_worker_threads, 8);
944        assert_eq!(custom_options.rpc_timeout, Duration::from_secs(30));
945        assert_eq!(custom_options.init_method, "file://");
946        assert_eq!(custom_options.buffer_size, 4096);
947        assert_eq!(custom_options.max_connections, 50);
948    }
949
950    #[test]
951    fn test_rpc_not_initialized() {
952        // Without initialization, we should get errors
953        assert!(!is_initialized());
954    }
955
956    #[tokio::test]
957    async fn test_rpc_shutdown_cleanup() -> TorshResult<()> {
958        reset_rpc();
959
960        let options = RpcBackendOptions::default();
961        init_rpc("test_worker", 0, 1, options).await?;
962
963        assert!(is_initialized());
964
965        // Register some functions and remote refs
966        register_function("add", add_function).await?;
967
968        // Shutdown should clean everything up
969        shutdown().await?;
970        reset_rpc();
971
972        Ok(())
973    }
974}