Skip to main content

rpc_runtime_client/
lib.rs

1use std::collections::{BTreeMap, HashMap};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6use rmpv::Value;
7use rpc_runtime_activation::{
8    CREATE_INSTANCE_METHOD_ID, CreateInstanceRequest, LIST_INSTANCES_METHOD_ID,
9    ListInstancesRequest, RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID,
10    ReleaseInstanceRequest, ResolveInstanceIdsRequest, activation_instance_id,
11    decode_create_instance_response, decode_list_instances_response,
12    decode_release_instance_response, decode_resolve_instance_ids_response,
13    encode_create_instance_request, encode_list_instances_request, encode_release_instance_request,
14    encode_resolve_instance_ids_request,
15};
16use rpc_runtime_core::{
17    CapabilityFlags, Envelope, Hello, InstanceId, MethodId, Notification, RUNTIME_PROTOCOL_VERSION,
18    Request, RequestId, Role, ServiceGuid,
19};
20use rpc_runtime_errors::{ErrorKind, RuntimeError, RuntimeErrorCode};
21use rpc_runtime_transport::{RpcConnection, RpcReceiver, RpcSender};
22use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection, IpcEndpoint};
23use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
24
25#[derive(Clone)]
26pub struct RpcClient {
27    inner: Arc<ClientInner>,
28}
29
30struct ClientInner {
31    sender: RpcSender,
32    next_request_id: AtomicU64,
33    pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, RuntimeError>>>>,
34    notifications: broadcast::Sender<Notification>,
35}
36
37impl RpcClient {
38    pub async fn connect(endpoint: IpcEndpoint, config: FrameConfig) -> Result<Self, RuntimeError> {
39        let connection = IpcConnection::connect(endpoint, config)
40            .await
41            .map_err(|err| {
42                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
43            })?;
44        Self::from_connection(connection).await
45    }
46
47    pub async fn from_connection<C>(connection: C) -> Result<Self, RuntimeError>
48    where
49        C: Into<RpcConnection>,
50    {
51        let (sender, mut receiver) = connection.into().split();
52        sender
53            .send_envelope(&Envelope::Hello(Hello {
54                protocol_version: RUNTIME_PROTOCOL_VERSION,
55                role: Role::Client,
56                capability_bits: client_capabilities(),
57                max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
58                options: Vec::new(),
59            }))
60            .await
61            .map_err(|err| {
62                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
63            })?;
64
65        let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
66            RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
67        })?
68        else {
69            return Err(RuntimeError::transport(
70                RuntimeErrorCode::InternalRuntimeError,
71                "server disconnected during handshake",
72            ));
73        };
74        let Envelope::HelloAck(ack) = envelope else {
75            return Err(RuntimeError::protocol(
76                RuntimeErrorCode::InvalidEnvelope,
77                "expected HELLO_ACK during handshake",
78            ));
79        };
80        if ack.protocol_version != RUNTIME_PROTOCOL_VERSION {
81            return Err(RuntimeError::protocol(
82                RuntimeErrorCode::UnsupportedProtocolVersion,
83                "server returned unsupported protocol version",
84            ));
85        }
86
87        let (notifications, _) = broadcast::channel(128);
88        let inner = Arc::new(ClientInner {
89            sender,
90            next_request_id: AtomicU64::new(1),
91            pending: Mutex::new(HashMap::new()),
92            notifications,
93        });
94        spawn_receive_loop(Arc::clone(&inner), receiver);
95        Ok(Self { inner })
96    }
97
98    pub async fn call(
99        &self,
100        instance_id: InstanceId,
101        method_id: MethodId,
102        payload: Value,
103    ) -> Result<Value, RuntimeError> {
104        let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
105        let (tx, rx) = oneshot::channel();
106        self.inner.pending.lock().await.insert(request_id, tx);
107
108        let send_result = self
109            .inner
110            .sender
111            .send_envelope(&Envelope::Request(Request {
112                request_id: RequestId::new(request_id),
113                instance_id,
114                method_id,
115                payload,
116            }))
117            .await;
118        if let Err(err) = send_result {
119            self.inner.pending.lock().await.remove(&request_id);
120            return Err(RuntimeError::transport(
121                RuntimeErrorCode::InternalRuntimeError,
122                err.to_string(),
123            ));
124        }
125
126        rx.await.map_err(|_| {
127            RuntimeError::transport(
128                RuntimeErrorCode::InternalRuntimeError,
129                "response channel closed before request completed",
130            )
131        })?
132    }
133
134    pub async fn call_timeout(
135        &self,
136        instance_id: InstanceId,
137        method_id: MethodId,
138        payload: Value,
139        timeout: Duration,
140    ) -> Result<Value, RuntimeError> {
141        tokio::time::timeout(timeout, self.call(instance_id, method_id, payload))
142            .await
143            .map_err(|_| {
144                RuntimeError::runtime(RuntimeErrorCode::RequestTimeout, "request timed out")
145            })?
146    }
147
148    pub async fn resolve_instance_ids(&self, names: Vec<String>) -> Result<Vec<u64>, RuntimeError> {
149        let response = self
150            .call(
151                activation_instance_id(),
152                MethodId::new(RESOLVE_INSTANCE_IDS_METHOD_ID),
153                encode_resolve_instance_ids_request(&ResolveInstanceIdsRequest {
154                    instance_names: names,
155                }),
156            )
157            .await?;
158        Ok(decode_resolve_instance_ids_response(&response)?.instance_ids)
159    }
160
161    pub async fn create_instance(
162        &self,
163        service_guid: ServiceGuid,
164        create_payload: Option<Vec<u8>>,
165        options: BTreeMap<String, String>,
166    ) -> Result<InstanceId, RuntimeError> {
167        let response = self
168            .call(
169                activation_instance_id(),
170                MethodId::new(CREATE_INSTANCE_METHOD_ID),
171                encode_create_instance_request(&CreateInstanceRequest {
172                    service_guid,
173                    create_payload,
174                    options,
175                }),
176            )
177            .await?;
178        Ok(decode_create_instance_response(&response)?.instance_id)
179    }
180
181    pub async fn release_instance(&self, instance_id: InstanceId) -> Result<(), RuntimeError> {
182        let response = self
183            .call(
184                activation_instance_id(),
185                MethodId::new(RELEASE_INSTANCE_METHOD_ID),
186                encode_release_instance_request(&ReleaseInstanceRequest { instance_id }),
187            )
188            .await?;
189        decode_release_instance_response(&response)?;
190        Ok(())
191    }
192
193    pub async fn list_instances(
194        &self,
195        service_guid: Option<ServiceGuid>,
196    ) -> Result<Vec<rpc_runtime_activation::InstanceDescriptor>, RuntimeError> {
197        let response = self
198            .call(
199                activation_instance_id(),
200                MethodId::new(LIST_INSTANCES_METHOD_ID),
201                encode_list_instances_request(&ListInstancesRequest { service_guid }),
202            )
203            .await?;
204        Ok(decode_list_instances_response(&response)?.instances)
205    }
206
207    pub fn subscribe_notifications(
208        &self,
209        instance_id_filter: Option<InstanceId>,
210        notification_id_filter: Option<u32>,
211    ) -> mpsc::UnboundedReceiver<Notification> {
212        let mut source = self.inner.notifications.subscribe();
213        let (tx, rx) = mpsc::unbounded_channel();
214        tokio::spawn(async move {
215            loop {
216                let Ok(notification) = source.recv().await else {
217                    break;
218                };
219                let instance_matches = instance_id_filter
220                    .is_none_or(|expected| notification.instance_id == Some(expected));
221                let notification_matches = notification_id_filter
222                    .is_none_or(|expected| notification.notification_id.get() == expected);
223                if instance_matches && notification_matches && tx.send(notification).is_err() {
224                    break;
225                }
226            }
227        });
228        rx
229    }
230
231    pub async fn goodbye(&self, message: impl Into<String>) -> Result<(), RuntimeError> {
232        self.inner
233            .sender
234            .send_envelope(&Envelope::Goodbye(rpc_runtime_core::Goodbye {
235                reason_code: 0,
236                message: Some(message.into()),
237            }))
238            .await
239            .map_err(|err| {
240                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
241            })
242    }
243}
244
245fn spawn_receive_loop(inner: Arc<ClientInner>, mut receiver: RpcReceiver) {
246    tokio::spawn(async move {
247        loop {
248            let envelope = match receiver.recv_envelope().await {
249                Ok(Some(envelope)) => envelope,
250                Ok(None) => {
251                    fail_pending(
252                        &inner,
253                        RuntimeError::transport(
254                            RuntimeErrorCode::InternalRuntimeError,
255                            "server disconnected",
256                        ),
257                    )
258                    .await;
259                    break;
260                }
261                Err(err) => {
262                    fail_pending(
263                        &inner,
264                        RuntimeError::transport(
265                            RuntimeErrorCode::InternalRuntimeError,
266                            err.to_string(),
267                        ),
268                    )
269                    .await;
270                    break;
271                }
272            };
273            match envelope {
274                Envelope::ResponseOk(response) => {
275                    complete_pending(&inner, response.request_id.get(), Ok(response.payload)).await;
276                }
277                Envelope::ResponseError(response) => {
278                    complete_pending(
279                        &inner,
280                        response.request_id.get(),
281                        Err(RuntimeError::new(
282                            runtime_error_code(response.error_code),
283                            error_kind(response.error_kind),
284                            response.error_message.unwrap_or_default(),
285                        )),
286                    )
287                    .await;
288                }
289                Envelope::Notification(notification) => {
290                    let _ = inner.notifications.send(notification);
291                }
292                _ => {
293                    fail_pending(
294                        &inner,
295                        RuntimeError::protocol(
296                            RuntimeErrorCode::InvalidEnvelope,
297                            "client received invalid envelope kind",
298                        ),
299                    )
300                    .await;
301                    break;
302                }
303            }
304        }
305    });
306}
307
308async fn complete_pending(
309    inner: &ClientInner,
310    request_id: u64,
311    result: Result<Value, RuntimeError>,
312) {
313    if let Some(sender) = inner.pending.lock().await.remove(&request_id) {
314        let _ = sender.send(result);
315    }
316}
317
318async fn fail_pending(inner: &ClientInner, error: RuntimeError) {
319    let pending = std::mem::take(&mut *inner.pending.lock().await);
320    for (_, sender) in pending {
321        let _ = sender.send(Err(error.clone()));
322    }
323}
324
325fn client_capabilities() -> CapabilityFlags {
326    CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
327        | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
328        | CapabilityFlags::SERVICE_ACTIVATION
329        | CapabilityFlags::GOODBYE
330}
331
332fn runtime_error_code(value: i32) -> RuntimeErrorCode {
333    match value {
334        1001 => RuntimeErrorCode::UnknownMessageKind,
335        1002 => RuntimeErrorCode::UnsupportedProtocolVersion,
336        1003 => RuntimeErrorCode::InvalidEnvelope,
337        1004 => RuntimeErrorCode::InvalidRequestId,
338        1005 => RuntimeErrorCode::InvalidInstanceId,
339        1006 => RuntimeErrorCode::InstanceNotFound,
340        1007 => RuntimeErrorCode::MethodNotFound,
341        1008 => RuntimeErrorCode::NotificationNotFound,
342        1009 => RuntimeErrorCode::PayloadDecodeFailed,
343        1010 => RuntimeErrorCode::PayloadEncodeFailed,
344        1011 => RuntimeErrorCode::ServiceActivationNotSupported,
345        1012 => RuntimeErrorCode::ServiceGuidNotFound,
346        1013 => RuntimeErrorCode::InstanceReleaseNotAllowed,
347        1014 => RuntimeErrorCode::RequestTimeout,
348        1015 => RuntimeErrorCode::UnsupportedCapability,
349        1016 => RuntimeErrorCode::BusinessErrorDeclared,
350        1017 => RuntimeErrorCode::DuplicateRequestId,
351        1018 => RuntimeErrorCode::RequestCancelUnsupported,
352        1019 => RuntimeErrorCode::AccessDenied,
353        _ => RuntimeErrorCode::InternalRuntimeError,
354    }
355}
356
357fn error_kind(value: u8) -> ErrorKind {
358    match value {
359        1 => ErrorKind::Transport,
360        2 => ErrorKind::Protocol,
361        3 => ErrorKind::Runtime,
362        4 => ErrorKind::Business,
363        5 => ErrorKind::Timeout,
364        6 => ErrorKind::Cancelled,
365        _ => ErrorKind::Runtime,
366    }
367}