Skip to main content

rpc_runtime_server/
lib.rs

1use std::collections::{BTreeMap, HashMap};
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use rmpv::Value;
8use rpc_runtime_activation::{
9    ACTIVATION_INSTANCE_ID_VALUE, ActivationMode, CREATE_INSTANCE_METHOD_ID,
10    CreateInstanceResponse, InstanceDescriptor, LIST_INSTANCES_METHOD_ID, ListInstancesResponse,
11    RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID, ReleaseInstanceResponse,
12    ResolveInstanceIdsResponse, activation_instance_id, activation_service_guid,
13    decode_create_instance_request, decode_list_instances_request, decode_release_instance_request,
14    decode_resolve_instance_ids_request, encode_create_instance_response,
15    encode_list_instances_response, encode_release_instance_response,
16    encode_resolve_instance_ids_response,
17};
18use rpc_runtime_core::{
19    CapabilityFlags, Envelope, HelloAck, InstanceId, MethodId, Notification,
20    RUNTIME_PROTOCOL_VERSION, Request, RequestId, ResponseError, ResponseOk, Role, ServiceGuid,
21};
22use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
23use rpc_runtime_transport::{RpcConnection, RpcListener, RpcReceiver, RpcSender};
24use tokio::sync::RwLock;
25
26pub type HandlerFuture = Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + Send>>;
27
28pub trait RpcServiceHandler: Send + Sync {
29    fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture;
30}
31
32impl<F> RpcServiceHandler for F
33where
34    F: Send + Sync + 'static,
35    F: Fn(RpcCallContext, MethodId, Value) -> HandlerFuture,
36{
37    fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
38        self(ctx, method_id, payload)
39    }
40}
41
42pub type FactoryFuture =
43    Pin<Box<dyn Future<Output = Result<Arc<dyn RpcServiceHandler>, RuntimeError>> + Send>>;
44
45pub trait RpcServiceFactory: Send + Sync {
46    fn create(
47        &self,
48        ctx: RpcCallContext,
49        create_payload: Option<Vec<u8>>,
50        options: BTreeMap<String, String>,
51    ) -> FactoryFuture;
52}
53
54impl<F> RpcServiceFactory for F
55where
56    F: Send + Sync + 'static,
57    F: Fn(RpcCallContext, Option<Vec<u8>>, BTreeMap<String, String>) -> FactoryFuture,
58{
59    fn create<'a>(
60        &self,
61        ctx: RpcCallContext,
62        create_payload: Option<Vec<u8>>,
63        options: BTreeMap<String, String>,
64    ) -> FactoryFuture {
65        self(ctx, create_payload, options)
66    }
67}
68
69#[derive(Clone)]
70pub struct RpcCallContext {
71    connection_id: u64,
72    instance_id: InstanceId,
73    sender: RpcSender,
74}
75
76impl RpcCallContext {
77    pub fn connection_id(&self) -> u64 {
78        self.connection_id
79    }
80
81    pub fn instance_id(&self) -> InstanceId {
82        self.instance_id
83    }
84
85    pub async fn notify(
86        &self,
87        instance_id: Option<InstanceId>,
88        notification_id: u32,
89        payload: Value,
90    ) -> Result<(), RuntimeError> {
91        self.sender
92            .send_envelope(&Envelope::Notification(Notification {
93                instance_id,
94                notification_id: rpc_runtime_core::NotificationId::new(notification_id),
95                payload,
96            }))
97            .await
98            .map_err(|err| {
99                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
100            })
101    }
102
103    pub async fn notify_bound(
104        &self,
105        notification_id: u32,
106        payload: Value,
107    ) -> Result<(), RuntimeError> {
108        self.notify(Some(self.instance_id), notification_id, payload)
109            .await
110    }
111}
112
113#[derive(Clone)]
114pub struct RpcServer {
115    state: Arc<ServerState>,
116}
117
118pub struct RpcServerBuilder {
119    state: ServerState,
120}
121
122impl RpcServerBuilder {
123    pub fn new() -> Self {
124        let mut state = ServerState::new();
125        state.insert_activation_instance();
126        Self { state }
127    }
128
129    pub fn register_named_instance(
130        &mut self,
131        name: impl Into<String>,
132        service_guid: ServiceGuid,
133        methods: impl IntoIterator<Item = u32>,
134        handler: Arc<dyn RpcServiceHandler>,
135    ) -> InstanceId {
136        self.state.insert_instance(NewInstance {
137            service_guid,
138            name: Some(name.into()),
139            activation_mode: ActivationMode::NamedPrecreated,
140            releasable: false,
141            owner_connection_id: None,
142            methods: methods.into_iter().collect(),
143            handler,
144        })
145    }
146
147    pub fn register_singleton(
148        &mut self,
149        service_guid: ServiceGuid,
150        methods: impl IntoIterator<Item = u32>,
151        handler: Arc<dyn RpcServiceHandler>,
152    ) -> InstanceId {
153        self.state.insert_instance(NewInstance {
154            service_guid,
155            name: None,
156            activation_mode: ActivationMode::Singleton,
157            releasable: false,
158            owner_connection_id: None,
159            methods: methods.into_iter().collect(),
160            handler,
161        })
162    }
163
164    pub fn register_factory(
165        &mut self,
166        service_guid: ServiceGuid,
167        methods: impl IntoIterator<Item = u32>,
168        factory: Arc<dyn RpcServiceFactory>,
169    ) {
170        self.state.factories.insert(
171            service_guid.get(),
172            FactoryEntry {
173                methods: methods.into_iter().collect(),
174                factory,
175            },
176        );
177    }
178
179    pub fn build(self) -> RpcServer {
180        RpcServer {
181            state: Arc::new(self.state),
182        }
183    }
184}
185
186impl Default for RpcServerBuilder {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192impl RpcServer {
193    pub async fn serve_connection<C>(&self, connection: C) -> Result<(), RuntimeError>
194    where
195        C: Into<RpcConnection>,
196    {
197        let connection_id = self
198            .state
199            .next_connection_id
200            .fetch_add(1, Ordering::Relaxed);
201        let (sender, mut receiver) = connection.into().split();
202        self.perform_handshake(&sender, &mut receiver).await?;
203
204        while let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
205            RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
206        })? {
207            match envelope {
208                Envelope::Request(request) => {
209                    let state = Arc::clone(&self.state);
210                    let sender = sender.clone();
211                    tokio::spawn(async move {
212                        let request_id = request.request_id;
213                        let response =
214                            dispatch_request(state, sender.clone(), connection_id, request).await;
215                        let envelope = match response {
216                            Ok(payload) => Envelope::ResponseOk(ResponseOk {
217                                request_id,
218                                payload,
219                            }),
220                            Err(error) => runtime_error_response(request_id, error),
221                        };
222                        let _ = sender.send_envelope(&envelope).await;
223                    });
224                }
225                Envelope::Goodbye(_) => break,
226                _ => {
227                    return Err(RuntimeError::protocol(
228                        RuntimeErrorCode::InvalidEnvelope,
229                        "server expected request envelope",
230                    ));
231                }
232            }
233        }
234
235        self.state.cleanup_connection(connection_id).await;
236        Ok(())
237    }
238
239    pub async fn serve_listener<L>(&self, mut listener: L) -> Result<(), RuntimeError>
240    where
241        L: RpcListener + Send,
242    {
243        loop {
244            let connection = listener.accept().await.map_err(|err| {
245                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
246            })?;
247            let server = self.clone();
248            tokio::spawn(async move {
249                let _ = server.serve_connection(connection).await;
250            });
251        }
252    }
253
254    pub fn spawn_listener<L>(
255        &self,
256        listener: L,
257    ) -> tokio::task::JoinHandle<Result<(), RuntimeError>>
258    where
259        L: RpcListener + Send + 'static,
260    {
261        let server = self.clone();
262        tokio::spawn(async move { server.serve_listener(listener).await })
263    }
264
265    async fn perform_handshake(
266        &self,
267        sender: &RpcSender,
268        receiver: &mut RpcReceiver,
269    ) -> Result<(), RuntimeError> {
270        let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
271            RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
272        })?
273        else {
274            return Err(RuntimeError::transport(
275                RuntimeErrorCode::InternalRuntimeError,
276                "client disconnected during handshake",
277            ));
278        };
279        let Envelope::Hello(hello) = envelope else {
280            return Err(RuntimeError::protocol(
281                RuntimeErrorCode::InvalidEnvelope,
282                "expected HELLO during handshake",
283            ));
284        };
285        if hello.protocol_version != RUNTIME_PROTOCOL_VERSION || hello.role != Role::Client {
286            return Err(RuntimeError::protocol(
287                RuntimeErrorCode::UnsupportedProtocolVersion,
288                "unsupported client handshake",
289            ));
290        }
291        sender
292            .send_envelope(&Envelope::HelloAck(HelloAck {
293                protocol_version: RUNTIME_PROTOCOL_VERSION,
294                accepted_capability_bits: server_capabilities() & hello.capability_bits,
295                max_message_size: hello.max_message_size,
296                options: Vec::new(),
297            }))
298            .await
299            .map_err(|err| {
300                RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
301            })
302    }
303
304    pub async fn list_instances(&self) -> Vec<InstanceDescriptor> {
305        self.state.list_instances(None).await
306    }
307}
308
309async fn dispatch_request(
310    state: Arc<ServerState>,
311    sender: RpcSender,
312    connection_id: u64,
313    request: Request,
314) -> Result<Value, RuntimeError> {
315    if request.instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE {
316        return dispatch_activation(state, sender, connection_id, request).await;
317    }
318
319    let instance = state.get_instance(request.instance_id).await?;
320    if !instance.methods.contains(&request.method_id.get()) {
321        return Err(RuntimeError::runtime(
322            RuntimeErrorCode::MethodNotFound,
323            format!("method id `{}` was not found", request.method_id.get()),
324        ));
325    }
326    let ctx = RpcCallContext {
327        connection_id,
328        instance_id: request.instance_id,
329        sender,
330    };
331    instance
332        .handler
333        .call(ctx, request.method_id, request.payload)
334        .await
335}
336
337async fn dispatch_activation(
338    state: Arc<ServerState>,
339    sender: RpcSender,
340    connection_id: u64,
341    request: Request,
342) -> Result<Value, RuntimeError> {
343    let ctx = RpcCallContext {
344        connection_id,
345        instance_id: request.instance_id,
346        sender,
347    };
348    match request.method_id.get() {
349        RESOLVE_INSTANCE_IDS_METHOD_ID => {
350            let request = decode_resolve_instance_ids_request(&request.payload)?;
351            let ids = state.resolve_instance_ids(&request.instance_names).await;
352            Ok(encode_resolve_instance_ids_response(
353                &ResolveInstanceIdsResponse { instance_ids: ids },
354            ))
355        }
356        CREATE_INSTANCE_METHOD_ID => {
357            let request = decode_create_instance_request(&request.payload)?;
358            let factory = state.get_factory(request.service_guid).ok_or_else(|| {
359                RuntimeError::runtime(
360                    RuntimeErrorCode::ServiceGuidNotFound,
361                    "service factory was not found",
362                )
363            })?;
364            let handler = factory
365                .factory
366                .create(ctx, request.create_payload, request.options)
367                .await?;
368            let instance_id = state
369                .insert_client_instance(
370                    request.service_guid,
371                    connection_id,
372                    factory.methods.clone(),
373                    handler,
374                )
375                .await;
376            Ok(encode_create_instance_response(&CreateInstanceResponse {
377                instance_id,
378            }))
379        }
380        RELEASE_INSTANCE_METHOD_ID => {
381            let request = decode_release_instance_request(&request.payload)?;
382            state
383                .release_instance(connection_id, request.instance_id)
384                .await?;
385            Ok(encode_release_instance_response(&ReleaseInstanceResponse))
386        }
387        LIST_INSTANCES_METHOD_ID => {
388            let request = decode_list_instances_request(&request.payload)?;
389            let instances = state.list_instances(request.service_guid).await;
390            Ok(encode_list_instances_response(&ListInstancesResponse {
391                instances,
392            }))
393        }
394        _ => Err(RuntimeError::runtime(
395            RuntimeErrorCode::MethodNotFound,
396            "activation method was not found",
397        )),
398    }
399}
400
401fn runtime_error_response(request_id: RequestId, error: RuntimeError) -> Envelope {
402    Envelope::ResponseError(ResponseError {
403        request_id,
404        error_code: error.code.as_i32(),
405        error_kind: error.kind.as_u8(),
406        error_message: Some(error.message),
407        error_details: Value::Nil,
408    })
409}
410
411fn server_capabilities() -> CapabilityFlags {
412    CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
413        | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
414        | CapabilityFlags::SERVICE_ACTIVATION
415        | CapabilityFlags::GOODBYE
416}
417
418struct ServerState {
419    next_connection_id: AtomicU64,
420    next_instance_id: AtomicU64,
421    instances: RwLock<HashMap<u64, InstanceEntry>>,
422    names: RwLock<HashMap<String, u64>>,
423    factories: HashMap<uuid::Uuid, FactoryEntry>,
424}
425
426impl ServerState {
427    fn new() -> Self {
428        Self {
429            next_connection_id: AtomicU64::new(1),
430            next_instance_id: AtomicU64::new(2),
431            instances: RwLock::new(HashMap::new()),
432            names: RwLock::new(HashMap::new()),
433            factories: HashMap::new(),
434        }
435    }
436
437    fn insert_activation_instance(&mut self) {
438        self.instances.get_mut().insert(
439            ACTIVATION_INSTANCE_ID_VALUE,
440            InstanceEntry {
441                instance_id: activation_instance_id(),
442                service_guid: activation_service_guid(),
443                instance_name: Some("rpc.runtime.Activation".to_string()),
444                activation_mode: ActivationMode::Singleton,
445                releasable: false,
446                owner_connection_id: None,
447                methods: vec![
448                    RESOLVE_INSTANCE_IDS_METHOD_ID,
449                    CREATE_INSTANCE_METHOD_ID,
450                    RELEASE_INSTANCE_METHOD_ID,
451                    LIST_INSTANCES_METHOD_ID,
452                ],
453                handler: Arc::new(ActivationMarker),
454            },
455        );
456    }
457
458    fn insert_instance(&mut self, instance: NewInstance) -> InstanceId {
459        let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
460        let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
461        if let Some(name) = &instance.name {
462            self.names.get_mut().insert(name.clone(), id);
463        }
464        self.instances.get_mut().insert(
465            id,
466            InstanceEntry {
467                instance_id,
468                service_guid: instance.service_guid,
469                instance_name: instance.name,
470                activation_mode: instance.activation_mode,
471                releasable: instance.releasable,
472                owner_connection_id: instance.owner_connection_id,
473                methods: instance.methods,
474                handler: instance.handler,
475            },
476        );
477        instance_id
478    }
479
480    async fn insert_client_instance(
481        &self,
482        service_guid: ServiceGuid,
483        connection_id: u64,
484        methods: Vec<u32>,
485        handler: Arc<dyn RpcServiceHandler>,
486    ) -> InstanceId {
487        let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
488        let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
489        self.instances.write().await.insert(
490            id,
491            InstanceEntry {
492                instance_id,
493                service_guid,
494                instance_name: None,
495                activation_mode: ActivationMode::Instantiable,
496                releasable: true,
497                owner_connection_id: Some(connection_id),
498                methods,
499                handler,
500            },
501        );
502        instance_id
503    }
504
505    async fn get_instance(&self, instance_id: InstanceId) -> Result<InstanceEntry, RuntimeError> {
506        self.instances
507            .read()
508            .await
509            .get(&instance_id.get())
510            .cloned()
511            .ok_or_else(|| {
512                RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
513            })
514    }
515
516    fn get_factory(&self, service_guid: ServiceGuid) -> Option<FactoryEntry> {
517        self.factories.get(&service_guid.get()).cloned()
518    }
519
520    async fn resolve_instance_ids(&self, names: &[String]) -> Vec<u64> {
521        let index = self.names.read().await;
522        names
523            .iter()
524            .map(|name| index.get(name).copied().unwrap_or(0))
525            .collect()
526    }
527
528    async fn release_instance(
529        &self,
530        connection_id: u64,
531        instance_id: InstanceId,
532    ) -> Result<(), RuntimeError> {
533        let mut instances = self.instances.write().await;
534        let entry = instances.get(&instance_id.get()).ok_or_else(|| {
535            RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
536        })?;
537        if !entry.releasable {
538            return Err(RuntimeError::runtime(
539                RuntimeErrorCode::InstanceReleaseNotAllowed,
540                "instance is not releasable",
541            ));
542        }
543        if entry.owner_connection_id != Some(connection_id) {
544            return Err(RuntimeError::runtime(
545                RuntimeErrorCode::AccessDenied,
546                "instance is owned by another connection",
547            ));
548        }
549        instances.remove(&instance_id.get());
550        Ok(())
551    }
552
553    async fn cleanup_connection(&self, connection_id: u64) {
554        self.instances
555            .write()
556            .await
557            .retain(|_, entry| entry.owner_connection_id != Some(connection_id));
558    }
559
560    async fn list_instances(&self, service_guid: Option<ServiceGuid>) -> Vec<InstanceDescriptor> {
561        let mut values = self
562            .instances
563            .read()
564            .await
565            .values()
566            .filter(|entry| service_guid.is_none_or(|guid| guid == entry.service_guid))
567            .map(InstanceEntry::descriptor)
568            .collect::<Vec<_>>();
569        values.sort_by_key(|entry| entry.instance_id.get());
570        values
571    }
572}
573
574struct NewInstance {
575    service_guid: ServiceGuid,
576    name: Option<String>,
577    activation_mode: ActivationMode,
578    releasable: bool,
579    owner_connection_id: Option<u64>,
580    methods: Vec<u32>,
581    handler: Arc<dyn RpcServiceHandler>,
582}
583
584#[derive(Clone)]
585struct InstanceEntry {
586    instance_id: InstanceId,
587    service_guid: ServiceGuid,
588    instance_name: Option<String>,
589    activation_mode: ActivationMode,
590    releasable: bool,
591    owner_connection_id: Option<u64>,
592    methods: Vec<u32>,
593    handler: Arc<dyn RpcServiceHandler>,
594}
595
596impl InstanceEntry {
597    fn descriptor(&self) -> InstanceDescriptor {
598        InstanceDescriptor {
599            instance_id: self.instance_id,
600            instance_name: self.instance_name.clone(),
601            service_guid: self.service_guid,
602            activation_mode: self.activation_mode,
603            releasable: self.releasable,
604        }
605    }
606}
607
608#[derive(Clone)]
609struct FactoryEntry {
610    methods: Vec<u32>,
611    factory: Arc<dyn RpcServiceFactory>,
612}
613
614struct ActivationMarker;
615
616impl RpcServiceHandler for ActivationMarker {
617    fn call(&self, _: RpcCallContext, _: MethodId, _: Value) -> HandlerFuture {
618        Box::pin(async {
619            Err(RuntimeError::runtime(
620                RuntimeErrorCode::InternalRuntimeError,
621                "activation marker should not be dispatched directly",
622            ))
623        })
624    }
625}