1use std::collections::HashMap;
4use std::io::{Read, Write};
5use std::sync::{Arc, Mutex};
6
7use arrow_array::RecordBatch;
8use arrow_cast::cast_with_options;
9use arrow_schema::{Schema, SchemaRef};
10
11use crate::errors::{Result, RpcError};
12use crate::log::{LogLevel, LogMessage};
13use crate::metadata::{
14 CANCEL_KEY, LOG_EXTRA_KEY, LOG_LEVEL_KEY, LOG_MESSAGE_KEY, REQUEST_ID_KEY, REQUEST_VERSION,
15 REQUEST_VERSION_KEY, RPC_METHOD_KEY, SERVER_ID_KEY,
16};
17#[cfg(feature = "shm")]
18use crate::metadata::{SHM_SEGMENT_NAME_KEY, SHM_SEGMENT_SIZE_KEY};
19#[cfg(feature = "shm")]
20use crate::shm::{maybe_write_to_shm, resolve_shm_batch, ShmSegment};
21
22#[cfg(not(feature = "shm"))]
24pub(crate) struct ShmSegment;
25
26#[cfg(feature = "shm")]
29fn maybe_attach_shm(req_md: &Metadata) -> Option<ShmSegment> {
30 let name = req_md.get(SHM_SEGMENT_NAME_KEY)?;
31 let size: usize = req_md.get(SHM_SEGMENT_SIZE_KEY)?.parse().ok()?;
32 match ShmSegment::attach(name, size, false) {
33 Ok(seg) => Some(seg),
34 Err(e) => {
35 tracing::warn!(target: "vgi_rpc.shm", "ignoring malformed SHM metadata ({e})");
36 None
37 }
38 }
39}
40
41#[cfg(not(feature = "shm"))]
42#[inline]
43fn maybe_attach_shm(_req_md: &Metadata) -> Option<ShmSegment> {
44 None
45}
46use crate::stream::{empty_schema, Emitted, OutputCollector, StreamResult, StreamStateKind};
47use crate::wire::{empty_batch, md_get, Metadata, StreamReader, StreamWriter};
48
49fn serialize_request_batch(batch: &RecordBatch) -> std::io::Result<Vec<u8>> {
53 let mut buf = Vec::new();
54 {
55 let mut w = arrow_ipc::writer::StreamWriter::try_new(&mut buf, batch.schema_ref())
56 .map_err(|e| std::io::Error::other(e.to_string()))?;
57 w.write(batch)
58 .map_err(|e| std::io::Error::other(e.to_string()))?;
59 w.finish()
60 .map_err(|e| std::io::Error::other(e.to_string()))?;
61 }
62 Ok(buf)
63}
64
65fn lock_ok<T>(m: &Mutex<T>) -> std::sync::MutexGuard<'_, T> {
71 m.lock().unwrap_or_else(|e| e.into_inner())
72}
73
74pub(crate) fn call_guard<T>(f: impl FnOnce() -> T) -> Result<T> {
79 std::panic::catch_unwind(std::panic::AssertUnwindSafe(f))
80 .map_err(|_| RpcError::new("RuntimeError", "handler panicked"))
81}
82
83#[derive(Clone)]
85pub struct CallContext {
86 pub server_id: String,
87 pub method: String,
88 pub request_id: String,
89 pub transport_metadata: Arc<Metadata>,
90 pub auth: crate::auth::AuthContext,
93 pub cookies: std::collections::BTreeMap<String, String>,
95 pub kind: Option<crate::transport::TransportKind>,
99 pub(crate) log_sink: Arc<Mutex<Vec<LogMessage>>>,
100 pub(crate) tick_metadata: Arc<Mutex<Metadata>>,
103 pub(crate) sticky: Option<Arc<dyn StickySink>>,
108}
109
110pub trait StickySink: Send + Sync {
115 fn accept_opens(&self) -> bool;
117 fn current_state(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>>;
119 fn current_session_id(&self) -> Option<String>;
121 fn open(
123 &self,
124 state: Arc<dyn std::any::Any + Send + Sync>,
125 ttl: Option<std::time::Duration>,
126 ) -> Result<()>;
127 fn close(&self) -> Result<bool>;
129}
130
131impl CallContext {
132 pub fn client_log(&self, level: LogLevel, message: impl Into<String>) {
133 lock_ok(&self.log_sink).push(LogMessage::new(level, message));
134 }
135
136 pub fn client_log_with(&self, msg: LogMessage) {
137 lock_ok(&self.log_sink).push(msg);
138 }
139
140 pub(crate) fn drain_logs(&self) -> Vec<LogMessage> {
141 std::mem::take(&mut *lock_ok(&self.log_sink))
142 }
143
144 pub fn tick_metadata(&self, key: &str) -> Option<String> {
147 lock_ok(&self.tick_metadata).get(key).cloned()
148 }
149
150 pub(crate) fn for_request(server: &RpcServer, req: &Request) -> Self {
155 Self {
156 server_id: server.server_id.clone(),
157 method: req.method.clone(),
158 request_id: req.request_id.clone(),
159 transport_metadata: Arc::new(req.metadata.clone()),
160 auth: crate::auth::AuthContext::anonymous(),
161 cookies: std::collections::BTreeMap::new(),
162 kind: server.transport_kind(),
163 log_sink: Arc::new(Mutex::new(Vec::new())),
164 tick_metadata: Arc::new(Mutex::new(Metadata::default())),
165 sticky: None,
166 }
167 }
168
169 #[cfg(feature = "http")]
174 pub(crate) fn with_auth_cookies(
175 server: &RpcServer,
176 req: &Request,
177 auth: crate::auth::AuthContext,
178 cookies: std::collections::BTreeMap<String, String>,
179 ) -> Self {
180 Self {
181 server_id: server.server_id.clone(),
182 method: req.method.clone(),
183 request_id: req.request_id.clone(),
184 transport_metadata: Arc::new(req.metadata.clone()),
185 auth,
186 cookies,
187 kind: server.transport_kind(),
188 log_sink: Arc::new(Mutex::new(Vec::new())),
189 tick_metadata: Arc::new(Mutex::new(Metadata::default())),
190 sticky: None,
191 }
192 }
193
194 #[cfg(feature = "http")]
199 pub(crate) fn set_sticky(&mut self, sink: Arc<dyn StickySink>) {
200 self.sticky = Some(sink);
201 }
202
203 pub fn session<T: std::any::Any + Send + Sync>(&self) -> Option<Arc<T>> {
211 let state = self.sticky.as_ref()?.current_state()?;
212 state.downcast::<T>().ok()
213 }
214
215 pub fn session_id(&self) -> Option<String> {
218 self.sticky.as_ref()?.current_session_id()
219 }
220
221 pub fn open_session(
232 &self,
233 state: Arc<dyn std::any::Any + Send + Sync>,
234 ttl: Option<std::time::Duration>,
235 ) -> Result<()> {
236 let sink = self.sticky.as_ref().ok_or_else(|| {
237 RpcError::runtime_error("sticky sessions not available on this transport")
238 })?;
239 if !sink.accept_opens() {
240 return Err(RpcError::runtime_error(
241 "client did not opt in to sticky sessions \
242 (missing VGI-Session-Accept: true header — open the call inside \
243 an HttpConnection.with_session_token() block)",
244 ));
245 }
246 if sink.current_state().is_some() {
247 return Err(RpcError::runtime_error(
248 "a sticky session is already active for this request",
249 ));
250 }
251 sink.open(state, ttl)
252 }
253
254 pub fn close_session(&self) -> Result<()> {
257 let sink = self.sticky.as_ref().ok_or_else(|| {
258 RpcError::runtime_error("sticky sessions not available on this transport")
259 })?;
260 sink.close()?;
261 Ok(())
262 }
263}
264
265pub struct Request {
267 pub method: String,
268 pub request_id: String,
269 pub batch: RecordBatch,
270 pub metadata: Metadata,
271}
272
273impl Request {
274 pub fn column(&self, name: &str) -> Option<&dyn arrow_array::Array> {
275 let idx = self.batch.schema().index_of(name).ok()?;
276 Some(self.batch.column(idx).as_ref())
277 }
278
279 pub(crate) fn from_read_batch(
287 batch: RecordBatch,
288 metadata: Metadata,
289 require_method: bool,
290 ) -> Result<Self> {
291 let method = if require_method {
292 md_get(&metadata, RPC_METHOD_KEY)
293 .ok_or_else(|| {
294 RpcError::protocol_error(
295 "Missing 'vgi_rpc.method' in request batch custom_metadata.",
296 )
297 })?
298 .to_string()
299 } else {
300 md_get(&metadata, RPC_METHOD_KEY).unwrap_or("").to_string()
301 };
302 let version = md_get(&metadata, REQUEST_VERSION_KEY).ok_or_else(|| {
303 RpcError::version_error(format!(
304 "Missing 'vgi_rpc.request_version' in request batch custom_metadata. Set it to {:?}.",
305 REQUEST_VERSION
306 ))
307 })?;
308 if version != REQUEST_VERSION {
309 return Err(RpcError::version_error(format!(
310 "Unsupported request version {:?}, expected {:?}.",
311 version, REQUEST_VERSION
312 )));
313 }
314 if require_method && !batch.schema().fields().is_empty() && batch.num_rows() != 1 {
315 return Err(RpcError::protocol_error(format!(
316 "Expected 1 row in request batch, got {}",
317 batch.num_rows()
318 )));
319 }
320 let request_id = md_get(&metadata, REQUEST_ID_KEY).unwrap_or("").to_string();
321 Ok(Request {
322 method,
323 request_id,
324 batch,
325 metadata,
326 })
327 }
328}
329
330#[derive(Clone, Copy, Debug, PartialEq, Eq)]
332pub enum MethodType {
333 Unary,
334 Producer,
335 Exchange,
336 Dynamic,
338}
339
340pub type UnaryHandler =
342 Arc<dyn Fn(&Request, &CallContext) -> Result<Option<RecordBatch>> + Send + Sync>;
343
344pub type StreamHandler = Arc<dyn Fn(&Request, &CallContext) -> Result<StreamResult> + Send + Sync>;
346
347#[derive(Default)]
349pub struct RpcServerBuilder {
350 server_id: Option<String>,
351 server_version: Option<String>,
352 protocol_name: Option<String>,
353 protocol_version: Option<String>,
354 enable_describe: bool,
355 dispatch_hook: Option<Arc<dyn crate::hooks::DispatchHook>>,
356 on_serve_start: Option<crate::transport::ServeStartHook>,
357 #[cfg(feature = "http")]
358 external_config: Option<Arc<crate::external::ExternalLocationConfig>>,
359}
360
361impl RpcServerBuilder {
362 pub fn server_id(mut self, id: impl Into<String>) -> Self {
363 self.server_id = Some(id.into());
364 self
365 }
366
367 pub fn server_version(mut self, v: impl Into<String>) -> Self {
368 self.server_version = Some(v.into());
369 self
370 }
371
372 pub fn protocol_name(mut self, name: impl Into<String>) -> Self {
373 self.protocol_name = Some(name.into());
374 self
375 }
376
377 pub fn protocol_version(mut self, v: impl Into<String>) -> Self {
381 self.protocol_version = Some(v.into());
382 self
383 }
384
385 pub fn enable_describe(mut self, enabled: bool) -> Self {
386 self.enable_describe = enabled;
387 self
388 }
389
390 pub fn with_hook(mut self, hook: Arc<dyn crate::hooks::DispatchHook>) -> Self {
391 self.dispatch_hook = Some(hook);
392 self
393 }
394
395 pub fn on_serve_start(mut self, hook: crate::transport::ServeStartHook) -> Self {
405 self.on_serve_start = Some(hook);
406 self
407 }
408
409 #[cfg(feature = "http")]
413 pub fn with_external_location(mut self, cfg: crate::external::ExternalLocationConfig) -> Self {
414 self.external_config = Some(Arc::new(cfg));
415 self
416 }
417
418 pub fn build(self) -> RpcServer {
419 RpcServer {
420 methods: HashMap::new(),
421 server_id: self.server_id.unwrap_or_else(crate::util::short_random_id),
422 server_version: self.server_version.unwrap_or_default(),
423 protocol_name: self.protocol_name.unwrap_or_default(),
424 protocol_version: self.protocol_version.unwrap_or_default(),
425 protocol_hash: std::sync::OnceLock::new(),
426 describe_enabled: self.enable_describe,
427 dispatch_hook: self.dispatch_hook,
428 on_serve_start: self.on_serve_start,
429 transport_state: Mutex::new(None),
430 #[cfg(feature = "http")]
431 external_config: self.external_config,
432 }
433 }
434}
435
436pub struct MethodInfo {
443 pub name: String,
444 pub method_type: MethodType,
445 pub params_schema: SchemaRef,
447 pub result_schema: SchemaRef,
449 pub header_schema: Option<SchemaRef>,
451 pub doc: Option<String>,
453 pub param_types: Vec<(String, String)>,
456 pub param_defaults: Vec<(String, serde_json::Value)>,
458 pub param_docs: Vec<(String, String)>,
460 pub has_return: bool,
462 pub unary: Option<UnaryHandler>,
463 pub stream: Option<StreamHandler>,
464 pub state_decoder: Option<StateDecoder>,
470}
471
472pub type StateDecoder = Arc<dyn Fn(&[u8]) -> Result<crate::stream::StreamStateKind> + Send + Sync>;
475
476impl MethodInfo {
477 pub fn unary(
479 name: impl Into<String>,
480 params_schema: SchemaRef,
481 result_schema: SchemaRef,
482 handler: impl Fn(&Request, &CallContext) -> Result<Option<RecordBatch>> + Send + Sync + 'static,
483 ) -> Self {
484 let has_return = !result_schema.fields().is_empty();
485 Self {
486 name: name.into(),
487 method_type: MethodType::Unary,
488 params_schema,
489 result_schema,
490 header_schema: None,
491 doc: None,
492 param_types: Vec::new(),
493 param_defaults: Vec::new(),
494 param_docs: Vec::new(),
495 has_return,
496 unary: Some(Arc::new(handler)),
497 stream: None,
498 state_decoder: None,
499 }
500 }
501
502 pub fn stream(
510 name: impl Into<String>,
511 method_type: MethodType,
512 params_schema: SchemaRef,
513 handler: impl Fn(&Request, &CallContext) -> Result<StreamResult> + Send + Sync + 'static,
514 ) -> Self {
515 assert!(
516 matches!(
517 method_type,
518 MethodType::Producer | MethodType::Exchange | MethodType::Dynamic
519 ),
520 "stream methods must be Producer / Exchange / Dynamic"
521 );
522 Self {
523 name: name.into(),
524 method_type,
525 params_schema,
526 result_schema: empty_schema(),
527 header_schema: None,
528 doc: None,
529 param_types: Vec::new(),
530 param_defaults: Vec::new(),
531 param_docs: Vec::new(),
532 has_return: false,
533 unary: None,
534 stream: Some(Arc::new(handler)),
535 state_decoder: None,
536 }
537 }
538
539 pub fn with_state_decoder(mut self, decoder: StateDecoder) -> Self {
541 self.state_decoder = Some(decoder);
542 self
543 }
544
545 pub fn doc(mut self, s: impl Into<String>) -> Self {
546 self.doc = Some(s.into());
547 self
548 }
549
550 pub fn param_type(mut self, param: impl Into<String>, ty: impl Into<String>) -> Self {
551 self.param_types.push((param.into(), ty.into()));
552 self
553 }
554
555 pub fn param_default(mut self, param: impl Into<String>, value: serde_json::Value) -> Self {
556 self.param_defaults.push((param.into(), value));
557 self
558 }
559
560 pub fn param_doc(mut self, param: impl Into<String>, doc: impl Into<String>) -> Self {
561 self.param_docs.push((param.into(), doc.into()));
562 self
563 }
564
565 pub fn header_schema(mut self, schema: SchemaRef) -> Self {
566 self.header_schema = Some(schema);
567 self
568 }
569}
570
571pub struct RpcServer {
573 methods: HashMap<String, MethodInfo>,
574 pub server_id: String,
575 pub(crate) server_version: String,
576 pub(crate) protocol_name: String,
577 pub(crate) protocol_version: String,
578 pub(crate) protocol_hash: std::sync::OnceLock<String>,
579 pub(crate) describe_enabled: bool,
580 pub(crate) dispatch_hook: Option<Arc<dyn crate::hooks::DispatchHook>>,
581 on_serve_start: Option<crate::transport::ServeStartHook>,
584 transport_state: Mutex<
587 Option<(
588 crate::transport::TransportKind,
589 crate::transport::TransportCapabilities,
590 )>,
591 >,
592 #[cfg(feature = "http")]
593 pub(crate) external_config: Option<Arc<crate::external::ExternalLocationConfig>>,
594}
595
596impl RpcServer {
597 pub fn new(server_id: impl Into<String>) -> Self {
599 Self::builder().server_id(server_id).build()
600 }
601
602 pub fn builder() -> RpcServerBuilder {
604 RpcServerBuilder::default()
605 }
606
607 pub fn protocol_name(&self) -> &str {
608 &self.protocol_name
609 }
610
611 pub fn describe_enabled(&self) -> bool {
612 self.describe_enabled
613 }
614
615 pub fn server_version(&self) -> &str {
616 &self.server_version
617 }
618
619 pub fn protocol_version(&self) -> &str {
620 &self.protocol_version
621 }
622
623 pub fn protocol_hash(&self) -> &str {
626 self.protocol_hash.get_or_init(|| {
627 match crate::introspect::build_describe(
628 &self.protocol_name,
629 &self.methods,
630 &self.server_id,
631 &self.protocol_version,
632 ) {
633 Ok((_, md)) => md
634 .get(crate::metadata::PROTOCOL_HASH_KEY)
635 .cloned()
636 .unwrap_or_default(),
637 Err(_) => String::new(),
638 }
639 })
640 }
641
642 #[cfg(feature = "http")]
643 pub fn external_config(&self) -> Option<&Arc<crate::external::ExternalLocationConfig>> {
644 self.external_config.as_ref()
645 }
646
647 pub fn transport_kind(&self) -> Option<crate::transport::TransportKind> {
651 lock_ok(&self.transport_state).as_ref().map(|(k, _)| *k)
652 }
653
654 pub fn transport_capabilities(&self) -> crate::transport::TransportCapabilities {
658 lock_ok(&self.transport_state)
659 .as_ref()
660 .map(|(_, c)| *c)
661 .unwrap_or_default()
662 }
663
664 pub fn notify_transport(
676 &self,
677 kind: crate::transport::TransportKind,
678 caps: crate::transport::TransportCapabilities,
679 ) {
680 let hook = {
681 let mut guard = lock_ok(&self.transport_state);
682 if let Some((cur_kind, cur_caps)) = guard.as_ref() {
683 if *cur_kind == kind && *cur_caps == caps {
684 return;
685 }
686 }
687 *guard = Some((kind, caps));
688 self.on_serve_start.clone()
689 };
690 if let Some(h) = hook {
691 h(kind, &caps);
692 }
693 }
694
695 pub fn register(&mut self, info: MethodInfo) {
697 self.methods.insert(info.name.clone(), info);
698 }
699
700 pub fn register_unary(
704 &mut self,
705 name: impl Into<String>,
706 result_schema: SchemaRef,
707 handler: impl Fn(&Request, &CallContext) -> Result<Option<RecordBatch>> + Send + Sync + 'static,
708 ) {
709 self.register(MethodInfo::unary(
710 name,
711 empty_schema(),
712 result_schema,
713 handler,
714 ));
715 }
716
717 pub fn register_stream(
721 &mut self,
722 name: impl Into<String>,
723 method_type: MethodType,
724 handler: impl Fn(&Request, &CallContext) -> Result<StreamResult> + Send + Sync + 'static,
725 ) {
726 self.register(MethodInfo::stream(
727 name,
728 method_type,
729 empty_schema(),
730 handler,
731 ));
732 }
733
734 pub fn method(&self, name: &str) -> Option<&MethodInfo> {
735 self.methods.get(name)
736 }
737
738 pub fn methods(&self) -> &HashMap<String, MethodInfo> {
739 &self.methods
740 }
741
742 pub fn method_names(&self) -> Vec<&str> {
743 self.sorted_method_names()
744 }
745
746 pub fn sorted_method_names(&self) -> Vec<&str> {
749 let mut names: Vec<_> = self.methods.keys().map(String::as_str).collect();
750 names.sort();
751 names
752 }
753
754 pub fn serve<R: Read, W: Write>(&self, mut r: R, mut w: W) {
765 loop {
766 match self.serve_one(&mut r, &mut w) {
767 Ok(keep_going) => {
768 if !keep_going {
769 return;
770 }
771 }
772 Err(e) => {
773 tracing::warn!(
779 target: "vgi_rpc.server",
780 error = %e,
781 "serve loop terminating connection on error"
782 );
783 return;
784 }
785 }
786 }
787 }
788
789 pub fn serve_with_shutdown<R, W, F>(&self, mut r: R, mut w: W, shutdown: F)
795 where
796 R: Read,
797 W: Write,
798 F: Fn() -> bool,
799 {
800 loop {
801 if shutdown() {
802 return;
803 }
804 match self.serve_one(&mut r, &mut w) {
805 Ok(true) => {}
806 _ => return,
807 }
808 }
809 }
810
811 pub fn serve_one<R: Read, W: Write>(&self, r: &mut R, w: &mut W) -> Result<bool> {
813 let result = self._serve_one(r, w);
814 let _ = w.flush();
815 result
816 }
817
818 fn _serve_one<R: Read, W: Write>(&self, r: &mut R, w: &mut W) -> Result<bool> {
819 let req = match self.read_request(r)? {
820 Some(rq) => rq,
821 None => return Ok(false),
822 };
823
824 if req.method == crate::transport_options::TRANSPORT_OPTIONS_METHOD_NAME {
831 let mut md = crate::transport_options::worker_transport_metadata();
832 md.insert(REQUEST_VERSION_KEY.to_string(), REQUEST_VERSION.to_string());
833 md.insert(SERVER_ID_KEY.to_string(), self.server_id.clone());
834 let schema = empty_schema();
835 let batch = empty_batch(&schema)?;
836 let mut sw = StreamWriter::new(w, &schema)?;
837 sw.write(&batch, Some(&md))?;
838 sw.finish()?;
839 return Ok(true);
840 }
841
842 if !self.protocol_version.is_empty() {
846 if let Some(client_v) = md_get(&req.metadata, crate::metadata::PROTOCOL_VERSION_KEY) {
847 let major = |v: &str| v.split('.').next().unwrap_or("").to_string();
848 if major(client_v) != major(&self.protocol_version) {
849 let err = RpcError::version_error(format!(
850 "protocol_version mismatch: client {:?} is incompatible with server {:?}",
851 client_v, self.protocol_version
852 ));
853 write_error_stream(w, &empty_schema(), &err, &self.server_id, &req.request_id)?;
854 return Ok(true);
855 }
856 }
857 }
858
859 let ctx = CallContext::for_request(self, &req);
860
861 let stats = Arc::new(Mutex::new(crate::hooks::CallStatistics::default()));
862 {
864 let mut s = lock_ok(&stats);
865 s.input_batches = 1;
866 s.input_rows = req.batch.num_rows() as u64;
867 }
868
869 if self.describe_enabled && req.method == crate::introspect::DESCRIBE_METHOD_NAME {
871 match crate::introspect::build_describe(
872 &self.protocol_name,
873 &self.methods,
874 &self.server_id,
875 &self.protocol_version,
876 ) {
877 Ok((batch, md)) => {
878 crate::introspect::write_describe_response(w, &batch, &md)?;
879 }
880 Err(err) => {
881 write_error_stream(w, &empty_schema(), &err, &self.server_id, &req.request_id)?;
882 }
883 }
884 return Ok(true);
885 }
886
887 let Some(info) = self.methods.get(&req.method) else {
888 let names = self.sorted_method_names();
889 let msg = format!(
890 "Unknown method: '{}'. Available methods: {:?}",
891 req.method, names
892 );
893 write_error_stream(
894 w,
895 &empty_schema(),
896 &RpcError::attribute_error(msg),
897 &self.server_id,
898 &req.request_id,
899 )?;
900 return Ok(true);
901 };
902
903 let method_type = match info.method_type {
904 MethodType::Unary => "unary",
905 _ => "stream",
906 };
907 let mut dispatch_info =
908 crate::hooks::DispatchInfo::from_request(self, &req, method_type, &ctx.auth);
909 if let Ok(bytes) = serialize_request_batch(&req.batch) {
913 dispatch_info.request_data = bytes;
914 }
915 if method_type == "stream" {
916 dispatch_info.stream_id = crate::access_log::random_stream_id();
917 }
918 let hook_token = self
919 .dispatch_hook
920 .as_ref()
921 .map(|h| h.on_dispatch_start(&dispatch_info));
922
923 let mut app_err: Option<RpcError> = None;
924 let shm = maybe_attach_shm(&req.metadata);
925 let shm_ref = shm.as_ref();
926 match info.method_type {
927 MethodType::Unary => {
928 self.serve_unary(w, &req, info, &ctx, &stats, &mut app_err, shm_ref)?
929 }
930 MethodType::Producer | MethodType::Exchange | MethodType::Dynamic => {
931 self.serve_stream(r, w, &req, info, &ctx, &stats, &mut app_err, shm_ref)?
932 }
933 }
934 let _ = shm;
937
938 if let Some(hook) = self.dispatch_hook.as_ref() {
939 let token = hook_token.unwrap_or(0);
940 let final_stats = lock_ok(&stats).clone();
941 hook.on_dispatch_end(token, &dispatch_info, app_err.as_ref(), &final_stats);
942 }
943 Ok(true)
944 }
945
946 fn read_request<R: Read>(&self, r: &mut R) -> Result<Option<Request>> {
947 let mut reader = match StreamReader::new(r) {
948 Ok(r) => r,
949 Err(e) => {
950 let msg = e.message.to_lowercase();
952 if msg.contains("empty ipc stream") || msg.contains("eof") {
953 return Ok(None);
954 }
955 return Err(e);
956 }
957 };
958 let (batch, metadata) = match reader.read_next()? {
959 Some(b) => b,
960 None => return Ok(None),
961 };
962 reader.drain()?;
963 Ok(Some(Request::from_read_batch(batch, metadata, true)?))
964 }
965
966 #[allow(clippy::too_many_arguments)]
967 fn serve_unary<W: Write>(
968 &self,
969 w: &mut W,
970 req: &Request,
971 info: &MethodInfo,
972 ctx: &CallContext,
973 stats: &Arc<Mutex<crate::hooks::CallStatistics>>,
974 app_err: &mut Option<RpcError>,
975 #[cfg_attr(not(feature = "shm"), allow(unused_variables))] shm: Option<&ShmSegment>,
976 ) -> Result<()> {
977 let result = call_guard(|| (info.unary.as_ref().unwrap())(req, ctx)).and_then(|r| r);
981 let logs = ctx.drain_logs();
982 match result {
983 Ok(maybe_batch) => {
984 let mut sw = StreamWriter::new(w, &info.result_schema)?;
985 for log in logs {
986 let md = build_log_metadata(&log, &self.server_id, &req.request_id);
987 sw.write(&empty_batch(&info.result_schema)?, Some(&md))?;
988 }
989 let out_batch = match maybe_batch {
990 Some(b) => b,
991 None => empty_batch(&info.result_schema)?,
992 };
993 {
994 let mut s = lock_ok(stats);
995 s.output_batches = 1;
996 s.output_rows = out_batch.num_rows() as u64;
997 }
998 #[cfg(feature = "shm")]
999 if let Some(seg) = shm {
1000 let (written, written_md) =
1001 maybe_write_to_shm(out_batch.clone(), Metadata::new(), Some(seg))?;
1002 if written_md.contains_key(crate::metadata::SHM_OFFSET_KEY) {
1003 sw.write(&written, Some(&written_md))?;
1004 sw.finish()?;
1005 return Ok(());
1006 }
1007 }
1008 #[cfg(feature = "http")]
1009 if let Some(cfg) = self.external_config.as_ref() {
1010 if let Ok(Some((ptr, md))) =
1011 crate::external::maybe_externalize_batch(&out_batch, None, cfg)
1012 {
1013 sw.write(&ptr, Some(&md))?;
1014 sw.finish()?;
1015 return Ok(());
1016 }
1017 }
1018 #[cfg(not(feature = "shm"))]
1019 let _ = shm;
1020 sw.write(&out_batch, None)?;
1021 sw.finish()?;
1022 }
1023 Err(err) => {
1024 let mut sw = StreamWriter::new(w, &info.result_schema)?;
1025 for log in logs {
1026 let md = build_log_metadata(&log, &self.server_id, &req.request_id);
1027 sw.write(&empty_batch(&info.result_schema)?, Some(&md))?;
1028 }
1029 let md = build_error_metadata(&err, &self.server_id, &req.request_id);
1030 sw.write(&empty_batch(&info.result_schema)?, Some(&md))?;
1031 sw.finish()?;
1032 *app_err = Some(err);
1033 }
1034 }
1035 Ok(())
1036 }
1037
1038 #[allow(clippy::too_many_arguments)]
1039 #[allow(clippy::too_many_arguments)]
1040 fn serve_stream<R: Read, W: Write>(
1041 &self,
1042 r: &mut R,
1043 w: &mut W,
1044 req: &Request,
1045 info: &MethodInfo,
1046 ctx: &CallContext,
1047 stats: &Arc<Mutex<crate::hooks::CallStatistics>>,
1048 app_err: &mut Option<RpcError>,
1049 #[cfg_attr(not(feature = "shm"), allow(unused_variables))] shm: Option<&ShmSegment>,
1050 ) -> Result<()> {
1051 let init_result = call_guard(|| (info.stream.as_ref().unwrap())(req, ctx)).and_then(|r| r);
1052 let init_logs = ctx.drain_logs();
1053 let stream = match init_result {
1054 Ok(s) => s,
1055 Err(err) => {
1056 let output_schema = info.result_schema.clone();
1058 let mut sw = StreamWriter::new(w, &output_schema)?;
1059 for log in init_logs {
1060 let md = build_log_metadata(&log, &self.server_id, &req.request_id);
1061 sw.write(&empty_batch(&output_schema)?, Some(&md))?;
1062 }
1063 let md = build_error_metadata(&err, &self.server_id, &req.request_id);
1064 sw.write(&empty_batch(&output_schema)?, Some(&md))?;
1065 sw.finish()?;
1066 let _ = drain_input(r);
1069 *app_err = Some(err);
1070 return Ok(());
1071 }
1072 };
1073
1074 let StreamResult {
1075 output_schema,
1076 input_schema,
1077 state,
1078 header,
1079 header_metadata,
1080 } = stream;
1081
1082 let wrote_header = header.is_some();
1084 if let Some(header_batch) = header {
1085 let mut hw = StreamWriter::new(&mut *w, header_batch.schema().as_ref())?;
1086 for log in &init_logs {
1087 let md = build_log_metadata(log, &self.server_id, &req.request_id);
1088 hw.write(&empty_batch(header_batch.schema().as_ref())?, Some(&md))?;
1089 }
1090 hw.write(&header_batch, header_metadata.as_ref())?;
1091 hw.finish()?;
1092 }
1093 let _ = w.flush();
1094
1095 let mut out_writer = StreamWriter::new(&mut *w, output_schema.as_ref())?;
1099 out_writer.flush()?;
1100
1101 let mut input_reader = StreamReader::new(&mut *r)?;
1103
1104 if !wrote_header {
1106 for log in &init_logs {
1107 let md = build_log_metadata(log, &self.server_id, &req.request_id);
1108 out_writer.write(&empty_batch(output_schema.as_ref())?, Some(&md))?;
1109 }
1110 }
1111 let _ = header_metadata;
1112
1113 let mut state = state;
1114 let mut cancelled = false;
1115
1116 'lockstep: loop {
1117 let read = match input_reader.read_next() {
1118 Ok(x) => x,
1119 Err(_) => break,
1120 };
1121 let Some((input_batch, input_md)) = read else {
1122 break;
1123 };
1124
1125 #[cfg(feature = "shm")]
1130 let (input_batch, input_md) = {
1131 let resolved = resolve_shm_batch(input_batch, input_md, shm)?;
1132 if let (Some(off), Some(seg)) = (resolved.release_offset, shm) {
1133 let _ = seg.free(off);
1134 }
1135 (resolved.batch, resolved.metadata)
1136 };
1137
1138 {
1139 let mut s = lock_ok(stats);
1140 s.input_batches += 1;
1141 s.input_rows += input_batch.num_rows() as u64;
1142 }
1143
1144 *lock_ok(&ctx.tick_metadata) = input_md.clone();
1147
1148 if md_get(&input_md, CANCEL_KEY).is_some() {
1150 cancelled = true;
1151 match &mut state {
1152 StreamStateKind::Producer(p) => p.on_cancel(ctx),
1153 StreamStateKind::Exchange(e) => e.on_cancel(ctx),
1154 }
1155 break;
1156 }
1157
1158 let casted = match &input_schema {
1160 Some(expected) if input_batch.schema() != *expected => {
1161 match cast_batch(&input_batch, expected) {
1162 Ok(b) => b,
1163 Err(e) => {
1164 let md = build_error_metadata(&e, &self.server_id, &req.request_id);
1165 out_writer.write(&empty_batch(output_schema.as_ref())?, Some(&md))?;
1166 break 'lockstep;
1167 }
1168 }
1169 }
1170 _ => input_batch,
1171 };
1172
1173 let mut out = OutputCollector::new(output_schema.clone(), input_schema.is_none());
1174
1175 let iter_result = call_guard(|| match &mut state {
1176 StreamStateKind::Producer(p) => p.produce(&mut out, ctx),
1177 StreamStateKind::Exchange(e) => e.exchange(&casted, &mut out, ctx),
1178 })
1179 .and_then(|r| r);
1180
1181 let iter_logs = ctx.drain_logs();
1183 for log in iter_logs {
1184 let md = build_log_metadata(&log, &self.server_id, &req.request_id);
1185 out_writer.write(&empty_batch(output_schema.as_ref())?, Some(&md))?;
1186 }
1187
1188 if let Err(err) = iter_result {
1189 let md = build_error_metadata(&err, &self.server_id, &req.request_id);
1190 out_writer.write(&empty_batch(output_schema.as_ref())?, Some(&md))?;
1191 *app_err = Some(err);
1192 break;
1193 }
1194
1195 let finished = out.finished();
1196
1197 for item in out.items.drain(..) {
1199 match item {
1200 Emitted::Log(log) => {
1201 let md = build_log_metadata(&log, &self.server_id, &req.request_id);
1202 out_writer.write(&empty_batch(output_schema.as_ref())?, Some(&md))?;
1203 }
1204 Emitted::Batch { batch, metadata } => {
1205 {
1206 let mut s = lock_ok(stats);
1207 s.output_batches += 1;
1208 s.output_rows += batch.num_rows() as u64;
1209 }
1210 #[cfg(feature = "shm")]
1211 if let Some(seg) = shm {
1212 let md_in = metadata.clone().unwrap_or_default();
1213 let (written, written_md) =
1214 maybe_write_to_shm(batch.clone(), md_in, Some(seg))?;
1215 if written_md.contains_key(crate::metadata::SHM_OFFSET_KEY) {
1216 out_writer.write(&written, Some(&written_md))?;
1217 continue;
1218 }
1219 }
1220 #[cfg(feature = "http")]
1221 if let Some(cfg) = self.external_config.as_ref() {
1222 match crate::external::maybe_externalize_batch(
1223 &batch,
1224 metadata.as_ref(),
1225 cfg,
1226 ) {
1227 Ok(Some((ptr, md))) => {
1228 out_writer.write(&ptr, Some(&md))?;
1229 continue;
1230 }
1231 Ok(None) => {}
1232 Err(e) => {
1233 *app_err = Some(e);
1236 }
1237 }
1238 }
1239 out_writer.write(&batch, metadata.as_ref())?;
1240 }
1241 }
1242 }
1243 out_writer.flush()?;
1246
1247 if finished {
1248 break;
1249 }
1250 }
1251 let _ = cancelled;
1252 out_writer.finish()?;
1253
1254 let _ = input_reader.drain();
1256 Ok(())
1257 }
1258}
1259
1260fn drain_input<R: Read>(r: &mut R) -> Result<()> {
1261 let mut rdr = StreamReader::new(r)?;
1262 rdr.drain()?;
1263 Ok(())
1264}
1265
1266pub(crate) fn cast_batch(batch: &RecordBatch, target: &Schema) -> Result<RecordBatch> {
1267 if batch.num_columns() != target.fields().len() {
1268 return Err(RpcError::type_error(format!(
1269 "Input schema mismatch: expected {} fields, got {}",
1270 target.fields().len(),
1271 batch.num_columns()
1272 )));
1273 }
1274 let src_schema = batch.schema();
1275 for (i, field) in target.fields().iter().enumerate() {
1276 let src_name = src_schema.field(i).name();
1277 if src_name != field.name() {
1278 return Err(RpcError::type_error(format!(
1279 "Input schema mismatch: expected field {:?}, got {:?}",
1280 field.name(),
1281 src_name
1282 )));
1283 }
1284 }
1285 let opts = arrow_cast::CastOptions::default();
1286 let mut cols = Vec::with_capacity(batch.num_columns());
1287 for (i, field) in target.fields().iter().enumerate() {
1288 let src = batch.column(i);
1289 if src.data_type() == field.data_type() {
1290 cols.push(src.clone());
1291 continue;
1292 }
1293 let c = cast_with_options(src.as_ref(), field.data_type(), &opts)
1294 .map_err(|e| RpcError::type_error(format!("cast field {}: {}", field.name(), e)))?;
1295 cols.push(c);
1296 }
1297 RecordBatch::try_new(Arc::new(target.clone()), cols).map_err(RpcError::from)
1298}
1299
1300pub(crate) fn build_log_metadata(msg: &LogMessage, server_id: &str, request_id: &str) -> Metadata {
1301 let mut md = Metadata::new();
1302 md.insert(LOG_LEVEL_KEY.to_string(), msg.level.as_str().to_string());
1303 md.insert(LOG_MESSAGE_KEY.to_string(), msg.message.clone());
1304 if !msg.extras.is_empty() {
1305 md.insert(LOG_EXTRA_KEY.to_string(), msg.extras_json());
1306 }
1307 if !server_id.is_empty() {
1308 md.insert(SERVER_ID_KEY.to_string(), server_id.to_string());
1309 }
1310 if !request_id.is_empty() {
1311 md.insert(REQUEST_ID_KEY.to_string(), request_id.to_string());
1312 }
1313 md
1314}
1315
1316pub(crate) fn build_error_metadata(err: &RpcError, server_id: &str, request_id: &str) -> Metadata {
1317 let extra = serde_json::json!({
1318 "exception_type": err.error_type,
1319 "exception_message": err.message,
1320 "traceback": err.traceback,
1321 })
1322 .to_string();
1323 let mut md = Metadata::new();
1324 md.insert(LOG_LEVEL_KEY.to_string(), "EXCEPTION".to_string());
1325 md.insert(LOG_MESSAGE_KEY.to_string(), err.message.clone());
1326 md.insert(LOG_EXTRA_KEY.to_string(), extra);
1327 if !server_id.is_empty() {
1328 md.insert(SERVER_ID_KEY.to_string(), server_id.to_string());
1329 }
1330 if !request_id.is_empty() {
1331 md.insert(REQUEST_ID_KEY.to_string(), request_id.to_string());
1332 }
1333 md
1334}
1335
1336pub(crate) fn write_error_stream<W: Write>(
1338 w: &mut W,
1339 schema: &Schema,
1340 err: &RpcError,
1341 server_id: &str,
1342 request_id: &str,
1343) -> Result<()> {
1344 let mut sw = StreamWriter::new(w, schema)?;
1345 let md = build_error_metadata(err, server_id, request_id);
1346 sw.write(&empty_batch(schema)?, Some(&md))?;
1347 sw.finish()?;
1348 Ok(())
1349}
1350
1351#[cfg(test)]
1352mod tests {
1353 use super::*;
1354 use std::io::Cursor;
1355 use std::sync::atomic::{AtomicBool, Ordering};
1356
1357 fn request_bytes(method: &str) -> Vec<u8> {
1360 let schema = empty_schema();
1361 let batch = empty_batch(&schema).unwrap();
1362 let mut buf = Vec::new();
1363 {
1364 let mut w = StreamWriter::new(&mut buf, &schema).unwrap();
1365 let mut md = Metadata::new();
1366 md.insert(RPC_METHOD_KEY.into(), method.into());
1367 md.insert(REQUEST_VERSION_KEY.into(), REQUEST_VERSION.into());
1368 md.insert(REQUEST_ID_KEY.into(), format!("req-{method}"));
1369 w.write(&batch, Some(&md)).unwrap();
1370 w.finish().unwrap();
1371 }
1372 buf
1373 }
1374
1375 #[test]
1376 fn panicking_handler_yields_error_envelope_and_loop_survives() {
1377 let mut server = RpcServer::new("test-srv");
1378 server.register(MethodInfo::unary(
1379 "boom",
1380 empty_schema(),
1381 empty_schema(),
1382 |_req, _ctx| panic!("handler exploded"),
1383 ));
1384 let ran_second = Arc::new(AtomicBool::new(false));
1385 let flag = ran_second.clone();
1386 server.register(MethodInfo::unary(
1387 "ok",
1388 empty_schema(),
1389 empty_schema(),
1390 move |_req, _ctx| {
1391 flag.store(true, Ordering::SeqCst);
1392 Ok(None)
1393 },
1394 ));
1395
1396 let mut input = request_bytes("boom");
1399 input.extend(request_bytes("ok"));
1400 let mut output: Vec<u8> = Vec::new();
1401 server.serve(Cursor::new(input), &mut output);
1402
1403 assert!(
1404 ran_second.load(Ordering::SeqCst),
1405 "serve loop aborted after a handler panic"
1406 );
1407
1408 let mut r = StreamReader::new(output.as_slice()).unwrap();
1411 let (_b, md) = r.read_next().unwrap().expect("error batch");
1412 assert_eq!(md_get(&md, LOG_LEVEL_KEY), Some("EXCEPTION"));
1413 }
1414
1415 #[test]
1416 fn transport_options_reports_shm_capability_unregistered() {
1417 use crate::metadata::TRANSPORT_SHM_KEY;
1418 use crate::transport_options::{shm_available, TRANSPORT_OPTIONS_METHOD_NAME};
1419
1420 let mut server = RpcServer::new("test-srv");
1421 server.register(MethodInfo::unary(
1422 "noop",
1423 empty_schema(),
1424 empty_schema(),
1425 |_req, _ctx| Ok(None),
1426 ));
1427 assert!(!server.methods.contains_key(TRANSPORT_OPTIONS_METHOD_NAME));
1429
1430 let input = request_bytes(TRANSPORT_OPTIONS_METHOD_NAME);
1431 let mut output: Vec<u8> = Vec::new();
1432 server.serve(Cursor::new(input), &mut output);
1433
1434 let mut r = StreamReader::new(output.as_slice()).unwrap();
1435 let (_b, md) = r.read_next().unwrap().expect("transport options batch");
1436 let expected = if shm_available() { "true" } else { "false" };
1437 assert_eq!(md_get(&md, TRANSPORT_SHM_KEY), Some(expected));
1438 assert_eq!(md_get(&md, REQUEST_VERSION_KEY), Some(REQUEST_VERSION));
1439 assert_eq!(md_get(&md, SERVER_ID_KEY), Some("test-srv"));
1440 }
1441}