Skip to main content

rpc_runtime_client/
lib.rs

1use std::collections::{BTreeMap, HashMap};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6use rmpv::Value;
7use rpc_runtime_activation::{
8    CREATE_INSTANCE_METHOD_ID, CreateInstanceRequest, LIST_INSTANCES_METHOD_ID,
9    ListInstancesRequest, RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID,
10    ReleaseInstanceRequest, ResolveInstanceIdsRequest, activation_instance_id,
11    decode_create_instance_response, decode_list_instances_response,
12    decode_release_instance_response, decode_resolve_instance_ids_response,
13    encode_create_instance_request, encode_list_instances_request, encode_release_instance_request,
14    encode_resolve_instance_ids_request,
15};
16use rpc_runtime_core::{
17    CapabilityFlags, Envelope, Hello, InstanceId, MethodId, Notification, Options,
18    RUNTIME_PROTOCOL_VERSION, Request, RequestId, Role, ServiceGuid,
19};
20use rpc_runtime_errors::{ErrorKind, RuntimeError, RuntimeErrorCode};
21use rpc_runtime_transport::{RpcConnection, RpcReceiver, RpcSender};
22use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection, IpcEndpoint};
23use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
24
25#[derive(Clone)]
26pub struct RpcClient {
27    inner: Arc<ClientInner>,
28}
29
30pub const DEFAULT_AUTH_TOKEN_OPTION_KEY: &str = "tripley.auth.token";
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct RpcClientHandshakeConfig {
34    pub auth_token: Option<String>,
35    pub auth_option_key: String,
36}
37
38impl RpcClientHandshakeConfig {
39    pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
40        self.auth_token = Some(token.into());
41        self
42    }
43
44    pub fn with_auth_option_key(mut self, key: impl Into<String>) -> Self {
45        self.auth_option_key = key.into();
46        self
47    }
48
49    fn hello_options(&self) -> Options {
50        self.auth_token
51            .as_ref()
52            .map(|token| vec![(self.auth_option_key.clone(), Value::from(token.as_str()))])
53            .unwrap_or_default()
54    }
55}
56
57impl Default for RpcClientHandshakeConfig {
58    fn default() -> Self {
59        Self {
60            auth_token: None,
61            auth_option_key: DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
62        }
63    }
64}
65
66struct ClientInner {
67    sender: RpcSender,
68    next_request_id: AtomicU64,
69    pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, RuntimeError>>>>,
70    notifications: broadcast::Sender<Notification>,
71}
72
73impl RpcClient {
74    pub async fn connect(endpoint: IpcEndpoint, config: FrameConfig) -> Result<Self, RuntimeError> {
75        Self::connect_with_handshake_config(endpoint, config, RpcClientHandshakeConfig::default())
76            .await
77    }
78
79    pub async fn connect_with_handshake_config(
80        endpoint: IpcEndpoint,
81        config: FrameConfig,
82        handshake: RpcClientHandshakeConfig,
83    ) -> Result<Self, RuntimeError> {
84        let connection = IpcConnection::connect(endpoint, config)
85            .await
86            .map_err(|err| {
87                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
88            })?;
89        Self::from_connection_with_handshake_config(connection, handshake).await
90    }
91
92    pub async fn from_connection<C>(connection: C) -> Result<Self, RuntimeError>
93    where
94        C: Into<RpcConnection>,
95    {
96        Self::from_connection_with_handshake_config(connection, RpcClientHandshakeConfig::default())
97            .await
98    }
99
100    pub async fn from_connection_with_handshake_config<C>(
101        connection: C,
102        handshake: RpcClientHandshakeConfig,
103    ) -> Result<Self, RuntimeError>
104    where
105        C: Into<RpcConnection>,
106    {
107        let (sender, mut receiver) = connection.into().split();
108        sender
109            .send_envelope(&Envelope::Hello(Hello {
110                protocol_version: RUNTIME_PROTOCOL_VERSION,
111                role: Role::Client,
112                capability_bits: client_capabilities(),
113                max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
114                options: handshake.hello_options(),
115            }))
116            .await
117            .map_err(|err| {
118                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
119            })?;
120
121        let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
122            RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
123        })?
124        else {
125            return Err(RuntimeError::transport(
126                RuntimeErrorCode::InternalRuntimeError,
127                "server disconnected during handshake",
128            ));
129        };
130        let Envelope::HelloAck(ack) = envelope else {
131            return Err(RuntimeError::protocol(
132                RuntimeErrorCode::InvalidEnvelope,
133                "expected HELLO_ACK during handshake",
134            ));
135        };
136        if ack.protocol_version != RUNTIME_PROTOCOL_VERSION {
137            return Err(RuntimeError::protocol(
138                RuntimeErrorCode::UnsupportedProtocolVersion,
139                "server returned unsupported protocol version",
140            ));
141        }
142
143        let (notifications, _) = broadcast::channel(128);
144        let inner = Arc::new(ClientInner {
145            sender,
146            next_request_id: AtomicU64::new(1),
147            pending: Mutex::new(HashMap::new()),
148            notifications,
149        });
150        spawn_receive_loop(Arc::clone(&inner), receiver);
151        Ok(Self { inner })
152    }
153
154    pub async fn call(
155        &self,
156        instance_id: InstanceId,
157        method_id: MethodId,
158        payload: Value,
159    ) -> Result<Value, RuntimeError> {
160        let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
161        let (tx, rx) = oneshot::channel();
162        self.inner.pending.lock().await.insert(request_id, tx);
163
164        let send_result = self
165            .inner
166            .sender
167            .send_envelope(&Envelope::Request(Request {
168                request_id: RequestId::new(request_id),
169                instance_id,
170                method_id,
171                payload,
172            }))
173            .await;
174        if let Err(err) = send_result {
175            self.inner.pending.lock().await.remove(&request_id);
176            return Err(RuntimeError::transport(
177                RuntimeErrorCode::InternalRuntimeError,
178                err.to_string(),
179            ));
180        }
181
182        rx.await.map_err(|_| {
183            RuntimeError::transport(
184                RuntimeErrorCode::InternalRuntimeError,
185                "response channel closed before request completed",
186            )
187        })?
188    }
189
190    pub async fn call_timeout(
191        &self,
192        instance_id: InstanceId,
193        method_id: MethodId,
194        payload: Value,
195        timeout: Duration,
196    ) -> Result<Value, RuntimeError> {
197        tokio::time::timeout(timeout, self.call(instance_id, method_id, payload))
198            .await
199            .map_err(|_| {
200                RuntimeError::runtime(RuntimeErrorCode::RequestTimeout, "request timed out")
201            })?
202    }
203
204    pub async fn resolve_instance_ids(&self, names: Vec<String>) -> Result<Vec<u64>, RuntimeError> {
205        let response = self
206            .call(
207                activation_instance_id(),
208                MethodId::new(RESOLVE_INSTANCE_IDS_METHOD_ID),
209                encode_resolve_instance_ids_request(&ResolveInstanceIdsRequest {
210                    instance_names: names,
211                }),
212            )
213            .await?;
214        Ok(decode_resolve_instance_ids_response(&response)?.instance_ids)
215    }
216
217    pub async fn create_instance(
218        &self,
219        service_guid: ServiceGuid,
220        create_payload: Option<Vec<u8>>,
221        options: BTreeMap<String, String>,
222    ) -> Result<InstanceId, RuntimeError> {
223        let response = self
224            .call(
225                activation_instance_id(),
226                MethodId::new(CREATE_INSTANCE_METHOD_ID),
227                encode_create_instance_request(&CreateInstanceRequest {
228                    service_guid,
229                    create_payload,
230                    options,
231                }),
232            )
233            .await?;
234        Ok(decode_create_instance_response(&response)?.instance_id)
235    }
236
237    pub async fn release_instance(&self, instance_id: InstanceId) -> Result<(), RuntimeError> {
238        let response = self
239            .call(
240                activation_instance_id(),
241                MethodId::new(RELEASE_INSTANCE_METHOD_ID),
242                encode_release_instance_request(&ReleaseInstanceRequest { instance_id }),
243            )
244            .await?;
245        decode_release_instance_response(&response)?;
246        Ok(())
247    }
248
249    pub async fn list_instances(
250        &self,
251        service_guid: Option<ServiceGuid>,
252    ) -> Result<Vec<rpc_runtime_activation::InstanceDescriptor>, RuntimeError> {
253        let response = self
254            .call(
255                activation_instance_id(),
256                MethodId::new(LIST_INSTANCES_METHOD_ID),
257                encode_list_instances_request(&ListInstancesRequest { service_guid }),
258            )
259            .await?;
260        Ok(decode_list_instances_response(&response)?.instances)
261    }
262
263    pub fn subscribe_notifications(
264        &self,
265        instance_id_filter: Option<InstanceId>,
266        notification_id_filter: Option<u32>,
267    ) -> mpsc::UnboundedReceiver<Notification> {
268        let mut source = self.inner.notifications.subscribe();
269        let (tx, rx) = mpsc::unbounded_channel();
270        tokio::spawn(async move {
271            loop {
272                let Ok(notification) = source.recv().await else {
273                    break;
274                };
275                let instance_matches = instance_id_filter
276                    .is_none_or(|expected| notification.instance_id == Some(expected));
277                let notification_matches = notification_id_filter
278                    .is_none_or(|expected| notification.notification_id.get() == expected);
279                if instance_matches && notification_matches && tx.send(notification).is_err() {
280                    break;
281                }
282            }
283        });
284        rx
285    }
286
287    pub async fn goodbye(&self, message: impl Into<String>) -> Result<(), RuntimeError> {
288        self.inner
289            .sender
290            .send_envelope(&Envelope::Goodbye(rpc_runtime_core::Goodbye {
291                reason_code: 0,
292                message: Some(message.into()),
293            }))
294            .await
295            .map_err(|err| {
296                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
297            })
298    }
299}
300
301fn spawn_receive_loop(inner: Arc<ClientInner>, mut receiver: RpcReceiver) {
302    tokio::spawn(async move {
303        loop {
304            let envelope = match receiver.recv_envelope().await {
305                Ok(Some(envelope)) => envelope,
306                Ok(None) => {
307                    fail_pending(
308                        &inner,
309                        RuntimeError::transport(
310                            RuntimeErrorCode::InternalRuntimeError,
311                            "server disconnected",
312                        ),
313                    )
314                    .await;
315                    break;
316                }
317                Err(err) => {
318                    fail_pending(
319                        &inner,
320                        RuntimeError::transport(
321                            RuntimeErrorCode::InternalRuntimeError,
322                            err.to_string(),
323                        ),
324                    )
325                    .await;
326                    break;
327                }
328            };
329            match envelope {
330                Envelope::ResponseOk(response) => {
331                    complete_pending(&inner, response.request_id.get(), Ok(response.payload)).await;
332                }
333                Envelope::ResponseError(response) => {
334                    complete_pending(
335                        &inner,
336                        response.request_id.get(),
337                        Err(RuntimeError::new(
338                            runtime_error_code(response.error_code),
339                            error_kind(response.error_kind),
340                            response.error_message.unwrap_or_default(),
341                        )),
342                    )
343                    .await;
344                }
345                Envelope::Notification(notification) => {
346                    let _ = inner.notifications.send(notification);
347                }
348                _ => {
349                    fail_pending(
350                        &inner,
351                        RuntimeError::protocol(
352                            RuntimeErrorCode::InvalidEnvelope,
353                            "client received invalid envelope kind",
354                        ),
355                    )
356                    .await;
357                    break;
358                }
359            }
360        }
361    });
362}
363
364async fn complete_pending(
365    inner: &ClientInner,
366    request_id: u64,
367    result: Result<Value, RuntimeError>,
368) {
369    if let Some(sender) = inner.pending.lock().await.remove(&request_id) {
370        let _ = sender.send(result);
371    }
372}
373
374async fn fail_pending(inner: &ClientInner, error: RuntimeError) {
375    let pending = std::mem::take(&mut *inner.pending.lock().await);
376    for (_, sender) in pending {
377        let _ = sender.send(Err(error.clone()));
378    }
379}
380
381fn client_capabilities() -> CapabilityFlags {
382    CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
383        | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
384        | CapabilityFlags::SERVICE_ACTIVATION
385        | CapabilityFlags::GOODBYE
386}
387
388fn runtime_error_code(value: i32) -> RuntimeErrorCode {
389    match value {
390        1001 => RuntimeErrorCode::UnknownMessageKind,
391        1002 => RuntimeErrorCode::UnsupportedProtocolVersion,
392        1003 => RuntimeErrorCode::InvalidEnvelope,
393        1004 => RuntimeErrorCode::InvalidRequestId,
394        1005 => RuntimeErrorCode::InvalidInstanceId,
395        1006 => RuntimeErrorCode::InstanceNotFound,
396        1007 => RuntimeErrorCode::MethodNotFound,
397        1008 => RuntimeErrorCode::NotificationNotFound,
398        1009 => RuntimeErrorCode::PayloadDecodeFailed,
399        1010 => RuntimeErrorCode::PayloadEncodeFailed,
400        1011 => RuntimeErrorCode::ServiceActivationNotSupported,
401        1012 => RuntimeErrorCode::ServiceGuidNotFound,
402        1013 => RuntimeErrorCode::InstanceReleaseNotAllowed,
403        1014 => RuntimeErrorCode::RequestTimeout,
404        1015 => RuntimeErrorCode::UnsupportedCapability,
405        1016 => RuntimeErrorCode::BusinessErrorDeclared,
406        1017 => RuntimeErrorCode::DuplicateRequestId,
407        1018 => RuntimeErrorCode::RequestCancelUnsupported,
408        1019 => RuntimeErrorCode::AccessDenied,
409        _ => RuntimeErrorCode::InternalRuntimeError,
410    }
411}
412
413fn error_kind(value: u8) -> ErrorKind {
414    match value {
415        1 => ErrorKind::Transport,
416        2 => ErrorKind::Protocol,
417        3 => ErrorKind::Runtime,
418        4 => ErrorKind::Business,
419        5 => ErrorKind::Timeout,
420        6 => ErrorKind::Cancelled,
421        _ => ErrorKind::Runtime,
422    }
423}