1#![allow(clippy::type_complexity)]
2use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::{Arc, Mutex};
12
13use crate::encoding::STANDARD as B64;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use sha2::{Digest, Sha256};
17use tokio::sync::{mpsc, oneshot};
18
19use crate::session::SessionFrame;
20
21#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
22#[serde(rename_all = "snake_case")]
23pub enum RpcErrorCode {
24 InvalidArgument,
25 Unauthenticated,
26 PermissionDenied,
27 NotFound,
28 Internal,
29}
30
31impl RpcErrorCode {
32 pub fn as_str(&self) -> &'static str {
33 match self {
34 RpcErrorCode::InvalidArgument => "invalid_argument",
35 RpcErrorCode::Unauthenticated => "unauthenticated",
36 RpcErrorCode::PermissionDenied => "permission_denied",
37 RpcErrorCode::NotFound => "not_found",
38 RpcErrorCode::Internal => "internal",
39 }
40 }
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
44pub struct RpcError {
45 pub code: RpcErrorCode,
46 pub message: String,
47}
48
49#[derive(Debug, thiserror::Error)]
50#[error("{code:?}: {message}")]
51pub struct RpcCallError {
52 pub code: RpcErrorCode,
53 pub message: String,
54}
55
56impl From<RpcError> for RpcCallError {
57 fn from(e: RpcError) -> Self {
58 RpcCallError {
59 code: e.code,
60 message: e.message,
61 }
62 }
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
71#[serde(rename_all = "kebab-case")]
72pub enum RpcMethodKind {
73 Unary,
74 ServerStreaming,
75 ClientStreaming,
76 BidiStreaming,
77 Subscribe,
78 CommandChannel,
79 BulkTransfer,
80 Telemetry,
81 RemoteShell,
82 AgentSession,
83 HttpBridge,
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
87#[serde(tag = "kind", rename_all = "kebab-case")]
88pub enum HttpFrame {
89 RequestHeaders {
90 method: String,
91 path: String,
92 headers: HashMap<String, String>,
93 },
94 ResponseHeaders {
95 status: u16,
96 headers: HashMap<String, String>,
97 },
98 BodyChunk {
99 data: String, },
101 Trailers {
102 headers: HashMap<String, String>,
103 },
104}
105
106#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
107#[serde(rename_all = "lowercase")]
108pub enum RemoteShellStream {
109 Stdin,
110 Stdout,
111 Stderr,
112}
113
114#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
115pub enum StreamingPriority {
116 P0,
117 P1,
118 P2,
119 P3,
120 P4,
121 P5,
122}
123
124#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
127pub struct RpcFrameExt {
128 #[serde(skip_serializing_if = "Option::is_none", default)]
129 pub method_kind: Option<RpcMethodKind>,
130 #[serde(skip_serializing_if = "Option::is_none", default)]
131 pub streaming_priority: Option<StreamingPriority>,
132 #[serde(skip_serializing_if = "Option::is_none", default)]
133 pub subscribe_topic: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none", default)]
135 pub credit: Option<u32>,
136 #[serde(skip_serializing_if = "Option::is_none", default)]
137 pub bulk: Option<RpcBulkExt>,
138 #[serde(skip_serializing_if = "Option::is_none", default)]
139 pub shell_stream: Option<RemoteShellStream>,
140 #[serde(skip_serializing_if = "Option::is_none", default)]
141 pub responsibility_chain: Option<Vec<String>>,
142 #[serde(skip_serializing_if = "Option::is_none", default)]
143 pub ack: Option<String>,
144}
145
146#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
147pub struct RpcBulkExt {
148 #[serde(skip_serializing_if = "Option::is_none", default)]
149 pub chunk_index: Option<u32>,
150 #[serde(skip_serializing_if = "Option::is_none", default)]
151 pub total_chunks: Option<u32>,
152 #[serde(skip_serializing_if = "Option::is_none", default)]
155 pub expected_hash: Option<String>,
156}
157
158#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
159#[serde(tag = "kind", rename_all = "kebab-case")]
160pub enum RpcFrame {
161 RpcCall {
162 call_id: String,
163 method: String,
164 request: Value,
165 #[serde(skip_serializing_if = "Option::is_none", default)]
166 ext: Option<RpcFrameExt>,
167 },
168 RpcResponse {
169 call_id: String,
170 status: ResponseStatus,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 response: Option<Value>,
173 #[serde(skip_serializing_if = "Option::is_none")]
174 error: Option<RpcError>,
175 #[serde(skip_serializing_if = "Option::is_none", default)]
176 ext: Option<RpcFrameExt>,
177 },
178 RpcStream {
179 call_id: String,
180 seq: i64,
181 more: bool,
182 #[serde(skip_serializing_if = "Option::is_none")]
183 value: Option<Value>,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 error: Option<RpcError>,
186 #[serde(skip_serializing_if = "Option::is_none", default)]
187 ext: Option<RpcFrameExt>,
188 },
189 RpcClientStream {
194 call_id: String,
195 seq: u64,
196 more: bool,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 value: Option<Value>,
199 #[serde(skip_serializing_if = "Option::is_none")]
200 error: Option<RpcError>,
201 #[serde(skip_serializing_if = "Option::is_none", default)]
202 ext: Option<RpcFrameExt>,
203 },
204}
205
206#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
207#[serde(rename_all = "snake_case")]
208pub enum ResponseStatus {
209 Ok,
210 Error,
211}
212
213#[derive(Clone, Debug, Serialize, Deserialize)]
214pub struct RpcProofEventStub {
215 #[serde(rename = "type")]
216 pub kind: String,
217 pub method: String,
218 pub call_id: String,
219 pub caller: String,
220 pub result: String,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 pub error_code: Option<String>,
223 #[serde(skip_serializing_if = "Option::is_none", default)]
226 pub method_kind: Option<RpcMethodKind>,
227 #[serde(skip_serializing_if = "Option::is_none", default)]
228 pub streaming_priority: Option<StreamingPriority>,
229 #[serde(skip_serializing_if = "Option::is_none", default)]
230 pub bulk_hash_verified: Option<bool>,
231}
232
233pub trait CapabilityEnforcer: Send + Sync {
234 fn check(&self, caller: &str, method: &str, capability: &str) -> CapabilityDecision;
235}
236
237pub enum CapabilityDecision {
238 Allow,
239 Deny(String),
240}
241
242pub struct AllowAllEnforcer;
243
244impl CapabilityEnforcer for AllowAllEnforcer {
245 fn check(&self, _: &str, _: &str, _: &str) -> CapabilityDecision {
246 CapabilityDecision::Allow
247 }
248}
249
250pub struct DenyAllEnforcer;
251
252impl CapabilityEnforcer for DenyAllEnforcer {
253 fn check(&self, _: &str, _: &str, _: &str) -> CapabilityDecision {
254 CapabilityDecision::Deny("capability enforcement denied all".into())
255 }
256}
257
258pub trait RpcTransport: Send + Sync {
262 fn send(&self, frame: SessionFrame);
263 fn on_frame(&self, listener: Arc<dyn Fn(SessionFrame) + Send + Sync>);
266}
267
268pub fn new_call_id() -> String {
269 use rand::RngCore;
270 let mut bytes = [0u8; 16];
271 rand::thread_rng().fill_bytes(&mut bytes);
272 B64.encode(bytes)
273}
274
275fn encode_rpc(frame: RpcFrame) -> SessionFrame {
276 let payload = serde_json::to_value(frame).expect("serialize rpc frame");
277 SessionFrame::Data { payload }
278}
279
280fn decode_rpc(frame: SessionFrame) -> Option<RpcFrame> {
281 match frame {
282 SessionFrame::Data { payload } => serde_json::from_value(payload).ok(),
283 _ => None,
284 }
285}
286
287fn sha256_of_chunks(chunks: &[Vec<u8>]) -> String {
290 let mut hasher = Sha256::new();
291 for c in chunks {
292 hasher.update(c);
293 }
294 let digest = hasher.finalize();
295 let mut hex = String::with_capacity(7 + digest.len() * 2);
296 hex.push_str("sha256:");
297 for b in digest.iter() {
298 use std::fmt::Write;
299 let _ = write!(hex, "{:02x}", b);
300 }
301 hex
302}
303
304fn decode_bulk_chunk(v: &Value) -> Vec<u8> {
308 match v {
309 Value::String(s) => B64.decode(s.as_bytes()).unwrap_or_default(),
310 Value::Array(arr) => arr
311 .iter()
312 .filter_map(|n| n.as_u64().map(|x| x as u8))
313 .collect(),
314 _ => Vec::new(),
315 }
316}
317
318type UnaryResp = oneshot::Sender<Result<Value, RpcError>>;
321
322enum Pending {
323 Unary(UnaryResp),
324 Stream {
325 tx: mpsc::UnboundedSender<Result<Value, RpcError>>,
326 next_seq: u64,
327 last_shell_stream: Option<RemoteShellStream>,
328 last_chain: Option<Vec<String>>,
329 },
330 RemoteShellStream {
333 tx: mpsc::UnboundedSender<Result<RemoteShellOut, RpcError>>,
334 next_seq: u64,
335 last_stream: RemoteShellStream,
336 },
337 AgentSessionStream {
340 tx: mpsc::UnboundedSender<Result<AgentSessionFrame, RpcError>>,
341 next_seq: u64,
342 last_chain: Vec<String>,
343 },
344}
345
346pub struct RpcClient<T: RpcTransport + 'static> {
347 transport: Arc<T>,
348 pending: Arc<Mutex<HashMap<String, Pending>>>,
349 caller_actor: String,
350}
351
352impl<T: RpcTransport + 'static> RpcClient<T> {
353 pub fn new(transport: Arc<T>, caller_actor: impl Into<String>) -> Self {
354 let pending: Arc<Mutex<HashMap<String, Pending>>> = Arc::new(Mutex::new(HashMap::new()));
355 let pending_for_listener = pending.clone();
356 transport.on_frame(Arc::new(move |frame| {
357 let rpc = match decode_rpc(frame) {
358 Some(r) => r,
359 None => return,
360 };
361 match rpc {
362 RpcFrame::RpcResponse {
363 call_id,
364 status,
365 response,
366 error,
367 ext: _,
368 } => {
369 let mut map = pending_for_listener.lock().unwrap();
370 if let Some(Pending::Unary(tx)) = map.remove(&call_id) {
371 match status {
372 ResponseStatus::Ok => {
373 let _ = tx.send(Ok(response.unwrap_or(Value::Null)));
374 }
375 ResponseStatus::Error => {
376 let _ = tx.send(Err(error.unwrap_or(RpcError {
377 code: RpcErrorCode::Internal,
378 message: "(no error body)".into(),
379 })));
380 }
381 }
382 }
383 }
384 RpcFrame::RpcStream {
385 call_id,
386 seq,
387 more,
388 value,
389 error,
390 ext,
391 } => {
392 let mut map = pending_for_listener.lock().unwrap();
393 let Some(entry) = map.get_mut(&call_id) else {
394 return;
395 };
396 match entry {
397 Pending::Stream {
398 tx,
399 next_seq,
400 last_shell_stream,
401 last_chain,
402 } => {
403 if seq < 0 {
407 if !more {
408 map.remove(&call_id);
410 }
411 return;
412 }
413 let seq_u = seq as u64;
414 if seq_u != *next_seq {
415 let _ = tx.send(Err(RpcError {
416 code: RpcErrorCode::Internal,
417 message: format!(
418 "stream seq mismatch: expected {}, got {}",
419 next_seq, seq_u
420 ),
421 }));
422 map.remove(&call_id);
423 return;
424 }
425 *next_seq += 1;
426 if let Some(e) = &ext {
428 if let Some(s) = &e.shell_stream {
429 *last_shell_stream = Some(s.clone());
430 }
431 if let Some(c) = &e.responsibility_chain {
432 *last_chain = Some(c.clone());
433 }
434 }
435 if more {
436 if let Some(v) = value {
437 let _ = tx.send(Ok(v));
438 }
439 } else if let Some(err) = error {
440 let _ = tx.send(Err(err));
441 map.remove(&call_id);
442 } else {
443 map.remove(&call_id);
444 }
445 }
446 Pending::RemoteShellStream {
447 tx,
448 next_seq,
449 last_stream,
450 } => {
451 if seq < 0 {
452 if !more {
453 map.remove(&call_id);
454 }
455 return;
456 }
457 let seq_u = seq as u64;
458 if seq_u != *next_seq {
459 let _ = tx.send(Err(RpcError {
460 code: RpcErrorCode::Internal,
461 message: format!(
462 "stream seq mismatch: expected {}, got {}",
463 next_seq, seq_u
464 ),
465 }));
466 map.remove(&call_id);
467 return;
468 }
469 *next_seq += 1;
470 let stream_tag = ext
471 .as_ref()
472 .and_then(|e| e.shell_stream.clone())
473 .unwrap_or_else(|| last_stream.clone());
474 *last_stream = stream_tag.clone();
475 if more {
476 if let Some(v) = value {
477 let bytes = decode_bulk_chunk(&v);
478 let _ = tx.send(Ok(RemoteShellOut {
479 stream: stream_tag,
480 data: bytes,
481 }));
482 }
483 } else if let Some(err) = error {
484 let _ = tx.send(Err(err));
485 map.remove(&call_id);
486 } else {
487 map.remove(&call_id);
488 }
489 }
490 Pending::AgentSessionStream {
491 tx,
492 next_seq,
493 last_chain,
494 } => {
495 if seq < 0 {
496 if !more {
497 map.remove(&call_id);
498 }
499 return;
500 }
501 let seq_u = seq as u64;
502 if seq_u != *next_seq {
503 let _ = tx.send(Err(RpcError {
504 code: RpcErrorCode::Internal,
505 message: format!(
506 "stream seq mismatch: expected {}, got {}",
507 next_seq, seq_u
508 ),
509 }));
510 map.remove(&call_id);
511 return;
512 }
513 *next_seq += 1;
514 let chain = ext
515 .as_ref()
516 .and_then(|e| e.responsibility_chain.clone())
517 .unwrap_or_else(|| last_chain.clone());
518 *last_chain = chain.clone();
519 if more {
520 if let Some(v) = value {
521 let _ = tx.send(Ok(AgentSessionFrame {
522 value: v,
523 responsibility_chain: chain,
524 }));
525 }
526 } else if let Some(err) = error {
527 let _ = tx.send(Err(err));
528 map.remove(&call_id);
529 } else {
530 map.remove(&call_id);
531 }
532 }
533 _ => {}
534 }
535 }
536 RpcFrame::RpcCall { .. } | RpcFrame::RpcClientStream { .. } => {
537 }
539 }
540 }));
541 RpcClient {
542 transport,
543 pending,
544 caller_actor: caller_actor.into(),
545 }
546 }
547
548 pub fn caller_actor(&self) -> &str {
549 &self.caller_actor
550 }
551
552 pub async fn call_raw(&self, method: &str, request: Value) -> Result<Value, RpcCallError> {
553 let call_id = new_call_id();
554 let (tx, rx) = oneshot::channel();
555 self.pending
556 .lock()
557 .unwrap()
558 .insert(call_id.clone(), Pending::Unary(tx));
559 self.transport.send(encode_rpc(RpcFrame::RpcCall {
560 call_id: call_id.clone(),
561 method: method.to_owned(),
562 request,
563 ext: None,
564 }));
565 match rx.await {
566 Ok(Ok(v)) => Ok(v),
567 Ok(Err(err)) => Err(err.into()),
568 Err(_) => Err(RpcCallError {
569 code: RpcErrorCode::Internal,
570 message: "transport dropped the pending call".into(),
571 }),
572 }
573 }
574
575 pub fn server_stream_raw(
576 &self,
577 method: &str,
578 request: Value,
579 ) -> mpsc::UnboundedReceiver<Result<Value, RpcError>> {
580 let (tx, rx) = mpsc::unbounded_channel();
581 let call_id = new_call_id();
582 self.pending.lock().unwrap().insert(
583 call_id.clone(),
584 Pending::Stream {
585 tx,
586 next_seq: 0,
587 last_shell_stream: None,
588 last_chain: None,
589 },
590 );
591 self.transport.send(encode_rpc(RpcFrame::RpcCall {
592 call_id,
593 method: method.to_owned(),
594 request,
595 ext: None,
596 }));
597 rx
598 }
599
600 pub fn subscribe_raw(
605 &self,
606 method: &str,
607 request: Value,
608 topic: Option<String>,
609 ) -> mpsc::UnboundedReceiver<Result<Value, RpcError>> {
610 let (tx, rx) = mpsc::unbounded_channel();
611 let call_id = new_call_id();
612 self.pending.lock().unwrap().insert(
613 call_id.clone(),
614 Pending::Stream {
615 tx,
616 next_seq: 0,
617 last_shell_stream: None,
618 last_chain: None,
619 },
620 );
621 let ext = RpcFrameExt {
622 method_kind: Some(RpcMethodKind::Subscribe),
623 subscribe_topic: topic,
624 ..Default::default()
625 };
626 self.transport.send(encode_rpc(RpcFrame::RpcCall {
627 call_id,
628 method: method.to_owned(),
629 request,
630 ext: Some(ext),
631 }));
632 rx
633 }
634
635 pub fn client_stream_raw(
639 &self,
640 method: &str,
641 request: Value,
642 mut requests_rx: mpsc::UnboundedReceiver<Result<Value, RpcError>>,
643 ) -> Pin<Box<dyn Future<Output = Result<Value, RpcCallError>> + Send>> {
644 let call_id = new_call_id();
645 let (tx, rx) = oneshot::channel();
646 self.pending
647 .lock()
648 .unwrap()
649 .insert(call_id.clone(), Pending::Unary(tx));
650 let transport = self.transport.clone();
651 transport.send(encode_rpc(RpcFrame::RpcCall {
652 call_id: call_id.clone(),
653 method: method.to_owned(),
654 request,
655 ext: Some(RpcFrameExt {
656 method_kind: Some(RpcMethodKind::ClientStreaming),
657 ..Default::default()
658 }),
659 }));
660 let pump_transport = transport.clone();
661 let pump_call_id = call_id.clone();
662 tokio::spawn(async move {
663 let mut seq: u64 = 0;
664 while let Some(item) = requests_rx.recv().await {
665 match item {
666 Ok(v) => {
667 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
668 call_id: pump_call_id.clone(),
669 seq,
670 more: true,
671 value: Some(v),
672 error: None,
673 ext: None,
674 }));
675 seq += 1;
676 }
677 Err(err) => {
678 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
679 call_id: pump_call_id.clone(),
680 seq,
681 more: false,
682 value: None,
683 error: Some(err),
684 ext: None,
685 }));
686 return;
687 }
688 }
689 }
690 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
691 call_id: pump_call_id,
692 seq,
693 more: false,
694 value: None,
695 error: None,
696 ext: None,
697 }));
698 });
699 Box::pin(async move {
700 match rx.await {
701 Ok(Ok(v)) => Ok(v),
702 Ok(Err(err)) => Err(err.into()),
703 Err(_) => Err(RpcCallError {
704 code: RpcErrorCode::Internal,
705 message: "transport dropped the pending call".into(),
706 }),
707 }
708 })
709 }
710
711 pub fn bidi_raw(
715 &self,
716 method: &str,
717 request: Value,
718 ) -> (
719 mpsc::UnboundedSender<Result<Value, RpcError>>,
720 mpsc::UnboundedReceiver<Result<Value, RpcError>>,
721 ) {
722 self.bidi_with_kind(method, request, RpcMethodKind::BidiStreaming, None)
723 }
724
725 pub fn command_channel_raw(
730 &self,
731 method: &str,
732 request: Value,
733 ) -> (
734 mpsc::UnboundedSender<Result<Value, RpcError>>,
735 mpsc::UnboundedReceiver<Result<Value, RpcError>>,
736 ) {
737 self.bidi_with_kind(method, request, RpcMethodKind::CommandChannel, None)
738 }
739
740 fn bidi_with_kind(
741 &self,
742 method: &str,
743 request: Value,
744 kind: RpcMethodKind,
745 topic: Option<String>,
746 ) -> (
747 mpsc::UnboundedSender<Result<Value, RpcError>>,
748 mpsc::UnboundedReceiver<Result<Value, RpcError>>,
749 ) {
750 let call_id = new_call_id();
751 let (server_tx, server_rx) = mpsc::unbounded_channel();
752 self.pending.lock().unwrap().insert(
753 call_id.clone(),
754 Pending::Stream {
755 tx: server_tx,
756 next_seq: 0,
757 last_shell_stream: None,
758 last_chain: None,
759 },
760 );
761 let ext = RpcFrameExt {
762 method_kind: Some(kind.clone()),
763 subscribe_topic: topic,
764 ..Default::default()
765 };
766 self.transport.send(encode_rpc(RpcFrame::RpcCall {
767 call_id: call_id.clone(),
768 method: method.to_owned(),
769 request,
770 ext: Some(ext.clone()),
771 }));
772 let (client_tx, mut client_rx) = mpsc::unbounded_channel::<Result<Value, RpcError>>();
773 let pump_transport = self.transport.clone();
774 let pump_call_id = call_id.clone();
775 tokio::spawn(async move {
776 let mut seq: u64 = 0;
777 while let Some(item) = client_rx.recv().await {
778 match item {
779 Ok(v) => {
780 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
781 call_id: pump_call_id.clone(),
782 seq,
783 more: true,
784 value: Some(v),
785 error: None,
786 ext: Some(ext.clone()),
787 }));
788 seq += 1;
789 }
790 Err(err) => {
791 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
792 call_id: pump_call_id.clone(),
793 seq,
794 more: false,
795 value: None,
796 error: Some(err),
797 ext: Some(ext.clone()),
798 }));
799 return;
800 }
801 }
802 }
803 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
804 call_id: pump_call_id,
805 seq,
806 more: false,
807 value: None,
808 error: None,
809 ext: Some(ext),
810 }));
811 });
812 (client_tx, server_rx)
813 }
814
815 pub fn bulk_transfer_raw(
819 &self,
820 method: &str,
821 request: Value,
822 chunks: &[Vec<u8>],
823 ) -> Pin<Box<dyn Future<Output = Result<Value, RpcCallError>> + Send>> {
824 let expected_hash = sha256_of_chunks(chunks);
825 let call_id = new_call_id();
826 let (tx, rx) = oneshot::channel();
827 self.pending
828 .lock()
829 .unwrap()
830 .insert(call_id.clone(), Pending::Unary(tx));
831 let transport = self.transport.clone();
832 let total_chunks = chunks.len() as u32;
833 transport.send(encode_rpc(RpcFrame::RpcCall {
834 call_id: call_id.clone(),
835 method: method.to_owned(),
836 request,
837 ext: Some(RpcFrameExt {
838 method_kind: Some(RpcMethodKind::BulkTransfer),
839 bulk: Some(RpcBulkExt {
840 chunk_index: None,
841 total_chunks: Some(total_chunks),
842 expected_hash: Some(expected_hash),
843 }),
844 ..Default::default()
845 }),
846 }));
847 let owned_chunks: Vec<Vec<u8>> = chunks.to_vec();
848 let pump_transport = transport.clone();
849 let pump_call_id = call_id.clone();
850 tokio::spawn(async move {
851 tokio::task::yield_now().await;
854 let mut seq: u64 = 0;
855 for chunk in owned_chunks.iter() {
856 let encoded = B64.encode(chunk);
857 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
858 call_id: pump_call_id.clone(),
859 seq,
860 more: true,
861 value: Some(Value::String(encoded)),
862 error: None,
863 ext: Some(RpcFrameExt {
864 method_kind: Some(RpcMethodKind::BulkTransfer),
865 bulk: Some(RpcBulkExt {
866 chunk_index: Some(seq as u32),
867 ..Default::default()
868 }),
869 ..Default::default()
870 }),
871 }));
872 seq += 1;
873 }
874 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
875 call_id: pump_call_id,
876 seq,
877 more: false,
878 value: None,
879 error: None,
880 ext: Some(RpcFrameExt {
881 method_kind: Some(RpcMethodKind::BulkTransfer),
882 ..Default::default()
883 }),
884 }));
885 });
886 Box::pin(async move {
887 match rx.await {
888 Ok(Ok(v)) => Ok(v),
889 Ok(Err(err)) => Err(err.into()),
890 Err(_) => Err(RpcCallError {
891 code: RpcErrorCode::Internal,
892 message: "transport dropped the pending call".into(),
893 }),
894 }
895 })
896 }
897
898 pub fn telemetry_raw(
902 &self,
903 method: &str,
904 request: Value,
905 mut frames_rx: mpsc::UnboundedReceiver<Result<Value, RpcError>>,
906 priority: StreamingPriority,
907 ) -> Pin<Box<dyn Future<Output = Result<(), RpcCallError>> + Send>> {
908 let call_id = new_call_id();
909 let (tx, rx) = oneshot::channel();
910 self.pending
911 .lock()
912 .unwrap()
913 .insert(call_id.clone(), Pending::Unary(tx));
914 let transport = self.transport.clone();
915 let call_ext = RpcFrameExt {
916 method_kind: Some(RpcMethodKind::Telemetry),
917 streaming_priority: Some(priority.clone()),
918 ..Default::default()
919 };
920 transport.send(encode_rpc(RpcFrame::RpcCall {
921 call_id: call_id.clone(),
922 method: method.to_owned(),
923 request,
924 ext: Some(call_ext),
925 }));
926 let pump_transport = transport.clone();
927 let pump_call_id = call_id.clone();
928 let pump_priority = priority;
929 tokio::spawn(async move {
930 let mut seq: u64 = 0;
931 while let Some(item) = frames_rx.recv().await {
932 match item {
933 Ok(v) => {
934 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
935 call_id: pump_call_id.clone(),
936 seq,
937 more: true,
938 value: Some(v),
939 error: None,
940 ext: Some(RpcFrameExt {
941 method_kind: Some(RpcMethodKind::Telemetry),
942 streaming_priority: Some(pump_priority.clone()),
943 ..Default::default()
944 }),
945 }));
946 seq += 1;
947 }
948 Err(err) => {
949 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
950 call_id: pump_call_id.clone(),
951 seq,
952 more: false,
953 value: None,
954 error: Some(err),
955 ext: None,
956 }));
957 return;
958 }
959 }
960 }
961 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
962 call_id: pump_call_id,
963 seq,
964 more: false,
965 value: None,
966 error: None,
967 ext: Some(RpcFrameExt {
968 method_kind: Some(RpcMethodKind::Telemetry),
969 ..Default::default()
970 }),
971 }));
972 });
973 Box::pin(async move {
974 match rx.await {
975 Ok(Ok(_)) => Ok(()),
976 Ok(Err(err)) => Err(err.into()),
977 Err(_) => Err(RpcCallError {
978 code: RpcErrorCode::Internal,
979 message: "transport dropped the pending call".into(),
980 }),
981 }
982 })
983 }
984
985 pub fn remote_shell_raw(
988 &self,
989 method: &str,
990 request: Value,
991 mut stdin_rx: mpsc::UnboundedReceiver<Vec<u8>>,
992 ) -> mpsc::UnboundedReceiver<Result<RemoteShellOut, RpcError>> {
993 let call_id = new_call_id();
994 let (out_tx, out_rx) = mpsc::unbounded_channel();
995 self.pending.lock().unwrap().insert(
996 call_id.clone(),
997 Pending::RemoteShellStream {
998 tx: out_tx,
999 next_seq: 0,
1000 last_stream: RemoteShellStream::Stdout,
1001 },
1002 );
1003 let transport = self.transport.clone();
1004 transport.send(encode_rpc(RpcFrame::RpcCall {
1005 call_id: call_id.clone(),
1006 method: method.to_owned(),
1007 request,
1008 ext: Some(RpcFrameExt {
1009 method_kind: Some(RpcMethodKind::RemoteShell),
1010 ..Default::default()
1011 }),
1012 }));
1013 let pump_transport = transport.clone();
1014 let pump_call_id = call_id.clone();
1015 tokio::spawn(async move {
1016 let mut seq: u64 = 0;
1017 while let Some(chunk) = stdin_rx.recv().await {
1018 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
1019 call_id: pump_call_id.clone(),
1020 seq,
1021 more: true,
1022 value: Some(Value::String(B64.encode(&chunk))),
1023 error: None,
1024 ext: Some(RpcFrameExt {
1025 method_kind: Some(RpcMethodKind::RemoteShell),
1026 shell_stream: Some(RemoteShellStream::Stdin),
1027 ..Default::default()
1028 }),
1029 }));
1030 seq += 1;
1031 }
1032 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
1033 call_id: pump_call_id,
1034 seq,
1035 more: false,
1036 value: None,
1037 error: None,
1038 ext: Some(RpcFrameExt {
1039 method_kind: Some(RpcMethodKind::RemoteShell),
1040 ..Default::default()
1041 }),
1042 }));
1043 });
1044 out_rx
1045 }
1046
1047 pub fn agent_session_raw(
1052 &self,
1053 method: &str,
1054 request: Value,
1055 initial_chain: Vec<String>,
1056 mut frames_rx: mpsc::UnboundedReceiver<AgentSessionFrame>,
1057 ) -> mpsc::UnboundedReceiver<Result<AgentSessionFrame, RpcError>> {
1058 let call_id = new_call_id();
1059 let (out_tx, out_rx) = mpsc::unbounded_channel();
1060 self.pending.lock().unwrap().insert(
1061 call_id.clone(),
1062 Pending::AgentSessionStream {
1063 tx: out_tx,
1064 next_seq: 0,
1065 last_chain: initial_chain.clone(),
1066 },
1067 );
1068 let transport = self.transport.clone();
1069 transport.send(encode_rpc(RpcFrame::RpcCall {
1070 call_id: call_id.clone(),
1071 method: method.to_owned(),
1072 request,
1073 ext: Some(RpcFrameExt {
1074 method_kind: Some(RpcMethodKind::AgentSession),
1075 responsibility_chain: Some(initial_chain),
1076 ..Default::default()
1077 }),
1078 }));
1079 let pump_transport = transport.clone();
1080 let pump_call_id = call_id.clone();
1081 tokio::spawn(async move {
1082 let mut seq: u64 = 0;
1083 while let Some(frame) = frames_rx.recv().await {
1084 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
1085 call_id: pump_call_id.clone(),
1086 seq,
1087 more: true,
1088 value: Some(frame.value),
1089 error: None,
1090 ext: Some(RpcFrameExt {
1091 method_kind: Some(RpcMethodKind::AgentSession),
1092 responsibility_chain: Some(frame.responsibility_chain),
1093 ..Default::default()
1094 }),
1095 }));
1096 seq += 1;
1097 }
1098 pump_transport.send(encode_rpc(RpcFrame::RpcClientStream {
1099 call_id: pump_call_id,
1100 seq,
1101 more: false,
1102 value: None,
1103 error: None,
1104 ext: Some(RpcFrameExt {
1105 method_kind: Some(RpcMethodKind::AgentSession),
1106 ..Default::default()
1107 }),
1108 }));
1109 });
1110 out_rx
1111 }
1112}
1113
1114pub struct RpcContext {
1117 pub caller_actor: String,
1118 pub method: String,
1119 pub call_id: String,
1120 pub initial_chain: Vec<String>,
1123 pub subscribe_topic: Option<String>,
1126}
1127
1128pub type UnaryHandler = Arc<
1129 dyn Fn(Value, RpcContext) -> Pin<Box<dyn Future<Output = Result<Value, RpcError>> + Send>>
1130 + Send
1131 + Sync,
1132>;
1133
1134pub type StreamHandler = Arc<
1135 dyn Fn(
1136 Value,
1137 RpcContext,
1138 mpsc::UnboundedSender<Result<Value, RpcError>>,
1139 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
1140 + Send
1141 + Sync,
1142>;
1143
1144pub type SubscribeHandler = StreamHandler;
1148
1149pub type ClientStreamHandler = Arc<
1153 dyn Fn(
1154 Value,
1155 RpcContext,
1156 mpsc::UnboundedReceiver<Result<Value, RpcError>>,
1157 ) -> Pin<Box<dyn Future<Output = Result<Value, RpcError>> + Send>>
1158 + Send
1159 + Sync,
1160>;
1161
1162pub type BidiHandler = Arc<
1166 dyn Fn(
1167 Value,
1168 RpcContext,
1169 mpsc::UnboundedReceiver<Result<Value, RpcError>>,
1170 mpsc::UnboundedSender<Result<Value, RpcError>>,
1171 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
1172 + Send
1173 + Sync,
1174>;
1175
1176pub type CommandChannelHandler = BidiHandler;
1180
1181pub type BulkTransferHandler = Arc<
1186 dyn Fn(
1187 Value,
1188 RpcContext,
1189 mpsc::UnboundedReceiver<Vec<u8>>,
1190 ) -> Pin<Box<dyn Future<Output = Result<Value, RpcError>> + Send>>
1191 + Send
1192 + Sync,
1193>;
1194
1195pub type TelemetryHandler = Arc<
1199 dyn Fn(
1200 Value,
1201 RpcContext,
1202 StreamingPriority,
1203 mpsc::UnboundedReceiver<Result<Value, RpcError>>,
1204 ) -> Pin<Box<dyn Future<Output = Result<(), RpcError>> + Send>>
1205 + Send
1206 + Sync,
1207>;
1208
1209pub type RemoteShellHandler = Arc<
1213 dyn Fn(
1214 Value,
1215 RpcContext,
1216 mpsc::UnboundedReceiver<Vec<u8>>,
1217 mpsc::UnboundedSender<RemoteShellOut>,
1218 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
1219 + Send
1220 + Sync,
1221>;
1222
1223pub type AgentSessionHandler = Arc<
1228 dyn Fn(
1229 Value,
1230 RpcContext,
1231 mpsc::UnboundedReceiver<AgentSessionFrame>,
1232 mpsc::UnboundedSender<AgentSessionFrame>,
1233 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
1234 + Send
1235 + Sync,
1236>;
1237
1238pub type HttpBridgeHandler = Arc<
1241 dyn Fn(
1242 Value,
1243 RpcContext,
1244 mpsc::UnboundedReceiver<HttpFrame>,
1245 mpsc::UnboundedSender<HttpFrame>,
1246 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
1247 + Send
1248 + Sync,
1249>;
1250
1251#[derive(Clone, Debug)]
1252pub struct RemoteShellOut {
1253 pub stream: RemoteShellStream,
1254 pub data: Vec<u8>,
1255}
1256
1257#[derive(Clone, Debug)]
1258pub struct AgentSessionFrame {
1259 pub value: Value,
1260 pub responsibility_chain: Vec<String>,
1261}
1262
1263enum Handler {
1264 Unary {
1265 capability: String,
1266 handler: UnaryHandler,
1267 },
1268 Stream {
1269 capability: String,
1270 handler: StreamHandler,
1271 },
1272 Subscribe {
1273 capability: String,
1274 handler: SubscribeHandler,
1275 },
1276 ClientStream {
1277 capability: String,
1278 handler: ClientStreamHandler,
1279 },
1280 Bidi {
1281 capability: String,
1282 handler: BidiHandler,
1283 },
1284 CommandChannel {
1285 capability: String,
1286 handler: CommandChannelHandler,
1287 initial_credit: u32,
1288 },
1289 BulkTransfer {
1290 capability: String,
1291 handler: BulkTransferHandler,
1292 },
1293 Telemetry {
1294 capability: String,
1295 handler: TelemetryHandler,
1296 priority: StreamingPriority,
1297 },
1298 RemoteShell {
1299 capability: String,
1300 handler: RemoteShellHandler,
1301 },
1302 AgentSession {
1303 capability: String,
1304 handler: AgentSessionHandler,
1305 },
1306 HttpBridge {
1307 capability: String,
1308 handler: HttpBridgeHandler,
1309 },
1310}
1311
1312impl Handler {
1313 fn capability(&self) -> &str {
1314 match self {
1315 Handler::Unary { capability, .. }
1316 | Handler::Stream { capability, .. }
1317 | Handler::Subscribe { capability, .. }
1318 | Handler::ClientStream { capability, .. }
1319 | Handler::Bidi { capability, .. }
1320 | Handler::CommandChannel { capability, .. }
1321 | Handler::BulkTransfer { capability, .. }
1322 | Handler::Telemetry { capability, .. }
1323 | Handler::RemoteShell { capability, .. }
1324 | Handler::AgentSession { capability, .. }
1325 | Handler::HttpBridge { capability, .. } => capability,
1326 }
1327 }
1328
1329 fn is_streaming(&self) -> bool {
1332 matches!(
1333 self,
1334 Handler::Stream { .. }
1335 | Handler::Subscribe { .. }
1336 | Handler::Bidi { .. }
1337 | Handler::CommandChannel { .. }
1338 | Handler::RemoteShell { .. }
1339 | Handler::AgentSession { .. }
1340 | Handler::HttpBridge { .. }
1341 )
1342 }
1343}
1344
1345struct InflightCall {
1348 push: Arc<dyn Fn(InflightMsg) + Send + Sync>,
1349}
1350
1351#[derive(Clone)]
1352enum InflightMsg {
1353 Value(Value, Option<RpcFrameExt>),
1354 Done,
1355 Error(RpcError),
1356}
1357
1358pub struct RpcServer<T: RpcTransport + 'static> {
1359 transport: Arc<T>,
1360 handlers: Arc<Mutex<HashMap<String, Arc<Handler>>>>,
1361 #[allow(dead_code)]
1367 inflight: Arc<Mutex<HashMap<String, InflightCall>>>,
1368 caller_actor: String,
1369 enforcer: Arc<dyn CapabilityEnforcer>,
1370}
1371
1372impl<T: RpcTransport + 'static> RpcServer<T> {
1373 pub fn new(
1374 transport: Arc<T>,
1375 caller_actor: impl Into<String>,
1376 enforcer: Arc<dyn CapabilityEnforcer>,
1377 ) -> Self {
1378 let handlers: Arc<Mutex<HashMap<String, Arc<Handler>>>> =
1379 Arc::new(Mutex::new(HashMap::new()));
1380 let inflight: Arc<Mutex<HashMap<String, InflightCall>>> =
1381 Arc::new(Mutex::new(HashMap::new()));
1382 let caller_actor: String = caller_actor.into();
1383 let caller_for_listener = caller_actor.clone();
1384 let handlers_for_listener = handlers.clone();
1385 let inflight_for_listener = inflight.clone();
1386 let enforcer_for_listener = enforcer.clone();
1387 let transport_for_listener = transport.clone();
1388 transport.on_frame(Arc::new(move |frame| {
1389 let rpc = match decode_rpc(frame) {
1390 Some(r) => r,
1391 None => return,
1392 };
1393
1394 if let RpcFrame::RpcClientStream {
1396 call_id,
1397 seq: _,
1398 more,
1399 value,
1400 error,
1401 ext,
1402 } = &rpc
1403 {
1404 let push = {
1405 let map = inflight_for_listener.lock().unwrap();
1406 map.get(call_id).map(|c| c.push.clone())
1407 };
1408 let Some(push) = push else { return };
1409 if let Some(err) = error.clone() {
1410 push(InflightMsg::Error(err));
1411 } else if *more {
1412 if let Some(v) = value.clone() {
1413 push(InflightMsg::Value(v, ext.clone()));
1414 }
1415 } else {
1416 push(InflightMsg::Done);
1417 }
1418 return;
1419 }
1420
1421 let RpcFrame::RpcCall {
1422 call_id,
1423 method,
1424 request,
1425 ext: call_ext,
1426 } = rpc
1427 else {
1428 return;
1429 };
1430
1431 let initial_chain = call_ext
1432 .as_ref()
1433 .and_then(|e| e.responsibility_chain.clone())
1434 .unwrap_or_default();
1435 let subscribe_topic = call_ext.as_ref().and_then(|e| e.subscribe_topic.clone());
1436
1437 let handler = {
1438 let map = handlers_for_listener.lock().unwrap();
1439 map.get(&method).cloned()
1440 };
1441
1442 let Some(handler) = handler else {
1443 transport_for_listener.send(encode_rpc(RpcFrame::RpcResponse {
1444 call_id: call_id.clone(),
1445 status: ResponseStatus::Error,
1446 response: None,
1447 error: Some(RpcError {
1448 code: RpcErrorCode::NotFound,
1449 message: format!("unknown method: {}", method),
1450 }),
1451 ext: None,
1452 }));
1453 return;
1454 };
1455
1456 let capability = handler.capability().to_owned();
1457 let is_streaming = handler.is_streaming();
1458 match enforcer_for_listener.check(&caller_for_listener, &method, &capability) {
1459 CapabilityDecision::Allow => {}
1460 CapabilityDecision::Deny(reason) => {
1461 if is_streaming {
1462 transport_for_listener.send(encode_rpc(RpcFrame::RpcStream {
1463 call_id,
1464 seq: 0,
1465 more: false,
1466 value: None,
1467 error: Some(RpcError {
1468 code: RpcErrorCode::PermissionDenied,
1469 message: reason,
1470 }),
1471 ext: None,
1472 }));
1473 } else {
1474 transport_for_listener.send(encode_rpc(RpcFrame::RpcResponse {
1475 call_id,
1476 status: ResponseStatus::Error,
1477 response: None,
1478 error: Some(RpcError {
1479 code: RpcErrorCode::PermissionDenied,
1480 message: reason,
1481 }),
1482 ext: None,
1483 }));
1484 }
1485 return;
1486 }
1487 }
1488
1489 let ctx = RpcContext {
1490 caller_actor: caller_for_listener.clone(),
1491 method: method.clone(),
1492 call_id: call_id.clone(),
1493 initial_chain,
1494 subscribe_topic,
1495 };
1496 let transport = transport_for_listener.clone();
1497 let inflight = inflight_for_listener.clone();
1498 match &*handler {
1499 Handler::Unary { handler, .. } => {
1500 let handler = handler.clone();
1501 tokio::spawn(async move {
1502 dispatch_unary(transport, ctx, request, handler).await;
1503 });
1504 }
1505 Handler::Stream { handler, .. } => {
1506 let handler = handler.clone();
1507 tokio::spawn(async move {
1508 dispatch_server_stream(transport, ctx, request, handler).await;
1509 });
1510 }
1511 Handler::Subscribe { handler, .. } => {
1512 let handler = handler.clone();
1513 tokio::spawn(async move {
1514 dispatch_subscribe(transport, ctx, request, handler).await;
1515 });
1516 }
1517 Handler::ClientStream { handler, .. } => {
1518 let handler = handler.clone();
1519 tokio::spawn(async move {
1520 dispatch_client_stream(transport, inflight, ctx, request, handler).await;
1521 });
1522 }
1523 Handler::Bidi { handler, .. } => {
1524 let handler = handler.clone();
1525 tokio::spawn(async move {
1526 dispatch_bidi(transport, inflight, ctx, request, handler).await;
1527 });
1528 }
1529 Handler::CommandChannel {
1530 handler,
1531 initial_credit,
1532 ..
1533 } => {
1534 let handler = handler.clone();
1535 let credit = *initial_credit;
1536 tokio::spawn(async move {
1537 dispatch_command_channel(
1538 transport, inflight, ctx, request, handler, credit,
1539 )
1540 .await;
1541 });
1542 }
1543 Handler::BulkTransfer { handler, .. } => {
1544 let handler = handler.clone();
1545 let expected_hash = call_ext
1546 .as_ref()
1547 .and_then(|e| e.bulk.as_ref())
1548 .and_then(|b| b.expected_hash.clone());
1549 tokio::spawn(async move {
1550 dispatch_bulk_transfer(
1551 transport,
1552 inflight,
1553 ctx,
1554 request,
1555 handler,
1556 expected_hash,
1557 )
1558 .await;
1559 });
1560 }
1561 Handler::Telemetry {
1562 handler, priority, ..
1563 } => {
1564 let handler = handler.clone();
1565 let priority = priority.clone();
1566 tokio::spawn(async move {
1567 dispatch_telemetry(transport, inflight, ctx, request, handler, priority)
1568 .await;
1569 });
1570 }
1571 Handler::RemoteShell { handler, .. } => {
1572 let handler = handler.clone();
1573 tokio::spawn(async move {
1574 dispatch_remote_shell(transport, inflight, ctx, request, handler).await;
1575 });
1576 }
1577 Handler::AgentSession { handler, .. } => {
1578 let handler = handler.clone();
1579 tokio::spawn(async move {
1580 dispatch_agent_session(transport, inflight, ctx, request, handler).await;
1581 });
1582 }
1583 Handler::HttpBridge { handler, .. } => {
1584 let handler = handler.clone();
1585 tokio::spawn(async move {
1586 dispatch_http_bridge(transport, inflight, ctx, request, handler).await;
1587 });
1588 }
1589 }
1590 }));
1591 RpcServer {
1592 transport,
1593 handlers,
1594 inflight,
1595 caller_actor,
1596 enforcer,
1597 }
1598 }
1599
1600 pub fn register_unary(
1601 &self,
1602 method: impl Into<String>,
1603 capability: impl Into<String>,
1604 handler: UnaryHandler,
1605 ) {
1606 self.handlers.lock().unwrap().insert(
1607 method.into(),
1608 Arc::new(Handler::Unary {
1609 capability: capability.into(),
1610 handler,
1611 }),
1612 );
1613 }
1614
1615 pub fn register_stream(
1616 &self,
1617 method: impl Into<String>,
1618 capability: impl Into<String>,
1619 handler: StreamHandler,
1620 ) {
1621 self.handlers.lock().unwrap().insert(
1622 method.into(),
1623 Arc::new(Handler::Stream {
1624 capability: capability.into(),
1625 handler,
1626 }),
1627 );
1628 }
1629
1630 pub fn register_subscribe(
1631 &self,
1632 method: impl Into<String>,
1633 capability: impl Into<String>,
1634 handler: SubscribeHandler,
1635 ) {
1636 self.handlers.lock().unwrap().insert(
1637 method.into(),
1638 Arc::new(Handler::Subscribe {
1639 capability: capability.into(),
1640 handler,
1641 }),
1642 );
1643 }
1644
1645 pub fn register_client_stream(
1646 &self,
1647 method: impl Into<String>,
1648 capability: impl Into<String>,
1649 handler: ClientStreamHandler,
1650 ) {
1651 self.handlers.lock().unwrap().insert(
1652 method.into(),
1653 Arc::new(Handler::ClientStream {
1654 capability: capability.into(),
1655 handler,
1656 }),
1657 );
1658 }
1659
1660 pub fn register_bidi(
1661 &self,
1662 method: impl Into<String>,
1663 capability: impl Into<String>,
1664 handler: BidiHandler,
1665 ) {
1666 self.handlers.lock().unwrap().insert(
1667 method.into(),
1668 Arc::new(Handler::Bidi {
1669 capability: capability.into(),
1670 handler,
1671 }),
1672 );
1673 }
1674
1675 pub fn register_command_channel(
1676 &self,
1677 method: impl Into<String>,
1678 capability: impl Into<String>,
1679 handler: CommandChannelHandler,
1680 initial_credit: u32,
1681 ) {
1682 self.handlers.lock().unwrap().insert(
1683 method.into(),
1684 Arc::new(Handler::CommandChannel {
1685 capability: capability.into(),
1686 handler,
1687 initial_credit,
1688 }),
1689 );
1690 }
1691
1692 pub fn register_bulk_transfer(
1693 &self,
1694 method: impl Into<String>,
1695 capability: impl Into<String>,
1696 handler: BulkTransferHandler,
1697 ) {
1698 self.handlers.lock().unwrap().insert(
1699 method.into(),
1700 Arc::new(Handler::BulkTransfer {
1701 capability: capability.into(),
1702 handler,
1703 }),
1704 );
1705 }
1706
1707 pub fn register_telemetry(
1708 &self,
1709 method: impl Into<String>,
1710 capability: impl Into<String>,
1711 handler: TelemetryHandler,
1712 priority: StreamingPriority,
1713 ) {
1714 self.handlers.lock().unwrap().insert(
1715 method.into(),
1716 Arc::new(Handler::Telemetry {
1717 capability: capability.into(),
1718 handler,
1719 priority,
1720 }),
1721 );
1722 }
1723
1724 pub fn register_remote_shell(
1725 &self,
1726 method: impl Into<String>,
1727 capability: impl Into<String>,
1728 handler: RemoteShellHandler,
1729 ) {
1730 self.handlers.lock().unwrap().insert(
1731 method.into(),
1732 Arc::new(Handler::RemoteShell {
1733 capability: capability.into(),
1734 handler,
1735 }),
1736 );
1737 }
1738
1739 pub fn register_agent_session(
1740 &self,
1741 method: impl Into<String>,
1742 capability: impl Into<String>,
1743 handler: AgentSessionHandler,
1744 ) {
1745 self.handlers.lock().unwrap().insert(
1746 method.into(),
1747 Arc::new(Handler::AgentSession {
1748 capability: capability.into(),
1749 handler,
1750 }),
1751 );
1752 }
1753
1754 pub fn register_http_bridge(
1755 &self,
1756 method: impl Into<String>,
1757 capability: impl Into<String>,
1758 handler: HttpBridgeHandler,
1759 ) {
1760 self.handlers.lock().unwrap().insert(
1761 method.into(),
1762 Arc::new(Handler::HttpBridge {
1763 capability: capability.into(),
1764 handler,
1765 }),
1766 );
1767 }
1768
1769 pub fn caller_actor(&self) -> &str {
1770 &self.caller_actor
1771 }
1772
1773 pub fn transport(&self) -> &Arc<T> {
1774 &self.transport
1775 }
1776
1777 pub fn enforcer(&self) -> &Arc<dyn CapabilityEnforcer> {
1778 &self.enforcer
1779 }
1780}
1781
1782impl<T: RpcTransport + 'static> RpcServer<T> {
1783 pub fn check_authorization(
1788 &self,
1789 caller: &str,
1790 method: &str,
1791 capability: &str,
1792 ) -> CapabilityDecision {
1793 self.enforcer.check(caller, method, capability)
1794 }
1795}
1796
1797async fn dispatch_unary<T: RpcTransport + 'static>(
1800 transport: Arc<T>,
1801 ctx: RpcContext,
1802 request: Value,
1803 handler: UnaryHandler,
1804) {
1805 let call_id = ctx.call_id.clone();
1806 match handler(request, ctx).await {
1807 Ok(v) => {
1808 transport.send(encode_rpc(RpcFrame::RpcResponse {
1809 call_id,
1810 status: ResponseStatus::Ok,
1811 response: Some(v),
1812 error: None,
1813 ext: None,
1814 }));
1815 }
1816 Err(err) => {
1817 transport.send(encode_rpc(RpcFrame::RpcResponse {
1818 call_id,
1819 status: ResponseStatus::Error,
1820 response: None,
1821 error: Some(err),
1822 ext: None,
1823 }));
1824 }
1825 }
1826}
1827
1828async fn run_server_stream_loop<T: RpcTransport + 'static>(
1829 transport: &Arc<T>,
1830 call_id: &str,
1831 method_kind: Option<RpcMethodKind>,
1832 mut rx: mpsc::UnboundedReceiver<Result<Value, RpcError>>,
1833) {
1834 let ext = method_kind.clone().map(|k| RpcFrameExt {
1835 method_kind: Some(k),
1836 ..Default::default()
1837 });
1838 let mut seq: i64 = 0;
1839 while let Some(item) = rx.recv().await {
1840 match item {
1841 Ok(v) => {
1842 transport.send(encode_rpc(RpcFrame::RpcStream {
1843 call_id: call_id.to_owned(),
1844 seq,
1845 more: true,
1846 value: Some(v),
1847 error: None,
1848 ext: ext.clone(),
1849 }));
1850 seq += 1;
1851 }
1852 Err(err) => {
1853 transport.send(encode_rpc(RpcFrame::RpcStream {
1854 call_id: call_id.to_owned(),
1855 seq,
1856 more: false,
1857 value: None,
1858 error: Some(err),
1859 ext: ext.clone(),
1860 }));
1861 return;
1862 }
1863 }
1864 }
1865 transport.send(encode_rpc(RpcFrame::RpcStream {
1866 call_id: call_id.to_owned(),
1867 seq,
1868 more: false,
1869 value: None,
1870 error: None,
1871 ext: ext.clone(),
1872 }));
1873}
1874
1875async fn dispatch_server_stream<T: RpcTransport + 'static>(
1876 transport: Arc<T>,
1877 ctx: RpcContext,
1878 request: Value,
1879 handler: StreamHandler,
1880) {
1881 let call_id = ctx.call_id.clone();
1882 let (tx, rx) = mpsc::unbounded_channel::<Result<Value, RpcError>>();
1883 let fut = handler(request, ctx, tx);
1884 tokio::spawn(fut);
1885 run_server_stream_loop(
1886 &transport,
1887 &call_id,
1888 Some(RpcMethodKind::ServerStreaming),
1889 rx,
1890 )
1891 .await;
1892}
1893
1894async fn dispatch_subscribe<T: RpcTransport + 'static>(
1895 transport: Arc<T>,
1896 ctx: RpcContext,
1897 request: Value,
1898 handler: SubscribeHandler,
1899) {
1900 let call_id = ctx.call_id.clone();
1901 let topic = ctx.subscribe_topic.clone();
1902 transport.send(encode_rpc(RpcFrame::RpcStream {
1905 call_id: call_id.clone(),
1906 seq: -1,
1907 more: true,
1908 value: None,
1909 error: None,
1910 ext: Some(RpcFrameExt {
1911 method_kind: Some(RpcMethodKind::Subscribe),
1912 ack: Some("subscribed".into()),
1913 subscribe_topic: topic,
1914 ..Default::default()
1915 }),
1916 }));
1917 let (tx, rx) = mpsc::unbounded_channel::<Result<Value, RpcError>>();
1918 let fut = handler(request, ctx, tx);
1919 tokio::spawn(fut);
1920 run_server_stream_loop(&transport, &call_id, Some(RpcMethodKind::Subscribe), rx).await;
1921 transport.send(encode_rpc(RpcFrame::RpcStream {
1922 call_id,
1923 seq: -1,
1924 more: false,
1925 value: None,
1926 error: None,
1927 ext: Some(RpcFrameExt {
1928 method_kind: Some(RpcMethodKind::Subscribe),
1929 ack: Some("unsubscribed".into()),
1930 ..Default::default()
1931 }),
1932 }));
1933}
1934
1935fn install_client_pipe(
1939 inflight: &Arc<Mutex<HashMap<String, InflightCall>>>,
1940 call_id: &str,
1941) -> mpsc::UnboundedReceiver<InflightMsg> {
1942 let (tx, rx) = mpsc::unbounded_channel::<InflightMsg>();
1943 let push = Arc::new(move |msg: InflightMsg| {
1944 let _ = tx.send(msg);
1945 });
1946 inflight
1947 .lock()
1948 .unwrap()
1949 .insert(call_id.to_owned(), InflightCall { push });
1950 rx
1951}
1952
1953fn remove_inflight(inflight: &Arc<Mutex<HashMap<String, InflightCall>>>, call_id: &str) {
1954 inflight.lock().unwrap().remove(call_id);
1955}
1956
1957fn pipe_to_value_rx(
1960 mut raw_rx: mpsc::UnboundedReceiver<InflightMsg>,
1961) -> mpsc::UnboundedReceiver<Result<Value, RpcError>> {
1962 let (tx, rx) = mpsc::unbounded_channel::<Result<Value, RpcError>>();
1963 tokio::spawn(async move {
1964 while let Some(msg) = raw_rx.recv().await {
1965 match msg {
1966 InflightMsg::Value(v, _) => {
1967 if tx.send(Ok(v)).is_err() {
1968 return;
1969 }
1970 }
1971 InflightMsg::Done => return,
1972 InflightMsg::Error(err) => {
1973 let _ = tx.send(Err(err));
1974 return;
1975 }
1976 }
1977 }
1978 });
1979 rx
1980}
1981
1982async fn dispatch_client_stream<T: RpcTransport + 'static>(
1983 transport: Arc<T>,
1984 inflight: Arc<Mutex<HashMap<String, InflightCall>>>,
1985 ctx: RpcContext,
1986 request: Value,
1987 handler: ClientStreamHandler,
1988) {
1989 let call_id = ctx.call_id.clone();
1990 let raw_rx = install_client_pipe(&inflight, &call_id);
1991 let value_rx = pipe_to_value_rx(raw_rx);
1992 let result = handler(request, ctx, value_rx).await;
1993 remove_inflight(&inflight, &call_id);
1994 match result {
1995 Ok(v) => transport.send(encode_rpc(RpcFrame::RpcResponse {
1996 call_id,
1997 status: ResponseStatus::Ok,
1998 response: Some(v),
1999 error: None,
2000 ext: None,
2001 })),
2002 Err(err) => transport.send(encode_rpc(RpcFrame::RpcResponse {
2003 call_id,
2004 status: ResponseStatus::Error,
2005 response: None,
2006 error: Some(err),
2007 ext: None,
2008 })),
2009 }
2010}
2011
2012async fn dispatch_bidi<T: RpcTransport + 'static>(
2013 transport: Arc<T>,
2014 inflight: Arc<Mutex<HashMap<String, InflightCall>>>,
2015 ctx: RpcContext,
2016 request: Value,
2017 handler: BidiHandler,
2018) {
2019 let call_id = ctx.call_id.clone();
2020 let raw_rx = install_client_pipe(&inflight, &call_id);
2021 let value_rx = pipe_to_value_rx(raw_rx);
2022 let (server_tx, server_rx) = mpsc::unbounded_channel::<Result<Value, RpcError>>();
2023 let fut = handler(request, ctx, value_rx, server_tx);
2024 tokio::spawn(fut);
2025 run_server_stream_loop(
2026 &transport,
2027 &call_id,
2028 Some(RpcMethodKind::BidiStreaming),
2029 server_rx,
2030 )
2031 .await;
2032 remove_inflight(&inflight, &call_id);
2033}
2034
2035async fn dispatch_command_channel<T: RpcTransport + 'static>(
2036 transport: Arc<T>,
2037 inflight: Arc<Mutex<HashMap<String, InflightCall>>>,
2038 ctx: RpcContext,
2039 request: Value,
2040 handler: CommandChannelHandler,
2041 initial_credit: u32,
2042) {
2043 let call_id = ctx.call_id.clone();
2044 let raw_rx = install_client_pipe(&inflight, &call_id);
2045 let value_rx = pipe_to_value_rx(raw_rx);
2046 transport.send(encode_rpc(RpcFrame::RpcStream {
2048 call_id: call_id.clone(),
2049 seq: -1,
2050 more: true,
2051 value: None,
2052 error: None,
2053 ext: Some(RpcFrameExt {
2054 method_kind: Some(RpcMethodKind::CommandChannel),
2055 credit: Some(initial_credit),
2056 ..Default::default()
2057 }),
2058 }));
2059 let (server_tx, mut server_rx) = mpsc::unbounded_channel::<Result<Value, RpcError>>();
2060 let fut = handler(request, ctx, value_rx, server_tx);
2061 tokio::spawn(fut);
2062 let ext = Some(RpcFrameExt {
2063 method_kind: Some(RpcMethodKind::CommandChannel),
2064 ..Default::default()
2065 });
2066 let mut seq: i64 = 0;
2067 while let Some(item) = server_rx.recv().await {
2068 match item {
2069 Ok(v) => {
2070 transport.send(encode_rpc(RpcFrame::RpcStream {
2071 call_id: call_id.clone(),
2072 seq,
2073 more: true,
2074 value: Some(v),
2075 error: None,
2076 ext: ext.clone(),
2077 }));
2078 seq += 1;
2079 }
2080 Err(err) => {
2081 transport.send(encode_rpc(RpcFrame::RpcStream {
2082 call_id: call_id.clone(),
2083 seq,
2084 more: false,
2085 value: None,
2086 error: Some(err),
2087 ext: ext.clone(),
2088 }));
2089 remove_inflight(&inflight, &call_id);
2090 return;
2091 }
2092 }
2093 }
2094 transport.send(encode_rpc(RpcFrame::RpcStream {
2095 call_id: call_id.clone(),
2096 seq,
2097 more: false,
2098 value: None,
2099 error: None,
2100 ext,
2101 }));
2102 remove_inflight(&inflight, &call_id);
2103}
2104
2105async fn dispatch_bulk_transfer<T: RpcTransport + 'static>(
2106 transport: Arc<T>,
2107 inflight: Arc<Mutex<HashMap<String, InflightCall>>>,
2108 ctx: RpcContext,
2109 request: Value,
2110 handler: BulkTransferHandler,
2111 expected_hash: Option<String>,
2112) {
2113 let call_id = ctx.call_id.clone();
2114 let Some(expected_hash) = expected_hash else {
2115 transport.send(encode_rpc(RpcFrame::RpcResponse {
2116 call_id,
2117 status: ResponseStatus::Error,
2118 response: None,
2119 error: Some(RpcError {
2120 code: RpcErrorCode::InvalidArgument,
2121 message: "bulk-transfer requires ext.bulk.expected_hash".into(),
2122 }),
2123 ext: None,
2124 }));
2125 return;
2126 };
2127 let raw_rx = install_client_pipe(&inflight, &call_id);
2128 let collected: Arc<Mutex<Vec<Vec<u8>>>> = Arc::new(Mutex::new(Vec::new()));
2132 let collected_for_pump = collected.clone();
2133 let (handler_tx, handler_rx) = mpsc::unbounded_channel::<Vec<u8>>();
2134 let mut raw_rx = raw_rx;
2135 tokio::spawn(async move {
2136 while let Some(msg) = raw_rx.recv().await {
2137 match msg {
2138 InflightMsg::Value(v, _) => {
2139 let bytes = decode_bulk_chunk(&v);
2140 collected_for_pump.lock().unwrap().push(bytes.clone());
2141 if handler_tx.send(bytes).is_err() {
2142 return;
2143 }
2144 }
2145 InflightMsg::Done => return,
2146 InflightMsg::Error(_) => return,
2147 }
2148 }
2149 });
2150 let result = handler(request, ctx, handler_rx).await;
2151 remove_inflight(&inflight, &call_id);
2152 let actual_hash = {
2153 let chunks = collected.lock().unwrap();
2154 sha256_of_chunks(&chunks)
2155 };
2156 match result {
2157 Ok(v) => {
2158 if actual_hash != expected_hash {
2159 transport.send(encode_rpc(RpcFrame::RpcResponse {
2160 call_id,
2161 status: ResponseStatus::Error,
2162 response: None,
2163 error: Some(RpcError {
2164 code: RpcErrorCode::InvalidArgument,
2165 message: format!(
2166 "bulk-transfer hash mismatch: got {}, expected {}",
2167 actual_hash, expected_hash
2168 ),
2169 }),
2170 ext: Some(RpcFrameExt {
2171 method_kind: Some(RpcMethodKind::BulkTransfer),
2172 bulk: Some(RpcBulkExt {
2173 expected_hash: Some(actual_hash),
2174 ..Default::default()
2175 }),
2176 ..Default::default()
2177 }),
2178 }));
2179 return;
2180 }
2181 transport.send(encode_rpc(RpcFrame::RpcResponse {
2182 call_id,
2183 status: ResponseStatus::Ok,
2184 response: Some(v),
2185 error: None,
2186 ext: Some(RpcFrameExt {
2187 method_kind: Some(RpcMethodKind::BulkTransfer),
2188 bulk: Some(RpcBulkExt {
2189 expected_hash: Some(actual_hash),
2190 ..Default::default()
2191 }),
2192 ..Default::default()
2193 }),
2194 }));
2195 }
2196 Err(err) => {
2197 transport.send(encode_rpc(RpcFrame::RpcResponse {
2198 call_id,
2199 status: ResponseStatus::Error,
2200 response: None,
2201 error: Some(err),
2202 ext: None,
2203 }));
2204 }
2205 }
2206}
2207
2208async fn dispatch_telemetry<T: RpcTransport + 'static>(
2209 transport: Arc<T>,
2210 inflight: Arc<Mutex<HashMap<String, InflightCall>>>,
2211 ctx: RpcContext,
2212 request: Value,
2213 handler: TelemetryHandler,
2214 priority: StreamingPriority,
2215) {
2216 let call_id = ctx.call_id.clone();
2217 let raw_rx = install_client_pipe(&inflight, &call_id);
2218 let value_rx = pipe_to_value_rx(raw_rx);
2219 let result = handler(request, ctx, priority.clone(), value_rx).await;
2220 remove_inflight(&inflight, &call_id);
2221 let ext = Some(RpcFrameExt {
2222 method_kind: Some(RpcMethodKind::Telemetry),
2223 streaming_priority: Some(priority),
2224 ..Default::default()
2225 });
2226 match result {
2227 Ok(()) => transport.send(encode_rpc(RpcFrame::RpcResponse {
2228 call_id,
2229 status: ResponseStatus::Ok,
2230 response: Some(Value::Null),
2231 error: None,
2232 ext,
2233 })),
2234 Err(err) => transport.send(encode_rpc(RpcFrame::RpcResponse {
2235 call_id,
2236 status: ResponseStatus::Error,
2237 response: None,
2238 error: Some(err),
2239 ext,
2240 })),
2241 }
2242}
2243
2244async fn dispatch_remote_shell<T: RpcTransport + 'static>(
2245 transport: Arc<T>,
2246 inflight: Arc<Mutex<HashMap<String, InflightCall>>>,
2247 ctx: RpcContext,
2248 request: Value,
2249 handler: RemoteShellHandler,
2250) {
2251 let call_id = ctx.call_id.clone();
2252 let raw_rx = install_client_pipe(&inflight, &call_id);
2253 let (stdin_tx, stdin_rx) = mpsc::unbounded_channel::<Vec<u8>>();
2256 let mut raw_rx = raw_rx;
2257 tokio::spawn(async move {
2258 while let Some(msg) = raw_rx.recv().await {
2259 match msg {
2260 InflightMsg::Value(v, ext) => {
2261 let tag = ext
2262 .as_ref()
2263 .and_then(|e| e.shell_stream.clone())
2264 .unwrap_or(RemoteShellStream::Stdin);
2265 if !matches!(tag, RemoteShellStream::Stdin) {
2266 continue;
2269 }
2270 let bytes = decode_bulk_chunk(&v);
2271 if stdin_tx.send(bytes).is_err() {
2272 return;
2273 }
2274 }
2275 InflightMsg::Done | InflightMsg::Error(_) => return,
2276 }
2277 }
2278 });
2279 let (out_tx, mut out_rx) = mpsc::unbounded_channel::<RemoteShellOut>();
2280 let fut = handler(request, ctx, stdin_rx, out_tx);
2281 tokio::spawn(fut);
2282 let mut seq: i64 = 0;
2283 while let Some(frame) = out_rx.recv().await {
2284 let encoded = B64.encode(&frame.data);
2285 transport.send(encode_rpc(RpcFrame::RpcStream {
2286 call_id: call_id.clone(),
2287 seq,
2288 more: true,
2289 value: Some(Value::String(encoded)),
2290 error: None,
2291 ext: Some(RpcFrameExt {
2292 method_kind: Some(RpcMethodKind::RemoteShell),
2293 shell_stream: Some(frame.stream),
2294 ..Default::default()
2295 }),
2296 }));
2297 seq += 1;
2298 }
2299 transport.send(encode_rpc(RpcFrame::RpcStream {
2300 call_id: call_id.clone(),
2301 seq,
2302 more: false,
2303 value: None,
2304 error: None,
2305 ext: Some(RpcFrameExt {
2306 method_kind: Some(RpcMethodKind::RemoteShell),
2307 ..Default::default()
2308 }),
2309 }));
2310 remove_inflight(&inflight, &call_id);
2311}
2312
2313async fn dispatch_agent_session<T: RpcTransport + 'static>(
2314 transport: Arc<T>,
2315 inflight: Arc<Mutex<HashMap<String, InflightCall>>>,
2316 ctx: RpcContext,
2317 request: Value,
2318 handler: AgentSessionHandler,
2319) {
2320 let call_id = ctx.call_id.clone();
2321 let initial_chain = ctx.initial_chain.clone();
2322 let raw_rx = install_client_pipe(&inflight, &call_id);
2323 let (frames_tx, frames_rx) = mpsc::unbounded_channel::<AgentSessionFrame>();
2324 let initial_for_pump = initial_chain.clone();
2325 let mut raw_rx = raw_rx;
2326 tokio::spawn(async move {
2327 while let Some(msg) = raw_rx.recv().await {
2328 match msg {
2329 InflightMsg::Value(v, ext) => {
2330 let chain = ext
2331 .as_ref()
2332 .and_then(|e| e.responsibility_chain.clone())
2333 .unwrap_or_else(|| initial_for_pump.clone());
2334 if frames_tx
2335 .send(AgentSessionFrame {
2336 value: v,
2337 responsibility_chain: chain,
2338 })
2339 .is_err()
2340 {
2341 return;
2342 }
2343 }
2344 InflightMsg::Done | InflightMsg::Error(_) => return,
2345 }
2346 }
2347 });
2348 let (out_tx, mut out_rx) = mpsc::unbounded_channel::<AgentSessionFrame>();
2349 let fut = handler(request, ctx, frames_rx, out_tx);
2350 tokio::spawn(fut);
2351 let mut seq: i64 = 0;
2352 while let Some(frame) = out_rx.recv().await {
2353 transport.send(encode_rpc(RpcFrame::RpcStream {
2354 call_id: call_id.clone(),
2355 seq,
2356 more: true,
2357 value: Some(frame.value),
2358 error: None,
2359 ext: Some(RpcFrameExt {
2360 method_kind: Some(RpcMethodKind::AgentSession),
2361 responsibility_chain: Some(frame.responsibility_chain),
2362 ..Default::default()
2363 }),
2364 }));
2365 seq += 1;
2366 }
2367 transport.send(encode_rpc(RpcFrame::RpcStream {
2368 call_id: call_id.clone(),
2369 seq,
2370 more: false,
2371 value: None,
2372 error: None,
2373 ext: Some(RpcFrameExt {
2374 method_kind: Some(RpcMethodKind::AgentSession),
2375 ..Default::default()
2376 }),
2377 }));
2378 remove_inflight(&inflight, &call_id);
2379}
2380
2381async fn dispatch_http_bridge<T: RpcTransport>(
2382 transport: Arc<T>,
2383 inflight: Arc<Mutex<HashMap<String, InflightCall>>>,
2384 ctx: RpcContext,
2385 request: Value,
2386 handler: HttpBridgeHandler,
2387) {
2388 let call_id = ctx.call_id.clone();
2389 let mut rx = install_client_pipe(&inflight, &call_id);
2390
2391 let (frames_tx, frames_rx) = mpsc::unbounded_channel::<HttpFrame>();
2392 let inflight_inner = inflight.clone();
2393 let call_id_inner = call_id.clone();
2394
2395 tokio::spawn(async move {
2396 while let Some(msg) = rx.recv().await {
2397 match msg {
2398 InflightMsg::Value(v, _) => {
2399 if let Ok(frame) = serde_json::from_value::<HttpFrame>(v) {
2400 let _ = frames_tx.send(frame);
2401 }
2402 }
2403 InflightMsg::Done | InflightMsg::Error(_) => {
2404 remove_inflight(&inflight_inner, &call_id_inner);
2405 return;
2406 }
2407 }
2408 }
2409 });
2410
2411 let (out_tx, mut out_rx) = mpsc::unbounded_channel::<HttpFrame>();
2412 let fut = handler(request, ctx, frames_rx, out_tx);
2413 tokio::spawn(fut);
2414
2415 let mut seq: i64 = 0;
2416 while let Some(frame) = out_rx.recv().await {
2417 transport.send(encode_rpc(RpcFrame::RpcStream {
2418 call_id: call_id.clone(),
2419 seq,
2420 more: true,
2421 value: Some(serde_json::to_value(frame).unwrap_or_default()),
2422 error: None,
2423 ext: Some(RpcFrameExt {
2424 method_kind: Some(RpcMethodKind::HttpBridge),
2425 ..Default::default()
2426 }),
2427 }));
2428 seq += 1;
2429 }
2430 transport.send(encode_rpc(RpcFrame::RpcStream {
2431 call_id: call_id.clone(),
2432 seq,
2433 more: false,
2434 value: None,
2435 error: None,
2436 ext: Some(RpcFrameExt {
2437 method_kind: Some(RpcMethodKind::HttpBridge),
2438 ..Default::default()
2439 }),
2440 }));
2441 remove_inflight(&inflight, &call_id);
2442}
2443
2444#[cfg(test)]
2445mod method_kind_tests {
2446 use super::*;
2447
2448 #[test]
2449 fn rpc_method_kind_serde_kebab_case() {
2450 let kinds = [
2451 RpcMethodKind::Unary,
2452 RpcMethodKind::ServerStreaming,
2453 RpcMethodKind::ClientStreaming,
2454 RpcMethodKind::BidiStreaming,
2455 RpcMethodKind::Subscribe,
2456 RpcMethodKind::CommandChannel,
2457 RpcMethodKind::BulkTransfer,
2458 RpcMethodKind::Telemetry,
2459 RpcMethodKind::RemoteShell,
2460 RpcMethodKind::AgentSession,
2461 RpcMethodKind::HttpBridge,
2462 ];
2463 let json = serde_json::to_string(&kinds).unwrap();
2464 assert!(json.contains("unary"));
2465 assert!(json.contains("server-streaming"));
2466 assert!(json.contains("client-streaming"));
2467 assert!(json.contains("bidi-streaming"));
2468 assert!(json.contains("subscribe"));
2469 assert!(json.contains("command-channel"));
2470 assert!(json.contains("bulk-transfer"));
2471 assert!(json.contains("telemetry"));
2472 assert!(json.contains("remote-shell"));
2473 assert!(json.contains("agent-session"));
2474 assert!(json.contains("http-bridge"));
2475 let parsed: Vec<RpcMethodKind> = serde_json::from_str(&json).unwrap();
2476 assert_eq!(parsed, kinds);
2477 }
2478
2479 #[test]
2480 fn rpc_frame_ext_round_trip() {
2481 let ext = RpcFrameExt {
2482 method_kind: Some(RpcMethodKind::BulkTransfer),
2483 streaming_priority: Some(StreamingPriority::P1),
2484 subscribe_topic: None,
2485 credit: Some(8),
2486 bulk: Some(RpcBulkExt {
2487 chunk_index: Some(3),
2488 total_chunks: Some(4),
2489 expected_hash: Some("sha256:abcd".into()),
2490 }),
2491 shell_stream: Some(RemoteShellStream::Stderr),
2492 responsibility_chain: Some(vec!["tf:actor:human:example.com/alice".into()]),
2493 ack: Some("subscribed".into()),
2494 };
2495 let json = serde_json::to_string(&ext).unwrap();
2496 let parsed: RpcFrameExt = serde_json::from_str(&json).unwrap();
2497 assert_eq!(parsed, ext);
2498 }
2499
2500 #[test]
2501 fn rpc_client_stream_frame_serializes_with_kebab_kind() {
2502 let frame = RpcFrame::RpcClientStream {
2503 call_id: "c1".into(),
2504 seq: 0,
2505 more: true,
2506 value: Some(serde_json::json!("payload")),
2507 error: None,
2508 ext: Some(RpcFrameExt {
2509 method_kind: Some(RpcMethodKind::Telemetry),
2510 streaming_priority: Some(StreamingPriority::P3),
2511 ..Default::default()
2512 }),
2513 };
2514 let json = serde_json::to_value(&frame).unwrap();
2515 assert_eq!(json["kind"], "rpc-client-stream");
2516 assert_eq!(json["ext"]["method_kind"], "telemetry");
2517 assert_eq!(json["ext"]["streaming_priority"], "P3");
2518 }
2519
2520 #[test]
2521 fn proof_event_carries_method_kind_when_set() {
2522 let ev = RpcProofEventStub {
2523 kind: "rpc.call".into(),
2524 method: "blob.upload".into(),
2525 call_id: "c1".into(),
2526 caller: "tf:actor:agent:example.com/x".into(),
2527 result: "ok".into(),
2528 error_code: None,
2529 method_kind: Some(RpcMethodKind::BulkTransfer),
2530 streaming_priority: None,
2531 bulk_hash_verified: Some(true),
2532 };
2533 let json = serde_json::to_value(&ev).unwrap();
2534 assert_eq!(json["method_kind"], "bulk-transfer");
2535 assert_eq!(json["bulk_hash_verified"], true);
2536 }
2537}