1use std::collections::{BTreeMap, HashMap};
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::{Duration, Instant};
7use std::{panic, panic::AssertUnwindSafe};
8
9use rmpv::Value;
10use rpc_runtime_activation::{
11 ACTIVATION_INSTANCE_ID_VALUE, ActivationMode, CREATE_INSTANCE_METHOD_ID,
12 CreateInstanceResponse, InstanceDescriptor, LIST_INSTANCES_METHOD_ID, ListInstancesResponse,
13 RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID, ReleaseInstanceResponse,
14 ResolveInstanceIdsResponse, activation_instance_id, activation_service_guid,
15 decode_create_instance_request, decode_list_instances_request, decode_release_instance_request,
16 decode_resolve_instance_ids_request, encode_create_instance_response,
17 encode_list_instances_response, encode_release_instance_response,
18 encode_resolve_instance_ids_response,
19};
20use rpc_runtime_core::{
21 CapabilityFlags, Envelope, HelloAck, InstanceId, MethodId, Notification, Options,
22 RUNTIME_PROTOCOL_VERSION, Request, RequestId, ResponseError, ResponseOk, Role, ServiceGuid,
23};
24use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
25pub use rpc_runtime_transport::ConnectionScope;
26use rpc_runtime_transport::{RpcConnection, RpcListener, RpcReceiver, RpcSender, TransportError};
27use tokio::sync::RwLock;
28use tracing::{debug, error, info, trace, warn};
29
30pub type HandlerFuture = Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + Send>>;
31
32pub trait RpcServiceHandler: Send + Sync {
33 fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture;
34}
35
36impl<F> RpcServiceHandler for F
37where
38 F: Send + Sync + 'static,
39 F: Fn(RpcCallContext, MethodId, Value) -> HandlerFuture,
40{
41 fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
42 self(ctx, method_id, payload)
43 }
44}
45
46pub type FactoryFuture =
47 Pin<Box<dyn Future<Output = Result<Arc<dyn RpcServiceHandler>, RuntimeError>> + Send>>;
48
49pub trait RpcServiceFactory: Send + Sync {
50 fn create(
51 &self,
52 ctx: RpcCallContext,
53 create_payload: Option<Vec<u8>>,
54 options: BTreeMap<String, String>,
55 ) -> FactoryFuture;
56}
57
58impl<F> RpcServiceFactory for F
59where
60 F: Send + Sync + 'static,
61 F: Fn(RpcCallContext, Option<Vec<u8>>, BTreeMap<String, String>) -> FactoryFuture,
62{
63 fn create<'a>(
64 &self,
65 ctx: RpcCallContext,
66 create_payload: Option<Vec<u8>>,
67 options: BTreeMap<String, String>,
68 ) -> FactoryFuture {
69 self(ctx, create_payload, options)
70 }
71}
72
73#[derive(Clone)]
74pub struct RpcCallContext {
75 connection_id: u64,
76 instance_id: InstanceId,
77 sender: RpcSender,
78}
79
80impl RpcCallContext {
81 pub fn connection_id(&self) -> u64 {
82 self.connection_id
83 }
84
85 pub fn instance_id(&self) -> InstanceId {
86 self.instance_id
87 }
88
89 pub async fn notify(
90 &self,
91 instance_id: Option<InstanceId>,
92 notification_id: u32,
93 payload: Value,
94 ) -> Result<(), RuntimeError> {
95 self.sender
96 .send_envelope(&Envelope::Notification(Notification {
97 instance_id,
98 notification_id: rpc_runtime_core::NotificationId::new(notification_id),
99 payload,
100 }))
101 .await
102 .map_err(|err| {
103 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
104 })
105 }
106
107 pub async fn notify_bound(
108 &self,
109 notification_id: u32,
110 payload: Value,
111 ) -> Result<(), RuntimeError> {
112 self.notify(Some(self.instance_id), notification_id, payload)
113 .await
114 }
115}
116
117#[derive(Clone)]
118pub struct RpcServer {
119 state: Arc<ServerState>,
120}
121
122pub struct RpcServerBuilder {
123 state: ServerState,
124}
125
126pub const DEFAULT_AUTH_TOKEN_OPTION_KEY: &str = "tripley.auth.token";
127
128pub trait RpcServerMetricsSink: Send + Sync {
129 fn record(&self, event: RpcServerMetricEvent);
130}
131
132impl<F> RpcServerMetricsSink for F
133where
134 F: Send + Sync + 'static + Fn(RpcServerMetricEvent),
135{
136 fn record(&self, event: RpcServerMetricEvent) {
137 self(event);
138 }
139}
140
141#[derive(Debug, Clone, PartialEq)]
142pub enum RpcServerMetricEvent {
143 ConnectionStarted {
144 connection_id: u64,
145 },
146 ConnectionEnded {
147 connection_id: u64,
148 success: bool,
149 },
150 HandshakeCompleted {
151 connection_id: u64,
152 },
153 HandshakeFailed {
154 connection_id: u64,
155 error_code: RuntimeErrorCode,
156 },
157 ListenerConnectionRejected {
158 error_code: RuntimeErrorCode,
159 },
160 RequestStarted {
161 connection_id: u64,
162 request_id: RequestId,
163 instance_id: InstanceId,
164 method_id: MethodId,
165 is_activation: bool,
166 },
167 RequestCompleted {
168 connection_id: u64,
169 request_id: RequestId,
170 instance_id: InstanceId,
171 method_id: MethodId,
172 is_activation: bool,
173 elapsed: Duration,
174 },
175 RequestFailed {
176 connection_id: u64,
177 request_id: RequestId,
178 instance_id: InstanceId,
179 method_id: MethodId,
180 is_activation: bool,
181 elapsed: Duration,
182 error_code: RuntimeErrorCode,
183 },
184 RequestSlow {
185 connection_id: u64,
186 request_id: RequestId,
187 instance_id: InstanceId,
188 method_id: MethodId,
189 is_activation: bool,
190 elapsed: Duration,
191 threshold: Duration,
192 },
193 ResponseSendFailed {
194 connection_id: u64,
195 request_id: RequestId,
196 },
197}
198
199#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
200pub struct RpcServerMetricsSnapshot {
201 pub connections_started: u64,
202 pub connections_ended: u64,
203 pub connections_ended_successfully: u64,
204 pub handshakes_completed: u64,
205 pub handshakes_failed: u64,
206 pub listener_connections_rejected: u64,
207 pub requests_started: u64,
208 pub requests_completed: u64,
209 pub requests_failed: u64,
210 pub requests_slow: u64,
211 pub response_send_failures: u64,
212 pub request_elapsed_total: Duration,
213 pub request_elapsed_max: Duration,
214}
215
216#[derive(Debug, Default)]
217pub struct RpcServerMetricsRecorder {
218 connections_started: AtomicU64,
219 connections_ended: AtomicU64,
220 connections_ended_successfully: AtomicU64,
221 handshakes_completed: AtomicU64,
222 handshakes_failed: AtomicU64,
223 listener_connections_rejected: AtomicU64,
224 requests_started: AtomicU64,
225 requests_completed: AtomicU64,
226 requests_failed: AtomicU64,
227 requests_slow: AtomicU64,
228 response_send_failures: AtomicU64,
229 request_elapsed_total_nanos: AtomicU64,
230 request_elapsed_max_nanos: AtomicU64,
231}
232
233impl RpcServerMetricsRecorder {
234 pub fn new() -> Self {
235 Self::default()
236 }
237
238 pub fn snapshot(&self) -> RpcServerMetricsSnapshot {
239 RpcServerMetricsSnapshot {
240 connections_started: self.connections_started.load(Ordering::Relaxed),
241 connections_ended: self.connections_ended.load(Ordering::Relaxed),
242 connections_ended_successfully: self
243 .connections_ended_successfully
244 .load(Ordering::Relaxed),
245 handshakes_completed: self.handshakes_completed.load(Ordering::Relaxed),
246 handshakes_failed: self.handshakes_failed.load(Ordering::Relaxed),
247 listener_connections_rejected: self
248 .listener_connections_rejected
249 .load(Ordering::Relaxed),
250 requests_started: self.requests_started.load(Ordering::Relaxed),
251 requests_completed: self.requests_completed.load(Ordering::Relaxed),
252 requests_failed: self.requests_failed.load(Ordering::Relaxed),
253 requests_slow: self.requests_slow.load(Ordering::Relaxed),
254 response_send_failures: self.response_send_failures.load(Ordering::Relaxed),
255 request_elapsed_total: Duration::from_nanos(
256 self.request_elapsed_total_nanos.load(Ordering::Relaxed),
257 ),
258 request_elapsed_max: Duration::from_nanos(
259 self.request_elapsed_max_nanos.load(Ordering::Relaxed),
260 ),
261 }
262 }
263
264 fn record_elapsed(&self, elapsed: Duration) {
265 let nanos = duration_nanos_u64(elapsed);
266 saturating_atomic_add(&self.request_elapsed_total_nanos, nanos);
267 update_atomic_max(&self.request_elapsed_max_nanos, nanos);
268 }
269}
270
271impl RpcServerMetricsSink for RpcServerMetricsRecorder {
272 fn record(&self, event: RpcServerMetricEvent) {
273 match event {
274 RpcServerMetricEvent::ConnectionStarted { .. } => {
275 self.connections_started.fetch_add(1, Ordering::Relaxed);
276 }
277 RpcServerMetricEvent::ConnectionEnded { success, .. } => {
278 self.connections_ended.fetch_add(1, Ordering::Relaxed);
279 if success {
280 self.connections_ended_successfully
281 .fetch_add(1, Ordering::Relaxed);
282 }
283 }
284 RpcServerMetricEvent::HandshakeCompleted { .. } => {
285 self.handshakes_completed.fetch_add(1, Ordering::Relaxed);
286 }
287 RpcServerMetricEvent::HandshakeFailed { .. } => {
288 self.handshakes_failed.fetch_add(1, Ordering::Relaxed);
289 }
290 RpcServerMetricEvent::ListenerConnectionRejected { .. } => {
291 self.listener_connections_rejected
292 .fetch_add(1, Ordering::Relaxed);
293 }
294 RpcServerMetricEvent::RequestStarted { .. } => {
295 self.requests_started.fetch_add(1, Ordering::Relaxed);
296 }
297 RpcServerMetricEvent::RequestCompleted { elapsed, .. } => {
298 self.requests_completed.fetch_add(1, Ordering::Relaxed);
299 self.record_elapsed(elapsed);
300 }
301 RpcServerMetricEvent::RequestFailed { elapsed, .. } => {
302 self.requests_failed.fetch_add(1, Ordering::Relaxed);
303 self.record_elapsed(elapsed);
304 }
305 RpcServerMetricEvent::RequestSlow { .. } => {
306 self.requests_slow.fetch_add(1, Ordering::Relaxed);
307 }
308 RpcServerMetricEvent::ResponseSendFailed { .. } => {
309 self.response_send_failures.fetch_add(1, Ordering::Relaxed);
310 }
311 }
312 }
313}
314
315#[derive(Debug, Clone, Copy, PartialEq, Eq)]
316pub struct RpcServerObservabilityConfig {
317 pub slow_call_threshold: Duration,
318 pub payload_preview_bytes: usize,
319 pub log_payload_preview: bool,
320}
321
322#[derive(Debug, Clone, PartialEq, Eq)]
323pub struct RpcServerSecurityConfig {
324 pub connection_scope: ConnectionScope,
325 pub auth: RpcServerAuthConfig,
326}
327
328impl RpcServerSecurityConfig {
329 pub fn remote_allowed(mut self) -> Self {
330 self.connection_scope = ConnectionScope::RemoteAllowed;
331 self
332 }
333
334 pub fn local_only(mut self) -> Self {
335 self.connection_scope = ConnectionScope::LocalOnly;
336 self
337 }
338
339 pub fn with_token(mut self, token: impl Into<String>) -> Self {
340 self.auth = RpcServerAuthConfig::token(token);
341 self
342 }
343
344 pub fn with_auth(mut self, auth: RpcServerAuthConfig) -> Self {
345 self.auth = auth;
346 self
347 }
348}
349
350impl Default for RpcServerSecurityConfig {
351 fn default() -> Self {
352 Self {
353 connection_scope: ConnectionScope::LocalOnly,
354 auth: RpcServerAuthConfig::Disabled,
355 }
356 }
357}
358
359#[derive(Debug, Clone, PartialEq, Eq)]
360pub enum RpcServerAuthConfig {
361 Disabled,
362 Token { token: String, option_key: String },
363}
364
365impl RpcServerAuthConfig {
366 pub fn token(token: impl Into<String>) -> Self {
367 Self::Token {
368 token: token.into(),
369 option_key: DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
370 }
371 }
372
373 pub fn token_with_option_key(token: impl Into<String>, option_key: impl Into<String>) -> Self {
374 Self::Token {
375 token: token.into(),
376 option_key: option_key.into(),
377 }
378 }
379}
380
381impl RpcServerObservabilityConfig {
382 pub fn with_slow_call_threshold(mut self, threshold: Duration) -> Self {
383 self.slow_call_threshold = threshold;
384 self
385 }
386
387 pub fn with_payload_preview(mut self, bytes: usize) -> Self {
388 self.payload_preview_bytes = bytes;
389 self.log_payload_preview = bytes > 0;
390 self
391 }
392}
393
394impl Default for RpcServerObservabilityConfig {
395 fn default() -> Self {
396 Self {
397 slow_call_threshold: Duration::from_millis(500),
398 payload_preview_bytes: 0,
399 log_payload_preview: false,
400 }
401 }
402}
403
404impl RpcServerBuilder {
405 pub fn new() -> Self {
406 let mut state = ServerState::new();
407 state.insert_activation_instance();
408 Self { state }
409 }
410
411 pub fn observability(mut self, config: RpcServerObservabilityConfig) -> Self {
412 self.state.observability = config;
413 self
414 }
415
416 pub fn set_observability(&mut self, config: RpcServerObservabilityConfig) -> &mut Self {
417 self.state.observability = config;
418 self
419 }
420
421 pub fn metrics_sink(mut self, sink: Arc<dyn RpcServerMetricsSink>) -> Self {
422 self.state.metrics_sink = Some(sink);
423 self
424 }
425
426 pub fn set_metrics_sink(&mut self, sink: Arc<dyn RpcServerMetricsSink>) -> &mut Self {
427 self.state.metrics_sink = Some(sink);
428 self
429 }
430
431 pub fn security(mut self, config: RpcServerSecurityConfig) -> Self {
432 self.state.security = config;
433 self
434 }
435
436 pub fn set_security(&mut self, config: RpcServerSecurityConfig) -> &mut Self {
437 self.state.security = config;
438 self
439 }
440
441 pub fn register_named_instance(
442 &mut self,
443 name: impl Into<String>,
444 service_guid: ServiceGuid,
445 methods: impl IntoIterator<Item = u32>,
446 handler: Arc<dyn RpcServiceHandler>,
447 ) -> InstanceId {
448 self.state.insert_instance(NewInstance {
449 service_guid,
450 name: Some(name.into()),
451 activation_mode: ActivationMode::NamedPrecreated,
452 releasable: false,
453 owner_connection_id: None,
454 methods: methods.into_iter().collect(),
455 handler,
456 })
457 }
458
459 pub fn register_singleton(
460 &mut self,
461 service_guid: ServiceGuid,
462 methods: impl IntoIterator<Item = u32>,
463 handler: Arc<dyn RpcServiceHandler>,
464 ) -> InstanceId {
465 self.state.insert_instance(NewInstance {
466 service_guid,
467 name: None,
468 activation_mode: ActivationMode::Singleton,
469 releasable: false,
470 owner_connection_id: None,
471 methods: methods.into_iter().collect(),
472 handler,
473 })
474 }
475
476 pub fn register_factory(
477 &mut self,
478 service_guid: ServiceGuid,
479 methods: impl IntoIterator<Item = u32>,
480 factory: Arc<dyn RpcServiceFactory>,
481 ) {
482 self.state.factories.insert(
483 service_guid.get(),
484 FactoryEntry {
485 methods: methods.into_iter().collect(),
486 factory,
487 },
488 );
489 }
490
491 pub fn build(self) -> RpcServer {
492 if self.state.security.connection_scope == ConnectionScope::RemoteAllowed
493 && self.state.security.auth == RpcServerAuthConfig::Disabled
494 {
495 warn!("rpc server allows remote connections without token authentication");
496 }
497 RpcServer {
498 state: Arc::new(self.state),
499 }
500 }
501}
502
503impl Default for RpcServerBuilder {
504 fn default() -> Self {
505 Self::new()
506 }
507}
508
509impl RpcServer {
510 pub async fn serve_connection<C>(&self, connection: C) -> Result<(), RuntimeError>
511 where
512 C: Into<RpcConnection>,
513 {
514 let connection_id = self
515 .state
516 .next_connection_id
517 .fetch_add(1, Ordering::Relaxed);
518 self.state
519 .record_metric(RpcServerMetricEvent::ConnectionStarted { connection_id });
520 info!(connection_id, "rpc server connection started");
521 let (sender, mut receiver) = connection.into().split();
522
523 let result = async {
524 if let Err(error) = self
525 .perform_handshake(connection_id, &sender, &mut receiver)
526 .await
527 {
528 self.state
529 .record_metric(RpcServerMetricEvent::HandshakeFailed {
530 connection_id,
531 error_code: error.code,
532 });
533 return Err(error);
534 }
535 self.state
536 .record_metric(RpcServerMetricEvent::HandshakeCompleted { connection_id });
537
538 loop {
539 let envelope = match receiver.recv_envelope().await {
540 Ok(Some(envelope)) => envelope,
541 Ok(None) => {
542 debug!(connection_id, "rpc server connection closed by peer");
543 break;
544 }
545 Err(err) => {
546 let error = RuntimeError::transport(
547 RuntimeErrorCode::InternalRuntimeError,
548 err.to_string(),
549 );
550 warn!(
551 connection_id,
552 error_code = error.code.as_i32(),
553 error_kind = error.kind.as_u8(),
554 error_message = %error.message,
555 "rpc server failed to receive envelope"
556 );
557 return Err(error);
558 }
559 };
560
561 match envelope {
562 Envelope::Request(request) => {
563 let state = Arc::clone(&self.state);
564 let sender = sender.clone();
565 let observability = self.state.observability;
566 tokio::spawn(async move {
567 handle_request(state, sender, connection_id, request, observability)
568 .await;
569 });
570 }
571 Envelope::Goodbye(goodbye) => {
572 info!(
573 connection_id,
574 reason_code = goodbye.reason_code,
575 message = goodbye.message.as_deref().unwrap_or(""),
576 "rpc server received goodbye"
577 );
578 break;
579 }
580 envelope => {
581 let error = RuntimeError::protocol(
582 RuntimeErrorCode::InvalidEnvelope,
583 "server expected request envelope",
584 );
585 warn!(
586 connection_id,
587 envelope_kind = envelope_name(&envelope),
588 error_code = error.code.as_i32(),
589 error_kind = error.kind.as_u8(),
590 error_message = %error.message,
591 "rpc server received invalid envelope"
592 );
593 return Err(error);
594 }
595 }
596 }
597
598 Ok(())
599 }
600 .await;
601
602 self.state.cleanup_connection(connection_id).await;
603 debug!(connection_id, "rpc server connection cleanup completed");
604 self.state
605 .record_metric(RpcServerMetricEvent::ConnectionEnded {
606 connection_id,
607 success: result.is_ok(),
608 });
609 if let Err(error) = &result {
610 warn!(
611 connection_id,
612 error_code = error.code.as_i32(),
613 error_kind = error.kind.as_u8(),
614 error_message = %error.message,
615 "rpc server connection ended with error"
616 );
617 } else {
618 info!(connection_id, "rpc server connection ended");
619 }
620 result
621 }
622
623 pub async fn serve_listener<L>(&self, mut listener: L) -> Result<(), RuntimeError>
624 where
625 L: RpcListener + Send,
626 {
627 listener.set_connection_scope(self.state.security.connection_scope);
628 loop {
629 let connection = match listener.accept().await {
630 Ok(connection) => connection,
631 Err(err) => {
632 let access_denied = is_transport_access_denied(&err);
633 let error = RuntimeError::transport(
634 if access_denied {
635 RuntimeErrorCode::AccessDenied
636 } else {
637 RuntimeErrorCode::InternalRuntimeError
638 },
639 err.to_string(),
640 );
641 if access_denied {
642 self.state.record_metric(
643 RpcServerMetricEvent::ListenerConnectionRejected {
644 error_code: RuntimeErrorCode::AccessDenied,
645 },
646 );
647 warn!(
648 error_code = error.code.as_i32(),
649 error_kind = error.kind.as_u8(),
650 error_message = %error.message,
651 "rpc server listener rejected connection"
652 );
653 continue;
654 }
655 error!(
656 error_code = error.code.as_i32(),
657 error_kind = error.kind.as_u8(),
658 error_message = %error.message,
659 "rpc server listener accept failed"
660 );
661 return Err(error);
662 }
663 };
664 let server = self.clone();
665 tokio::spawn(async move {
666 if let Err(error) = server.serve_connection(connection).await {
667 warn!(
668 error_code = error.code.as_i32(),
669 error_kind = error.kind.as_u8(),
670 error_message = %error.message,
671 "rpc server listener connection task failed"
672 );
673 }
674 });
675 }
676 }
677
678 pub fn spawn_listener<L>(
679 &self,
680 listener: L,
681 ) -> tokio::task::JoinHandle<Result<(), RuntimeError>>
682 where
683 L: RpcListener + Send + 'static,
684 {
685 let server = self.clone();
686 tokio::spawn(async move { server.serve_listener(listener).await })
687 }
688
689 async fn perform_handshake(
690 &self,
691 connection_id: u64,
692 sender: &RpcSender,
693 receiver: &mut RpcReceiver,
694 ) -> Result<(), RuntimeError> {
695 let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
696 let error =
697 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string());
698 warn!(
699 connection_id,
700 error_code = error.code.as_i32(),
701 error_kind = error.kind.as_u8(),
702 error_message = %error.message,
703 "rpc server handshake receive failed"
704 );
705 error
706 })?
707 else {
708 let error = RuntimeError::transport(
709 RuntimeErrorCode::InternalRuntimeError,
710 "client disconnected during handshake",
711 );
712 warn!(
713 connection_id,
714 error_code = error.code.as_i32(),
715 error_kind = error.kind.as_u8(),
716 error_message = %error.message,
717 "rpc server handshake disconnected"
718 );
719 return Err(error);
720 };
721 let Envelope::Hello(hello) = envelope else {
722 let error = RuntimeError::protocol(
723 RuntimeErrorCode::InvalidEnvelope,
724 "expected HELLO during handshake",
725 );
726 warn!(
727 connection_id,
728 envelope_kind = envelope_name(&envelope),
729 error_code = error.code.as_i32(),
730 error_kind = error.kind.as_u8(),
731 error_message = %error.message,
732 "rpc server handshake received invalid envelope"
733 );
734 return Err(error);
735 };
736 if hello.protocol_version != RUNTIME_PROTOCOL_VERSION || hello.role != Role::Client {
737 let error = RuntimeError::protocol(
738 RuntimeErrorCode::UnsupportedProtocolVersion,
739 "unsupported client handshake",
740 );
741 warn!(
742 connection_id,
743 protocol_version = hello.protocol_version,
744 role = ?hello.role,
745 capability_bits = hello.capability_bits.bits(),
746 max_message_size = hello.max_message_size,
747 error_code = error.code.as_i32(),
748 error_kind = error.kind.as_u8(),
749 error_message = %error.message,
750 "rpc server handshake rejected"
751 );
752 return Err(error);
753 }
754 self.validate_handshake_auth(connection_id, &hello.options)?;
755 sender
756 .send_envelope(&Envelope::HelloAck(HelloAck {
757 protocol_version: RUNTIME_PROTOCOL_VERSION,
758 accepted_capability_bits: server_capabilities() & hello.capability_bits,
759 max_message_size: hello.max_message_size,
760 options: Vec::new(),
761 }))
762 .await
763 .map_err(|err| {
764 let error = RuntimeError::transport(
765 RuntimeErrorCode::InternalRuntimeError,
766 err.to_string(),
767 );
768 warn!(
769 connection_id,
770 error_code = error.code.as_i32(),
771 error_kind = error.kind.as_u8(),
772 error_message = %error.message,
773 "rpc server handshake ack send failed"
774 );
775 error
776 })?;
777 info!(
778 connection_id,
779 protocol_version = hello.protocol_version,
780 accepted_capability_bits = (server_capabilities() & hello.capability_bits).bits(),
781 max_message_size = hello.max_message_size,
782 "rpc server handshake completed"
783 );
784 Ok(())
785 }
786
787 fn validate_handshake_auth(
788 &self,
789 connection_id: u64,
790 options: &Options,
791 ) -> Result<(), RuntimeError> {
792 let RpcServerAuthConfig::Token { token, option_key } = &self.state.security.auth else {
793 return Ok(());
794 };
795
796 let value = options
797 .iter()
798 .rev()
799 .find_map(|(key, value)| (key == option_key).then_some(value));
800 let Some(value) = value else {
801 let error = RuntimeError::protocol(
802 RuntimeErrorCode::AccessDenied,
803 "missing handshake authentication token",
804 );
805 warn!(
806 connection_id,
807 auth_option_key = %option_key,
808 error_code = error.code.as_i32(),
809 error_kind = error.kind.as_u8(),
810 error_message = %error.message,
811 "rpc server handshake rejected authentication"
812 );
813 return Err(error);
814 };
815 let Some(received) = value.as_str() else {
816 let error = RuntimeError::protocol(
817 RuntimeErrorCode::AccessDenied,
818 "handshake authentication token must be a string",
819 );
820 warn!(
821 connection_id,
822 auth_option_key = %option_key,
823 error_code = error.code.as_i32(),
824 error_kind = error.kind.as_u8(),
825 error_message = %error.message,
826 "rpc server handshake rejected authentication"
827 );
828 return Err(error);
829 };
830 if received != token {
831 let error = RuntimeError::protocol(
832 RuntimeErrorCode::AccessDenied,
833 "invalid handshake authentication token",
834 );
835 warn!(
836 connection_id,
837 auth_option_key = %option_key,
838 error_code = error.code.as_i32(),
839 error_kind = error.kind.as_u8(),
840 error_message = %error.message,
841 "rpc server handshake rejected authentication"
842 );
843 return Err(error);
844 }
845 debug!(
846 connection_id,
847 auth_option_key = %option_key,
848 "rpc server handshake authentication accepted"
849 );
850 Ok(())
851 }
852
853 pub async fn list_instances(&self) -> Vec<InstanceDescriptor> {
854 self.state.list_instances(None).await
855 }
856}
857
858async fn handle_request(
859 state: Arc<ServerState>,
860 sender: RpcSender,
861 connection_id: u64,
862 request: Request,
863 observability: RpcServerObservabilityConfig,
864) {
865 let request_id = request.request_id;
866 let instance_id = request.instance_id;
867 let method_id = request.method_id;
868 let is_activation = instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE;
869 let payload_preview = payload_preview(&request.payload, observability);
870
871 debug!(
872 connection_id,
873 request_id = request_id.get(),
874 instance_id = instance_id.get(),
875 method_id = method_id.get(),
876 is_activation,
877 "rpc server request received"
878 );
879 state.record_metric(RpcServerMetricEvent::RequestStarted {
880 connection_id,
881 request_id,
882 instance_id,
883 method_id,
884 is_activation,
885 });
886 if let Some(payload_preview) = payload_preview {
887 trace!(
888 connection_id,
889 request_id = request_id.get(),
890 payload_preview,
891 "rpc server request payload preview"
892 );
893 }
894
895 let started = Instant::now();
896 let response = dispatch_request(state.clone(), sender.clone(), connection_id, request).await;
897 let elapsed = started.elapsed();
898 let elapsed_ms = elapsed.as_secs_f64() * 1000.0;
899
900 let envelope = match response {
901 Ok(payload) => {
902 if elapsed >= observability.slow_call_threshold {
903 state.record_metric(RpcServerMetricEvent::RequestSlow {
904 connection_id,
905 request_id,
906 instance_id,
907 method_id,
908 is_activation,
909 elapsed,
910 threshold: observability.slow_call_threshold,
911 });
912 warn!(
913 connection_id,
914 request_id = request_id.get(),
915 instance_id = instance_id.get(),
916 method_id = method_id.get(),
917 is_activation,
918 elapsed_ms,
919 slow_call_threshold_ms =
920 observability.slow_call_threshold.as_secs_f64() * 1000.0,
921 "rpc server request completed slowly"
922 );
923 } else {
924 info!(
925 connection_id,
926 request_id = request_id.get(),
927 instance_id = instance_id.get(),
928 method_id = method_id.get(),
929 is_activation,
930 elapsed_ms,
931 "rpc server request completed"
932 );
933 }
934 state.record_metric(RpcServerMetricEvent::RequestCompleted {
935 connection_id,
936 request_id,
937 instance_id,
938 method_id,
939 is_activation,
940 elapsed,
941 });
942 Envelope::ResponseOk(ResponseOk {
943 request_id,
944 payload,
945 })
946 }
947 Err(error) => {
948 state.record_metric(RpcServerMetricEvent::RequestFailed {
949 connection_id,
950 request_id,
951 instance_id,
952 method_id,
953 is_activation,
954 elapsed,
955 error_code: error.code,
956 });
957 warn!(
958 connection_id,
959 request_id = request_id.get(),
960 instance_id = instance_id.get(),
961 method_id = method_id.get(),
962 is_activation,
963 elapsed_ms,
964 error_code = error.code.as_i32(),
965 error_kind = error.kind.as_u8(),
966 error_message = %error.message,
967 "rpc server request failed"
968 );
969 runtime_error_response(request_id, error)
970 }
971 };
972
973 if let Err(err) = sender.send_envelope(&envelope).await {
974 state.record_metric(RpcServerMetricEvent::ResponseSendFailed {
975 connection_id,
976 request_id,
977 });
978 error!(
979 connection_id,
980 request_id = request_id.get(),
981 error = %err,
982 "rpc server failed to send response"
983 );
984 } else {
985 trace!(
986 connection_id,
987 request_id = request_id.get(),
988 response_kind = envelope_name(&envelope),
989 "rpc server response sent"
990 );
991 }
992}
993
994async fn dispatch_request(
995 state: Arc<ServerState>,
996 sender: RpcSender,
997 connection_id: u64,
998 request: Request,
999) -> Result<Value, RuntimeError> {
1000 if request.instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE {
1001 return dispatch_activation(state, sender, connection_id, request).await;
1002 }
1003
1004 let instance = state.get_instance(request.instance_id).await?;
1005 if !instance.methods.contains(&request.method_id.get()) {
1006 return Err(RuntimeError::runtime(
1007 RuntimeErrorCode::MethodNotFound,
1008 format!("method id `{}` was not found", request.method_id.get()),
1009 ));
1010 }
1011 let ctx = RpcCallContext {
1012 connection_id,
1013 instance_id: request.instance_id,
1014 sender,
1015 };
1016 instance
1017 .handler
1018 .call(ctx, request.method_id, request.payload)
1019 .await
1020}
1021
1022async fn dispatch_activation(
1023 state: Arc<ServerState>,
1024 sender: RpcSender,
1025 connection_id: u64,
1026 request: Request,
1027) -> Result<Value, RuntimeError> {
1028 let ctx = RpcCallContext {
1029 connection_id,
1030 instance_id: request.instance_id,
1031 sender,
1032 };
1033 match request.method_id.get() {
1034 RESOLVE_INSTANCE_IDS_METHOD_ID => {
1035 let request = decode_resolve_instance_ids_request(&request.payload)?;
1036 let ids = state.resolve_instance_ids(&request.instance_names).await;
1037 Ok(encode_resolve_instance_ids_response(
1038 &ResolveInstanceIdsResponse { instance_ids: ids },
1039 ))
1040 }
1041 CREATE_INSTANCE_METHOD_ID => {
1042 let request = decode_create_instance_request(&request.payload)?;
1043 let factory = state.get_factory(request.service_guid).ok_or_else(|| {
1044 RuntimeError::runtime(
1045 RuntimeErrorCode::ServiceGuidNotFound,
1046 "service factory was not found",
1047 )
1048 })?;
1049 let handler = factory
1050 .factory
1051 .create(ctx, request.create_payload, request.options)
1052 .await?;
1053 let instance_id = state
1054 .insert_client_instance(
1055 request.service_guid,
1056 connection_id,
1057 factory.methods.clone(),
1058 handler,
1059 )
1060 .await;
1061 Ok(encode_create_instance_response(&CreateInstanceResponse {
1062 instance_id,
1063 }))
1064 }
1065 RELEASE_INSTANCE_METHOD_ID => {
1066 let request = decode_release_instance_request(&request.payload)?;
1067 state
1068 .release_instance(connection_id, request.instance_id)
1069 .await?;
1070 Ok(encode_release_instance_response(&ReleaseInstanceResponse))
1071 }
1072 LIST_INSTANCES_METHOD_ID => {
1073 let request = decode_list_instances_request(&request.payload)?;
1074 let instances = state.list_instances(request.service_guid).await;
1075 Ok(encode_list_instances_response(&ListInstancesResponse {
1076 instances,
1077 }))
1078 }
1079 _ => Err(RuntimeError::runtime(
1080 RuntimeErrorCode::MethodNotFound,
1081 "activation method was not found",
1082 )),
1083 }
1084}
1085
1086fn runtime_error_response(request_id: RequestId, error: RuntimeError) -> Envelope {
1087 Envelope::ResponseError(ResponseError {
1088 request_id,
1089 error_code: error.code.as_i32(),
1090 error_kind: error.kind.as_u8(),
1091 error_message: Some(error.message),
1092 error_details: Value::Nil,
1093 })
1094}
1095
1096fn server_capabilities() -> CapabilityFlags {
1097 CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
1098 | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
1099 | CapabilityFlags::SERVICE_ACTIVATION
1100 | CapabilityFlags::GOODBYE
1101}
1102
1103fn envelope_name(envelope: &Envelope) -> &'static str {
1104 match envelope {
1105 Envelope::Hello(_) => "hello",
1106 Envelope::HelloAck(_) => "hello_ack",
1107 Envelope::Request(_) => "request",
1108 Envelope::ResponseOk(_) => "response_ok",
1109 Envelope::ResponseError(_) => "response_error",
1110 Envelope::Notification(_) => "notification",
1111 Envelope::Goodbye(_) => "goodbye",
1112 }
1113}
1114
1115fn payload_preview(payload: &Value, config: RpcServerObservabilityConfig) -> Option<String> {
1116 if !config.log_payload_preview || config.payload_preview_bytes == 0 {
1117 return None;
1118 }
1119 let mut preview = format!("{payload:?}");
1120 if preview.len() > config.payload_preview_bytes {
1121 preview.truncate(config.payload_preview_bytes);
1122 preview.push_str("...");
1123 }
1124 Some(preview)
1125}
1126
1127fn is_transport_access_denied(error: &TransportError) -> bool {
1128 matches!(
1129 error,
1130 TransportError::Runtime(error) if error.code == RuntimeErrorCode::AccessDenied
1131 )
1132}
1133
1134fn duration_nanos_u64(duration: Duration) -> u64 {
1135 duration.as_nanos().min(u128::from(u64::MAX)) as u64
1136}
1137
1138fn update_atomic_max(value: &AtomicU64, candidate: u64) {
1139 let mut current = value.load(Ordering::Relaxed);
1140 while candidate > current {
1141 match value.compare_exchange_weak(current, candidate, Ordering::Relaxed, Ordering::Relaxed)
1142 {
1143 Ok(_) => break,
1144 Err(actual) => current = actual,
1145 }
1146 }
1147}
1148
1149fn saturating_atomic_add(value: &AtomicU64, increment: u64) {
1150 let mut current = value.load(Ordering::Relaxed);
1151 loop {
1152 let next = current.saturating_add(increment);
1153 match value.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
1154 Ok(_) => break,
1155 Err(actual) => current = actual,
1156 }
1157 }
1158}
1159
1160struct ServerState {
1161 next_connection_id: AtomicU64,
1162 next_instance_id: AtomicU64,
1163 observability: RpcServerObservabilityConfig,
1164 security: RpcServerSecurityConfig,
1165 metrics_sink: Option<Arc<dyn RpcServerMetricsSink>>,
1166 instances: RwLock<HashMap<u64, InstanceEntry>>,
1167 names: RwLock<HashMap<String, u64>>,
1168 factories: HashMap<uuid::Uuid, FactoryEntry>,
1169}
1170
1171impl ServerState {
1172 fn new() -> Self {
1173 Self {
1174 next_connection_id: AtomicU64::new(1),
1175 next_instance_id: AtomicU64::new(2),
1176 observability: RpcServerObservabilityConfig::default(),
1177 security: RpcServerSecurityConfig::default(),
1178 metrics_sink: None,
1179 instances: RwLock::new(HashMap::new()),
1180 names: RwLock::new(HashMap::new()),
1181 factories: HashMap::new(),
1182 }
1183 }
1184
1185 fn record_metric(&self, event: RpcServerMetricEvent) {
1186 let Some(sink) = &self.metrics_sink else {
1187 return;
1188 };
1189 let result = panic::catch_unwind(AssertUnwindSafe(|| sink.record(event)));
1190 if result.is_err() {
1191 error!("rpc server metrics sink panicked while recording event");
1192 }
1193 }
1194
1195 fn insert_activation_instance(&mut self) {
1196 self.instances.get_mut().insert(
1197 ACTIVATION_INSTANCE_ID_VALUE,
1198 InstanceEntry {
1199 instance_id: activation_instance_id(),
1200 service_guid: activation_service_guid(),
1201 instance_name: Some("rpc.runtime.Activation".to_string()),
1202 activation_mode: ActivationMode::Singleton,
1203 releasable: false,
1204 owner_connection_id: None,
1205 methods: vec![
1206 RESOLVE_INSTANCE_IDS_METHOD_ID,
1207 CREATE_INSTANCE_METHOD_ID,
1208 RELEASE_INSTANCE_METHOD_ID,
1209 LIST_INSTANCES_METHOD_ID,
1210 ],
1211 handler: Arc::new(ActivationMarker),
1212 },
1213 );
1214 }
1215
1216 fn insert_instance(&mut self, instance: NewInstance) -> InstanceId {
1217 let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
1218 let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
1219 if let Some(name) = &instance.name {
1220 self.names.get_mut().insert(name.clone(), id);
1221 }
1222 self.instances.get_mut().insert(
1223 id,
1224 InstanceEntry {
1225 instance_id,
1226 service_guid: instance.service_guid,
1227 instance_name: instance.name,
1228 activation_mode: instance.activation_mode,
1229 releasable: instance.releasable,
1230 owner_connection_id: instance.owner_connection_id,
1231 methods: instance.methods,
1232 handler: instance.handler,
1233 },
1234 );
1235 instance_id
1236 }
1237
1238 async fn insert_client_instance(
1239 &self,
1240 service_guid: ServiceGuid,
1241 connection_id: u64,
1242 methods: Vec<u32>,
1243 handler: Arc<dyn RpcServiceHandler>,
1244 ) -> InstanceId {
1245 let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
1246 let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
1247 self.instances.write().await.insert(
1248 id,
1249 InstanceEntry {
1250 instance_id,
1251 service_guid,
1252 instance_name: None,
1253 activation_mode: ActivationMode::Instantiable,
1254 releasable: true,
1255 owner_connection_id: Some(connection_id),
1256 methods,
1257 handler,
1258 },
1259 );
1260 instance_id
1261 }
1262
1263 async fn get_instance(&self, instance_id: InstanceId) -> Result<InstanceEntry, RuntimeError> {
1264 self.instances
1265 .read()
1266 .await
1267 .get(&instance_id.get())
1268 .cloned()
1269 .ok_or_else(|| {
1270 RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
1271 })
1272 }
1273
1274 fn get_factory(&self, service_guid: ServiceGuid) -> Option<FactoryEntry> {
1275 self.factories.get(&service_guid.get()).cloned()
1276 }
1277
1278 async fn resolve_instance_ids(&self, names: &[String]) -> Vec<u64> {
1279 let index = self.names.read().await;
1280 names
1281 .iter()
1282 .map(|name| index.get(name).copied().unwrap_or(0))
1283 .collect()
1284 }
1285
1286 async fn release_instance(
1287 &self,
1288 connection_id: u64,
1289 instance_id: InstanceId,
1290 ) -> Result<(), RuntimeError> {
1291 let mut instances = self.instances.write().await;
1292 let entry = instances.get(&instance_id.get()).ok_or_else(|| {
1293 RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
1294 })?;
1295 if !entry.releasable {
1296 return Err(RuntimeError::runtime(
1297 RuntimeErrorCode::InstanceReleaseNotAllowed,
1298 "instance is not releasable",
1299 ));
1300 }
1301 if entry.owner_connection_id != Some(connection_id) {
1302 return Err(RuntimeError::runtime(
1303 RuntimeErrorCode::AccessDenied,
1304 "instance is owned by another connection",
1305 ));
1306 }
1307 instances.remove(&instance_id.get());
1308 Ok(())
1309 }
1310
1311 async fn cleanup_connection(&self, connection_id: u64) {
1312 self.instances
1313 .write()
1314 .await
1315 .retain(|_, entry| entry.owner_connection_id != Some(connection_id));
1316 }
1317
1318 async fn list_instances(&self, service_guid: Option<ServiceGuid>) -> Vec<InstanceDescriptor> {
1319 let mut values = self
1320 .instances
1321 .read()
1322 .await
1323 .values()
1324 .filter(|entry| service_guid.is_none_or(|guid| guid == entry.service_guid))
1325 .map(InstanceEntry::descriptor)
1326 .collect::<Vec<_>>();
1327 values.sort_by_key(|entry| entry.instance_id.get());
1328 values
1329 }
1330}
1331
1332struct NewInstance {
1333 service_guid: ServiceGuid,
1334 name: Option<String>,
1335 activation_mode: ActivationMode,
1336 releasable: bool,
1337 owner_connection_id: Option<u64>,
1338 methods: Vec<u32>,
1339 handler: Arc<dyn RpcServiceHandler>,
1340}
1341
1342#[derive(Clone)]
1343struct InstanceEntry {
1344 instance_id: InstanceId,
1345 service_guid: ServiceGuid,
1346 instance_name: Option<String>,
1347 activation_mode: ActivationMode,
1348 releasable: bool,
1349 owner_connection_id: Option<u64>,
1350 methods: Vec<u32>,
1351 handler: Arc<dyn RpcServiceHandler>,
1352}
1353
1354impl InstanceEntry {
1355 fn descriptor(&self) -> InstanceDescriptor {
1356 InstanceDescriptor {
1357 instance_id: self.instance_id,
1358 instance_name: self.instance_name.clone(),
1359 service_guid: self.service_guid,
1360 activation_mode: self.activation_mode,
1361 releasable: self.releasable,
1362 }
1363 }
1364}
1365
1366#[derive(Clone)]
1367struct FactoryEntry {
1368 methods: Vec<u32>,
1369 factory: Arc<dyn RpcServiceFactory>,
1370}
1371
1372struct ActivationMarker;
1373
1374impl RpcServiceHandler for ActivationMarker {
1375 fn call(&self, _: RpcCallContext, _: MethodId, _: Value) -> HandlerFuture {
1376 Box::pin(async {
1377 Err(RuntimeError::runtime(
1378 RuntimeErrorCode::InternalRuntimeError,
1379 "activation marker should not be dispatched directly",
1380 ))
1381 })
1382 }
1383}
1384
1385#[cfg(test)]
1386mod tests {
1387 use super::*;
1388 use rpc_runtime_core::{Goodbye, Hello, Request, Role};
1389 use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection};
1390 use tokio::io::duplex;
1391
1392 #[test]
1393 fn observability_defaults_are_safe() {
1394 let config = RpcServerObservabilityConfig::default();
1395
1396 assert_eq!(config.slow_call_threshold, Duration::from_millis(500));
1397 assert_eq!(config.payload_preview_bytes, 0);
1398 assert!(!config.log_payload_preview);
1399 }
1400
1401 #[test]
1402 fn payload_preview_is_opt_in_and_bounded() {
1403 let payload = Value::from("1234567890");
1404
1405 assert_eq!(
1406 payload_preview(&payload, RpcServerObservabilityConfig::default()),
1407 None
1408 );
1409 let preview = payload_preview(
1410 &payload,
1411 RpcServerObservabilityConfig::default().with_payload_preview(5),
1412 )
1413 .expect("preview");
1414 assert!(preview.len() <= 8);
1415 assert!(preview.ends_with("..."));
1416 }
1417
1418 #[test]
1419 fn metrics_recorder_counts_events_and_latency() {
1420 let recorder = RpcServerMetricsRecorder::new();
1421 recorder.record(RpcServerMetricEvent::ConnectionStarted { connection_id: 1 });
1422 recorder.record(RpcServerMetricEvent::ConnectionEnded {
1423 connection_id: 1,
1424 success: true,
1425 });
1426 recorder.record(RpcServerMetricEvent::RequestCompleted {
1427 connection_id: 1,
1428 request_id: RequestId::new(7),
1429 instance_id: activation_instance_id(),
1430 method_id: MethodId::new(1),
1431 is_activation: true,
1432 elapsed: Duration::from_millis(3),
1433 });
1434 recorder.record(RpcServerMetricEvent::RequestFailed {
1435 connection_id: 1,
1436 request_id: RequestId::new(8),
1437 instance_id: activation_instance_id(),
1438 method_id: MethodId::new(2),
1439 is_activation: true,
1440 elapsed: Duration::from_millis(5),
1441 error_code: RuntimeErrorCode::InternalRuntimeError,
1442 });
1443
1444 let snapshot = recorder.snapshot();
1445 assert_eq!(snapshot.connections_started, 1);
1446 assert_eq!(snapshot.connections_ended, 1);
1447 assert_eq!(snapshot.connections_ended_successfully, 1);
1448 assert_eq!(snapshot.requests_completed, 1);
1449 assert_eq!(snapshot.requests_failed, 1);
1450 assert_eq!(snapshot.request_elapsed_total, Duration::from_millis(8));
1451 assert_eq!(snapshot.request_elapsed_max, Duration::from_millis(5));
1452 }
1453
1454 #[test]
1455 fn security_defaults_are_local_auth_disabled() {
1456 let config = RpcServerSecurityConfig::default();
1457
1458 assert_eq!(config.connection_scope, ConnectionScope::LocalOnly);
1459 assert_eq!(config.auth, RpcServerAuthConfig::Disabled);
1460 }
1461
1462 #[tokio::test]
1463 async fn token_auth_accepts_matching_token() {
1464 let server = RpcServerBuilder::new()
1465 .security(RpcServerSecurityConfig::default().with_token("secret"))
1466 .build();
1467
1468 let ack = run_handshake(server, vec![auth_option("secret")])
1469 .await
1470 .expect("handshake");
1471
1472 assert!(matches!(ack, Envelope::HelloAck(_)));
1473 }
1474
1475 #[tokio::test]
1476 async fn token_auth_rejects_missing_token() {
1477 let server = RpcServerBuilder::new()
1478 .security(RpcServerSecurityConfig::default().with_token("secret"))
1479 .build();
1480
1481 let err = run_handshake(server, Vec::new())
1482 .await
1483 .expect_err("must reject");
1484
1485 assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1486 }
1487
1488 #[tokio::test]
1489 async fn token_auth_rejects_wrong_token() {
1490 let server = RpcServerBuilder::new()
1491 .security(RpcServerSecurityConfig::default().with_token("secret"))
1492 .build();
1493
1494 let err = run_handshake(server, vec![auth_option("wrong")])
1495 .await
1496 .expect_err("must reject");
1497
1498 assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1499 }
1500
1501 #[tokio::test]
1502 async fn token_auth_rejects_non_string_token() {
1503 let server = RpcServerBuilder::new()
1504 .security(RpcServerSecurityConfig::default().with_token("secret"))
1505 .build();
1506
1507 let err = run_handshake(
1508 server,
1509 vec![(
1510 DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
1511 Value::from(123_u64),
1512 )],
1513 )
1514 .await
1515 .expect_err("must reject");
1516
1517 assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1518 }
1519
1520 #[tokio::test]
1521 async fn metrics_recorder_observes_handshake_failure() {
1522 let recorder = Arc::new(RpcServerMetricsRecorder::new());
1523 let server = RpcServerBuilder::new()
1524 .metrics_sink(recorder.clone())
1525 .security(RpcServerSecurityConfig::default().with_token("secret"))
1526 .build();
1527
1528 let err = run_handshake(server, Vec::new())
1529 .await
1530 .expect_err("must reject");
1531 assert_eq!(err.code, RuntimeErrorCode::AccessDenied);
1532
1533 let snapshot = recorder.snapshot();
1534 assert_eq!(snapshot.connections_started, 1);
1535 assert_eq!(snapshot.connections_ended, 1);
1536 assert_eq!(snapshot.connections_ended_successfully, 0);
1537 assert_eq!(snapshot.handshakes_completed, 0);
1538 assert_eq!(snapshot.handshakes_failed, 1);
1539 }
1540
1541 #[tokio::test]
1542 async fn metrics_recorder_observes_success_failure_and_slow_requests() {
1543 let recorder = Arc::new(RpcServerMetricsRecorder::new());
1544 let mut builder = RpcServerBuilder::new()
1545 .metrics_sink(recorder.clone())
1546 .observability(
1547 RpcServerObservabilityConfig::default()
1548 .with_slow_call_threshold(Duration::from_nanos(0)),
1549 );
1550 let instance_id = builder.register_named_instance(
1551 "metrics",
1552 activation_service_guid(),
1553 [1, 2],
1554 Arc::new(MetricsTestHandler),
1555 );
1556 let server = builder.build();
1557 let (client_stream, server_stream) = duplex(4096);
1558 let server_connection = IpcConnection::from_stream(server_stream, FrameConfig::default());
1559 let server_task =
1560 tokio::spawn(async move { server.serve_connection(server_connection).await });
1561 let client_connection = IpcConnection::from_stream(client_stream, FrameConfig::default());
1562 let (sender, mut receiver) = client_connection.split();
1563
1564 send_hello(&sender).await;
1565 assert!(matches!(
1566 receiver.recv_envelope().await.expect("recv ack"),
1567 Some(Envelope::HelloAck(_))
1568 ));
1569 sender
1570 .send_envelope(&Envelope::Request(Request {
1571 request_id: RequestId::new(11),
1572 instance_id,
1573 method_id: MethodId::new(1),
1574 payload: Value::from("ok"),
1575 }))
1576 .await
1577 .expect("send success request");
1578 assert!(matches!(
1579 receiver.recv_envelope().await.expect("recv response"),
1580 Some(Envelope::ResponseOk(_))
1581 ));
1582 sender
1583 .send_envelope(&Envelope::Request(Request {
1584 request_id: RequestId::new(12),
1585 instance_id,
1586 method_id: MethodId::new(2),
1587 payload: Value::Nil,
1588 }))
1589 .await
1590 .expect("send failing request");
1591 assert!(matches!(
1592 receiver.recv_envelope().await.expect("recv error"),
1593 Some(Envelope::ResponseError(_))
1594 ));
1595 sender
1596 .send_envelope(&Envelope::Goodbye(Goodbye {
1597 reason_code: 0,
1598 message: Some("done".to_string()),
1599 }))
1600 .await
1601 .expect("send goodbye");
1602 drop(sender);
1603 drop(receiver);
1604 server_task.await.expect("server task").expect("serve");
1605
1606 let snapshot = recorder.snapshot();
1607 assert_eq!(snapshot.connections_started, 1);
1608 assert_eq!(snapshot.connections_ended_successfully, 1);
1609 assert_eq!(snapshot.handshakes_completed, 1);
1610 assert_eq!(snapshot.requests_started, 2);
1611 assert_eq!(snapshot.requests_completed, 1);
1612 assert_eq!(snapshot.requests_failed, 1);
1613 assert_eq!(snapshot.requests_slow, 1);
1614 assert!(snapshot.request_elapsed_total > Duration::ZERO);
1615 }
1616
1617 async fn run_handshake(server: RpcServer, options: Options) -> Result<Envelope, RuntimeError> {
1618 let (client_stream, server_stream) = duplex(4096);
1619 let server_connection = IpcConnection::from_stream(server_stream, FrameConfig::default());
1620 let server_task =
1621 tokio::spawn(async move { server.serve_connection(server_connection).await });
1622
1623 let client_connection = IpcConnection::from_stream(client_stream, FrameConfig::default());
1624 let (sender, mut receiver) = client_connection.split();
1625 sender
1626 .send_envelope(&hello_envelope(options))
1627 .await
1628 .expect("send hello");
1629
1630 let envelope = receiver.recv_envelope().await;
1631 drop(sender);
1632 drop(receiver);
1633 let server_result = server_task.await.expect("server task");
1634 match envelope.expect("recv hello ack") {
1635 Some(envelope) => Ok(envelope),
1636 None => Err(server_result.expect_err("server should return handshake error")),
1637 }
1638 }
1639
1640 fn auth_option(token: &str) -> (String, Value) {
1641 (
1642 DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
1643 Value::from(token),
1644 )
1645 }
1646
1647 async fn send_hello(sender: &RpcSender) {
1648 sender
1649 .send_envelope(&hello_envelope(Vec::new()))
1650 .await
1651 .expect("send hello");
1652 }
1653
1654 fn hello_envelope(options: Options) -> Envelope {
1655 Envelope::Hello(Hello {
1656 protocol_version: RUNTIME_PROTOCOL_VERSION,
1657 role: Role::Client,
1658 capability_bits: CapabilityFlags::empty(),
1659 max_message_size: 16 * 1024 * 1024,
1660 options,
1661 })
1662 }
1663
1664 struct MetricsTestHandler;
1665
1666 impl RpcServiceHandler for MetricsTestHandler {
1667 fn call(&self, _: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
1668 Box::pin(async move {
1669 match method_id.get() {
1670 1 => Ok(payload),
1671 _ => Err(RuntimeError::runtime(
1672 RuntimeErrorCode::InternalRuntimeError,
1673 "test failure",
1674 )),
1675 }
1676 })
1677 }
1678 }
1679}