1use crate::{TorshDistributedError, TorshResult};
8use log::info;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12use std::time::Duration;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::net::{TcpListener, TcpStream};
15use tokio::sync::{mpsc, oneshot, RwLock};
16use uuid::Uuid;
17
18type PendingRequestMap = Arc<Mutex<HashMap<String, oneshot::Sender<Result<Vec<u8>, String>>>>>;
20type FunctionRegistry =
21 Arc<RwLock<HashMap<String, Box<dyn Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync>>>>;
22
23#[derive(Debug, Clone)]
25pub struct RpcBackendOptions {
26 pub num_worker_threads: usize,
27 pub rpc_timeout: Duration,
28 pub init_method: String,
29 pub buffer_size: usize,
30 pub max_connections: usize,
31}
32
33impl Default for RpcBackendOptions {
34 fn default() -> Self {
35 Self {
36 num_worker_threads: 4,
37 rpc_timeout: Duration::from_secs(60),
38 init_method: String::from("env://"),
39 buffer_size: 8192,
40 max_connections: 100,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum RpcMessage {
48 FunctionCall {
50 id: String,
51 function_name: String,
52 args: Vec<u8>, },
54 FunctionResponse {
56 id: String,
57 result: Result<Vec<u8>, String>, },
59 RemoteRef {
61 id: String,
62 function_name: String,
63 args: Vec<u8>,
64 rref_id: String,
65 },
66 RemoteRefResponse {
68 id: String,
69 result: Result<String, String>, },
71 DeleteRRef { rref_id: String },
73 Ping,
75 Pong,
77}
78
79#[derive(Debug, Clone)]
81pub struct RRef<T> {
82 id: String,
83 owner_rank: u32,
84 _phantom: std::marker::PhantomData<T>,
85}
86
87impl<T> RRef<T> {
88 pub fn new(id: String, owner_rank: u32) -> Self {
89 Self {
90 id,
91 owner_rank,
92 _phantom: std::marker::PhantomData,
93 }
94 }
95
96 pub fn id(&self) -> &str {
97 &self.id
98 }
99
100 pub fn owner_rank(&self) -> u32 {
101 self.owner_rank
102 }
103}
104
105struct RpcWorker {
107 rank: u32,
108 world_size: u32,
109 connections: Arc<RwLock<HashMap<u32, TcpStream>>>,
110 pending_requests: PendingRequestMap,
111 remote_refs: Arc<RwLock<HashMap<String, Box<dyn std::any::Any + Send + Sync>>>>,
112 function_registry: FunctionRegistry,
113 shutdown_tx: Option<mpsc::Sender<()>>,
114}
115
116static RPC_WORKER: once_cell::sync::OnceCell<Arc<Mutex<Option<RpcWorker>>>> =
118 once_cell::sync::OnceCell::new();
119
120#[cfg(test)]
122thread_local! {
123 static TEST_RPC_WORKER: std::cell::RefCell<Option<Arc<Mutex<Option<RpcWorker>>>>> = const { std::cell::RefCell::new(None) };
124}
125
126fn get_rpc_worker() -> TorshResult<Arc<Mutex<Option<RpcWorker>>>> {
128 #[cfg(test)]
129 {
130 let local_worker = TEST_RPC_WORKER.with(|w| w.borrow().clone());
132 if let Some(worker) = local_worker {
133 return Ok(worker);
134 }
135 }
136
137 RPC_WORKER
138 .get()
139 .ok_or(TorshDistributedError::BackendNotInitialized)
140 .cloned()
141}
142
143pub async fn init_rpc(
145 name: &str,
146 rank: u32,
147 world_size: u32,
148 _options: RpcBackendOptions,
149) -> TorshResult<()> {
150 let worker = RpcWorker {
152 rank,
153 world_size,
154 connections: Arc::new(RwLock::new(HashMap::new())),
155 pending_requests: Arc::new(Mutex::new(HashMap::new())),
156 remote_refs: Arc::new(RwLock::new(HashMap::new())),
157 function_registry: Arc::new(RwLock::new(HashMap::new())),
158 shutdown_tx: None,
159 };
160
161 let worker_arc = Arc::new(Mutex::new(Some(worker)));
162
163 #[cfg(test)]
165 {
166 TEST_RPC_WORKER.with(|w| *w.borrow_mut() = Some(worker_arc.clone()));
167 }
168
169 #[cfg(not(test))]
170 {
171 RPC_WORKER
172 .set(worker_arc.clone())
173 .map_err(|_| TorshDistributedError::backend_error("rpc", "RPC already initialized"))?;
174 }
175
176 let base_port = if cfg!(test) {
178 29600 + (std::process::id() % 1000) + rank * 100
180 } else {
181 29600 + rank
182 };
183
184 let mut listener = None;
186 for port_offset in 0..10 {
187 let listen_addr = format!("127.0.0.1:{}", base_port + port_offset);
188 match TcpListener::bind(&listen_addr).await {
189 Ok(l) => {
190 info!(
191 "[RPC] Worker '{}' (rank {}) starting on {}",
192 name, rank, listen_addr
193 );
194 listener = Some(l);
195 break;
196 }
197 Err(e) => {
198 if port_offset == 9 {
199 return Err(TorshDistributedError::communication_error(
200 "rpc_server",
201 format!("Failed to bind after trying 10 ports: {}", e),
202 ));
203 }
204 }
205 }
206 }
207
208 let listener = listener.expect("listener should be successfully bound");
209
210 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
211
212 {
214 let mut worker_guard = worker_arc.lock().expect("lock should not be poisoned");
215 if let Some(ref mut worker) = *worker_guard {
216 worker.shutdown_tx = Some(shutdown_tx);
217 }
218 }
219
220 let worker_for_server = worker_arc.clone();
222 tokio::spawn(async move {
223 loop {
224 tokio::select! {
225 result = listener.accept() => {
226 match result {
227 Ok((stream, addr)) => {
228 info!("[RPC] Accepted connection from {}", addr);
229 let worker_clone = worker_for_server.clone();
230 tokio::spawn(handle_connection(stream, worker_clone));
231 }
232 Err(e) => {
233 info!("[RPC] Failed to accept connection: {}", e);
234 }
235 }
236 }
237 _ = shutdown_rx.recv() => {
238 info!("[RPC] Server shutting down");
239 break;
240 }
241 }
242 }
243 });
244
245 tokio::time::sleep(Duration::from_millis(100)).await;
247
248 if world_size > 1 {
250 for other_rank in 0..world_size {
251 if other_rank != rank {
252 let target_addr = format!("127.0.0.1:{}", base_port + other_rank);
253
254 let mut retries = 0;
256 let max_retries = if cfg!(test) { 3 } else { 10 }; while retries < max_retries {
259 match TcpStream::connect(&target_addr).await {
260 Ok(stream) => {
261 info!(
262 "[RPC] Connected to worker {} at {}",
263 other_rank, target_addr
264 );
265 let connections = {
266 let worker_guard =
267 worker_arc.lock().expect("lock should not be poisoned");
268 worker_guard
269 .as_ref()
270 .expect("worker should be initialized")
271 .connections
272 .clone()
273 };
274 let mut connections_guard = connections.write().await;
275 connections_guard.insert(other_rank, stream);
276 break;
277 }
278 Err(e) => {
279 retries += 1;
280 let delay = Duration::from_millis(100 * (1 << retries.min(3)));
281 tokio::time::sleep(delay).await;
282 if retries == max_retries {
283 if cfg!(test) {
285 info!(
286 "[RPC] Failed to connect to worker {} (test mode): {}",
287 other_rank, e
288 );
289 break;
290 } else {
291 return Err(TorshDistributedError::communication_error(
292 "rpc_connect",
293 format!(
294 "Failed to connect to worker {}: {}",
295 other_rank, e
296 ),
297 ));
298 }
299 }
300 }
301 }
302 }
303 }
304 }
305 }
306
307 info!(
308 "[RPC] Worker '{}' (rank {}) initialized successfully",
309 name, rank
310 );
311 Ok(())
312}
313
314async fn handle_connection(mut stream: TcpStream, worker: Arc<Mutex<Option<RpcWorker>>>) {
316 let mut buffer = vec![0u8; 8192];
317
318 loop {
319 match stream.read(&mut buffer).await {
320 Ok(0) => break, Ok(n) => {
322 let data = &buffer[..n];
323
324 let result: Result<(RpcMessage, usize), _> =
326 oxicode::serde::decode_from_slice(data, oxicode::config::standard());
327 match result {
328 Ok((message, _)) => {
329 if let Err(e) = handle_rpc_message(message, &mut stream, &worker).await {
330 info!("[RPC] Error handling message: {}", e);
331 }
332 }
333 Err(e) => {
334 info!("[RPC] Failed to deserialize message: {}", e);
335 }
336 }
337 }
338 Err(e) => {
339 info!("[RPC] Connection error: {}", e);
340 break;
341 }
342 }
343 }
344}
345
346async fn handle_rpc_message(
348 message: RpcMessage,
349 stream: &mut TcpStream,
350 worker: &Arc<Mutex<Option<RpcWorker>>>,
351) -> TorshResult<()> {
352 match message {
353 RpcMessage::FunctionCall {
354 id,
355 function_name,
356 args,
357 } => {
358 let function_registry = {
360 let worker_guard = worker.lock().expect("lock should not be poisoned");
361 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
362 worker_ref.function_registry.clone()
363 };
364
365 let result = {
366 let registry = function_registry.read().await;
367 if let Some(func) = registry.get(&function_name) {
368 func(&args)
369 } else {
370 Err(format!("Function '{}' not found", function_name))
371 }
372 };
373
374 let response = RpcMessage::FunctionResponse { id, result };
375 let response_data =
376 oxicode::serde::encode_to_vec(&response, oxicode::config::standard()).map_err(
377 |e| {
378 TorshDistributedError::communication_error(
379 "rpc",
380 format!("Serialization error: {}", e),
381 )
382 },
383 )?;
384
385 stream.write_all(&response_data).await.map_err(|e| {
386 TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
387 })?;
388 }
389
390 RpcMessage::RemoteRef {
391 id,
392 function_name,
393 args,
394 rref_id,
395 } => {
396 let (function_registry, remote_refs) = {
398 let worker_guard = worker.lock().expect("lock should not be poisoned");
399 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
400 (
401 worker_ref.function_registry.clone(),
402 worker_ref.remote_refs.clone(),
403 )
404 };
405
406 let result = {
407 let registry = function_registry.read().await;
408 if let Some(func) = registry.get(&function_name) {
409 match func(&args) {
410 Ok(result_data) => {
411 let mut refs = remote_refs.write().await;
413 refs.insert(rref_id.clone(), Box::new(result_data));
415 Ok(rref_id)
416 }
417 Err(e) => Err(e),
418 }
419 } else {
420 Err(format!("Function '{}' not found", function_name))
421 }
422 };
423
424 let response = RpcMessage::RemoteRefResponse { id, result };
425 let response_data =
426 oxicode::serde::encode_to_vec(&response, oxicode::config::standard()).map_err(
427 |e| {
428 TorshDistributedError::communication_error(
429 "rpc",
430 format!("Serialization error: {}", e),
431 )
432 },
433 )?;
434
435 stream.write_all(&response_data).await.map_err(|e| {
436 TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
437 })?;
438 }
439
440 RpcMessage::DeleteRRef { rref_id } => {
441 let remote_refs = {
442 let worker_guard = worker.lock().expect("lock should not be poisoned");
443 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
444 worker_ref.remote_refs.clone()
445 };
446
447 let mut refs = remote_refs.write().await;
448 refs.remove(&rref_id);
449 }
450
451 RpcMessage::Ping => {
452 let response = RpcMessage::Pong;
453 let response_data =
454 oxicode::serde::encode_to_vec(&response, oxicode::config::standard()).map_err(
455 |e| {
456 TorshDistributedError::communication_error(
457 "rpc",
458 format!("Serialization error: {}", e),
459 )
460 },
461 )?;
462
463 stream.write_all(&response_data).await.map_err(|e| {
464 TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
465 })?;
466 }
467
468 _ => {
469 if let RpcMessage::FunctionResponse { id, result } = message {
471 let worker_guard = worker.lock().expect("lock should not be poisoned");
472 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
473 let mut pending = worker_ref
474 .pending_requests
475 .lock()
476 .expect("lock should not be poisoned");
477
478 if let Some(sender) = pending.remove(&id) {
479 let _ = sender.send(result);
480 }
481 }
482 }
483 }
484
485 Ok(())
486}
487
488pub async fn shutdown() -> TorshResult<()> {
490 let worker_arc = get_rpc_worker()?;
491
492 let (shutdown_tx, remote_refs) = {
493 let mut worker_guard = worker_arc.lock().expect("lock should not be poisoned");
494 if let Some(worker) = worker_guard.take() {
495 (worker.shutdown_tx, Some(worker.remote_refs))
496 } else {
497 (None, None)
498 }
499 };
500
501 if let Some(shutdown_tx) = shutdown_tx {
502 let _ = shutdown_tx.send(()).await;
503 }
504
505 if let Some(remote_refs) = remote_refs {
506 let mut refs = remote_refs.write().await;
508 refs.clear();
509 }
510
511 info!("[RPC] Framework shut down successfully");
512 Ok(())
513}
514
515pub async fn register_function<F, Args, Ret>(name: &str, func: F) -> TorshResult<()>
517where
518 F: Fn(Args) -> Result<Ret, String> + Send + Sync + 'static,
519 Args: for<'de> Deserialize<'de> + 'static,
520 Ret: Serialize + 'static,
521{
522 let worker_arc = get_rpc_worker()?;
523 let function_registry = {
524 let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
525 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
526 worker_ref.function_registry.clone()
527 };
528
529 let wrapper = move |args_bytes: &[u8]| -> Result<Vec<u8>, String> {
530 let (args, _): (Args, usize) =
531 oxicode::serde::decode_from_slice(args_bytes, oxicode::config::standard())
532 .map_err(|e| format!("Deserialization error: {}", e))?;
533
534 let result = func(args)?;
535
536 oxicode::serde::encode_to_vec(&result, oxicode::config::standard())
537 .map_err(|e| format!("Serialization error: {}", e))
538 };
539
540 let mut registry = function_registry.write().await;
541 registry.insert(name.to_string(), Box::new(wrapper));
542
543 Ok(())
544}
545
546pub async fn rpc_async<Args, Ret>(to: u32, function_name: &str, args: Args) -> TorshResult<Ret>
548where
549 Args: Serialize,
550 Ret: for<'de> Deserialize<'de>,
551{
552 let worker_arc = get_rpc_worker()?;
553
554 let args_bytes =
556 oxicode::serde::encode_to_vec(&args, oxicode::config::standard()).map_err(|e| {
557 TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
558 })?;
559
560 let request_id = Uuid::new_v4().to_string();
562
563 let message = RpcMessage::FunctionCall {
565 id: request_id.clone(),
566 function_name: function_name.to_string(),
567 args: args_bytes,
568 };
569
570 let message_bytes = oxicode::serde::encode_to_vec(&message, oxicode::config::standard())
572 .map_err(|e| {
573 TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
574 })?;
575
576 let (connections, pending_requests) = {
578 let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
579 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
580 (
581 worker_ref.connections.clone(),
582 worker_ref.pending_requests.clone(),
583 )
584 };
585
586 let (response_tx, response_rx) = oneshot::channel();
588 {
589 let mut pending = pending_requests
590 .lock()
591 .expect("lock should not be poisoned");
592 pending.insert(request_id, response_tx);
593 }
594
595 {
597 let mut connections_guard = connections.write().await;
598 let connection = connections_guard.get_mut(&to).ok_or_else(|| {
599 TorshDistributedError::communication_error(
600 "rpc",
601 format!("No connection to worker {}", to),
602 )
603 })?;
604
605 connection.write_all(&message_bytes).await.map_err(|e| {
606 TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
607 })?;
608 }
609
610 let result = tokio::time::timeout(Duration::from_secs(60), response_rx)
612 .await
613 .map_err(|_| TorshDistributedError::communication_error("rpc", "RPC timeout"))?
614 .map_err(|_| {
615 TorshDistributedError::communication_error("rpc", "Response channel closed")
616 })?;
617
618 match result {
619 Ok(result_bytes) => {
620 let (value, _): (Ret, usize) =
621 oxicode::serde::decode_from_slice(&result_bytes, oxicode::config::standard())
622 .map_err(|e| {
623 TorshDistributedError::communication_error(
624 "rpc",
625 format!("Deserialization error: {}", e),
626 )
627 })?;
628 Ok(value)
629 }
630 Err(error_msg) => Err(TorshDistributedError::communication_error(
631 "rpc_remote",
632 format!("Remote function error: {}", error_msg),
633 )),
634 }
635}
636
637pub async fn remote<Args, Ret>(to: u32, function_name: &str, args: Args) -> TorshResult<RRef<Ret>>
639where
640 Args: Serialize,
641 Ret: for<'de> Deserialize<'de> + 'static,
642{
643 let worker_arc = get_rpc_worker()?;
644
645 let args_bytes =
647 oxicode::serde::encode_to_vec(&args, oxicode::config::standard()).map_err(|e| {
648 TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
649 })?;
650
651 let request_id = Uuid::new_v4().to_string();
653 let rref_id = Uuid::new_v4().to_string();
654
655 let message = RpcMessage::RemoteRef {
657 id: request_id.clone(),
658 function_name: function_name.to_string(),
659 args: args_bytes,
660 rref_id: rref_id.clone(),
661 };
662
663 let message_bytes = oxicode::serde::encode_to_vec(&message, oxicode::config::standard())
665 .map_err(|e| {
666 TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
667 })?;
668
669 let (connections, pending_requests) = {
671 let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
672 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
673 (
674 worker_ref.connections.clone(),
675 worker_ref.pending_requests.clone(),
676 )
677 };
678
679 let (response_tx, response_rx) = oneshot::channel();
681 {
682 let mut pending = pending_requests
683 .lock()
684 .expect("lock should not be poisoned");
685 pending.insert(request_id, response_tx);
686 }
687
688 {
690 let mut connections_guard = connections.write().await;
691 let connection = connections_guard.get_mut(&to).ok_or_else(|| {
692 TorshDistributedError::communication_error(
693 "rpc",
694 format!("No connection to worker {}", to),
695 )
696 })?;
697
698 connection.write_all(&message_bytes).await.map_err(|e| {
699 TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
700 })?;
701 }
702
703 let result = tokio::time::timeout(Duration::from_secs(60), response_rx)
705 .await
706 .map_err(|_| TorshDistributedError::communication_error("rpc", "RPC timeout"))?
707 .map_err(|_| {
708 TorshDistributedError::communication_error("rpc", "Response channel closed")
709 })?;
710
711 match result {
712 Ok(returned_rref_id) => {
713 let (actual_rref_id, _): (String, usize) =
714 oxicode::serde::decode_from_slice(&returned_rref_id, oxicode::config::standard())
715 .map_err(|e| {
716 TorshDistributedError::communication_error(
717 "rpc",
718 format!("Deserialization error: {}", e),
719 )
720 })?;
721
722 Ok(RRef::new(actual_rref_id, to))
723 }
724 Err(error_msg) => Err(TorshDistributedError::communication_error(
725 "rpc_remote",
726 format!("Remote function error: {}", error_msg),
727 )),
728 }
729}
730
731pub async fn delete_rref<T>(rref: RRef<T>) -> TorshResult<()> {
733 let worker_arc = get_rpc_worker()?;
734 let connections = {
735 let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
736 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
737 worker_ref.connections.clone()
738 };
739
740 let message = RpcMessage::DeleteRRef {
741 rref_id: rref.id().to_string(),
742 };
743
744 let message_bytes = oxicode::serde::encode_to_vec(&message, oxicode::config::standard())
745 .map_err(|e| {
746 TorshDistributedError::communication_error("rpc", format!("Serialization error: {}", e))
747 })?;
748
749 let mut connections_guard = connections.write().await;
750 if let Some(connection) = connections_guard.get_mut(&rref.owner_rank()) {
751 connection.write_all(&message_bytes).await.map_err(|e| {
752 TorshDistributedError::communication_error("rpc", format!("Write error: {}", e))
753 })?;
754 }
755
756 Ok(())
757}
758
759pub fn is_initialized() -> bool {
761 #[cfg(test)]
762 {
763 let local_worker = TEST_RPC_WORKER.with(|w| w.borrow().clone());
764 if local_worker.is_some() {
765 return true;
766 }
767 }
768
769 RPC_WORKER.get().is_some()
770}
771
772#[cfg(test)]
774pub fn reset_rpc() {
775 TEST_RPC_WORKER.with(|w| *w.borrow_mut() = None);
776}
777
778pub fn get_worker_rank() -> TorshResult<u32> {
780 let worker_arc = get_rpc_worker()?;
781 let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
782 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
783 Ok(worker_ref.rank)
784}
785
786pub fn get_world_size() -> TorshResult<u32> {
788 let worker_arc = get_rpc_worker()?;
789 let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
790 let worker_ref = worker_guard.as_ref().expect("worker should be initialized");
791 Ok(worker_ref.world_size)
792}
793
794#[cfg(test)]
795mod tests {
796 use super::*;
797 use serde::{Deserialize, Serialize};
798
799 #[derive(Serialize, Deserialize, Debug, PartialEq)]
800 struct TestArgs {
801 x: i32,
802 y: i32,
803 }
804
805 #[derive(Serialize, Deserialize, Debug, PartialEq)]
806 struct TestResult {
807 sum: i32,
808 }
809
810 fn add_function(args: TestArgs) -> Result<TestResult, String> {
811 Ok(TestResult {
812 sum: args.x + args.y,
813 })
814 }
815
816 fn multiply_function(args: TestArgs) -> Result<TestResult, String> {
817 Ok(TestResult {
818 sum: args.x * args.y,
819 })
820 }
821
822 #[tokio::test]
823 async fn test_rpc_initialization() -> TorshResult<()> {
824 reset_rpc();
825
826 let options = RpcBackendOptions::default();
827
828 let result = init_rpc("test_worker", 0, 1, options).await;
830 if let Err(e) = &result {
831 info!("RPC initialization failed: {}", e);
832 }
833 assert!(result.is_ok());
834
835 assert_eq!(get_worker_rank()?, 0);
837 assert_eq!(get_world_size()?, 1);
838 assert!(is_initialized());
839
840 shutdown().await?;
842 reset_rpc();
843
844 Ok(())
845 }
846
847 #[tokio::test]
848 async fn test_function_registration() -> TorshResult<()> {
849 reset_rpc();
850
851 let options = RpcBackendOptions::default();
852 init_rpc("test_worker", 0, 1, options).await?;
853
854 register_function("add", add_function).await?;
856 register_function("multiply", multiply_function).await?;
857
858 let function_registry = {
860 let worker_arc = get_rpc_worker()?;
861 let worker_guard = worker_arc.lock().expect("lock should not be poisoned");
862 let worker_ref = worker_guard.as_ref().unwrap();
863 worker_ref.function_registry.clone()
864 }; let registry = function_registry.read().await;
867 assert!(registry.contains_key("add"));
868 assert!(registry.contains_key("multiply"));
869 drop(registry); shutdown().await?;
872 reset_rpc();
873
874 Ok(())
875 }
876
877 #[tokio::test]
878 async fn test_rpc_message_serialization() -> TorshResult<()> {
879 let message = RpcMessage::FunctionCall {
880 id: "test-123".to_string(),
881 function_name: "add".to_string(),
882 args: vec![1, 2, 3, 4],
883 };
884
885 let serialized =
887 oxicode::serde::encode_to_vec(&message, oxicode::config::standard()).unwrap();
888
889 let (deserialized, _): (RpcMessage, usize) =
891 oxicode::serde::decode_from_slice(&serialized, oxicode::config::standard()).unwrap();
892
893 match (message, deserialized) {
894 (
895 RpcMessage::FunctionCall {
896 id: id1,
897 function_name: fn1,
898 args: args1,
899 },
900 RpcMessage::FunctionCall {
901 id: id2,
902 function_name: fn2,
903 args: args2,
904 },
905 ) => {
906 assert_eq!(id1, id2);
907 assert_eq!(fn1, fn2);
908 assert_eq!(args1, args2);
909 }
910 _ => panic!("Message types don't match"),
911 }
912
913 Ok(())
914 }
915
916 #[tokio::test]
917 async fn test_rref_creation() -> TorshResult<()> {
918 let rref: RRef<TestResult> = RRef::new("test-id".to_string(), 42);
919
920 assert_eq!(rref.id(), "test-id");
921 assert_eq!(rref.owner_rank(), 42);
922
923 Ok(())
924 }
925
926 #[tokio::test]
927 async fn test_rpc_backend_options() {
928 let default_options = RpcBackendOptions::default();
929 assert_eq!(default_options.num_worker_threads, 4);
930 assert_eq!(default_options.rpc_timeout, Duration::from_secs(60));
931 assert_eq!(default_options.init_method, "env://");
932 assert_eq!(default_options.buffer_size, 8192);
933 assert_eq!(default_options.max_connections, 100);
934
935 let custom_options = RpcBackendOptions {
936 num_worker_threads: 8,
937 rpc_timeout: Duration::from_secs(30),
938 init_method: "file://".to_string(),
939 buffer_size: 4096,
940 max_connections: 50,
941 };
942
943 assert_eq!(custom_options.num_worker_threads, 8);
944 assert_eq!(custom_options.rpc_timeout, Duration::from_secs(30));
945 assert_eq!(custom_options.init_method, "file://");
946 assert_eq!(custom_options.buffer_size, 4096);
947 assert_eq!(custom_options.max_connections, 50);
948 }
949
950 #[test]
951 fn test_rpc_not_initialized() {
952 assert!(!is_initialized());
954 }
955
956 #[tokio::test]
957 async fn test_rpc_shutdown_cleanup() -> TorshResult<()> {
958 reset_rpc();
959
960 let options = RpcBackendOptions::default();
961 init_rpc("test_worker", 0, 1, options).await?;
962
963 assert!(is_initialized());
964
965 register_function("add", add_function).await?;
967
968 shutdown().await?;
970 reset_rpc();
971
972 Ok(())
973 }
974}