1use modrpc::RoleSetup;
2
3use crate::{
4 proto::{
5 RequestClientConfig,
6 RequestGen,
7 RequestLazy,
8 RequestInitState,
9 Response,
10 ResponseLazy,
11 },
12 request_tracker::{RequestTracker, get_request_tracker},
13};
14
15pub use sealed::ResponseWaiter;
16
17mod sealed {
18 use crate::{
19 proto::Response,
20 request_tracker::PendingRequestSubscription,
21 };
22
23 pub struct ResponseWaiter<Resp> {
24 pending_request: PendingRequestSubscription,
25 _phantom: core::marker::PhantomData<Resp>,
26 }
27
28 impl<Resp: for<'d> mproto::Decode<'d>> ResponseWaiter<Resp> {
29 pub async fn wait(self) -> mproto::DecodeResult<Resp> {
30 let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
31 let response_packet = self.pending_request.wait().await;
32
33 let response = mproto::decode_value(&response_packet.as_ref()[packet_header_len..])?;
34 let response: Response<Resp> = response;
35
36 Ok(response.payload)
37 }
38 }
39
40 pub fn wait_response_then_decode<Resp>(pending_request: PendingRequestSubscription)
41 -> ResponseWaiter<Resp>
42 where Resp: for<'d> mproto::Decode<'d>
43 {
44 ResponseWaiter {
45 pending_request,
46 _phantom: core::marker::PhantomData,
47 }
48 }
49
50 }
70
71pub struct RequestClient<Req, Resp> {
72 name: &'static str,
73 rt: modrpc::RuntimeHandle,
74 worker_id: u16,
75 hooks: crate::RequestClientHooks<Req, Resp>,
76 tracker: RequestTracker,
77 spawner: modrpc::RoleSpawner,
78}
79
80impl<
81 Req: mproto::Owned,
82 Resp: mproto::Owned,
83> RequestClient<Req, Resp> {
84 pub async fn call<LikeReq>(&self, payload: LikeReq) -> Resp
85 where LikeReq: mproto::Compatible<Req>
86 {
87 let pending_request = self.tracker.client_start_request();
88 self.hooks.request.send(RequestGen::<LikeReq> {
89 worker: self.worker_id,
90 request_id: pending_request.request_id(),
91 payload,
92 })
93 .await;
94
95 let response_buf = pending_request.await;
96 let header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
97 let response: Response<Resp>
98 = mproto::decode_value(&response_buf.as_ref()[header_len..]).unwrap();
99
100 response.payload
101 }
102
103 pub fn subscribe(
104 &self,
105 request_subscription:
106 impl AsyncFnMut(
107 modrpc::EndpointAddr,
108 Req::Lazy<'_>,
109 ResponseWaiter<Resp>,
110 ) + Clone + 'static,
111 )
112 where
113 {
114 let local_worker_context = self.rt.local_worker_context()
116 .expect("modrpc::RequestClient::subscribe local worker context");
117 modrpc::add_topic_subscription(
118 local_worker_context,
119 "todo",
120 self.hooks.request.plane_id(),
121 self.hooks.request.topic(),
122 );
123
124 let spawner = self.spawner.clone();
125 self.tracker.subscribe(
126 self.hooks.request.plane_id(),
127 self.hooks.request.topic(),
128 Box::new(move |request_packet: modrpc::Packet, pending_request| {
129 let mut request_subscription = request_subscription.clone();
130 spawner.spawn(async move {
131 let header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
132
133 let Ok(header) =
134 mproto::decode_value(&request_packet.as_ref()[..header_len])
135 else { return; };
136 let header: modrpc::TransmitPacket = header;
137
138 let Ok(request) =
139 mproto::decode_value(&request_packet.as_ref()[header_len..])
140 else { return; };
141 let request: RequestLazy<Req> = request;
142
143 let Ok(request_payload) = request.payload() else { return; };
144
145 request_subscription(
146 header.source,
147 request_payload,
148 sealed::wait_response_then_decode(pending_request),
149 )
150 .await;
151 });
152 }),
153 );
154 }
155}
156
157pub struct RequestClientBuilder<Req, Resp> {
158 pub name: &'static str,
159 pub hooks: crate::RequestClientHooks<Req, Resp>,
160 pub stubs: crate::RequestClientStubs<Req, Resp>,
161 pub init: RequestInitState,
162}
163
164impl<
165 Req: mproto::Owned,
166 Resp: mproto::Owned,
167> RequestClientBuilder<Req, Resp> {
168 pub fn new(
169 name: &'static str,
170 hooks: crate::RequestClientHooks<Req, Resp>,
171 stubs: crate::RequestClientStubs<Req, Resp>,
172 _config: &RequestClientConfig,
173 init: RequestInitState,
174 ) -> Self {
175 Self { name, hooks, stubs, init }
176 }
177
178 pub fn create_handle(&self, setup: &RoleSetup) -> RequestClient<Req, Resp> {
179 let worker_id = setup.worker_id();
180 let tracker = get_request_tracker(setup);
181
182 RequestClient {
183 name: self.name,
184 rt: setup.worker_context().rt().clone(),
185 worker_id,
186 hooks: self.hooks.clone(),
187 tracker,
188 spawner: setup.role_spawner().clone(),
189 }
190 }
191
192 pub fn build(
193 self,
194 setup: &RoleSetup,
195 ) {
196 let local_addr = setup.endpoint_addr().endpoint;
197 let local_worker_id = setup.worker_id();
198 let plane_id = setup.plane_id();
199
200 let tracker = get_request_tracker(setup);
206 let request_topic = self.hooks.request.topic();
207 self.stubs.request
208 .inline_untyped(setup, move |source, packet| {
209 tracker.handle_request::<Req>(plane_id, request_topic, source.endpoint, packet.clone());
210 })
211 .local();
212
213 let tracker = get_request_tracker(setup);
214 self.stubs.response.clone()
215 .inline_untyped(setup, move |_source, packet| {
216 tracker.handle_response::<Resp>(plane_id, local_addr, local_worker_id, packet.clone());
217 })
218 .local();
219
220 self.stubs.response.clone()
221 .route_to_worker(setup, move |_source, packet| {
222 let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
223 let Ok(response) =
224 mproto::decode_value::<ResponseLazy<Resp>>(&packet.as_ref()[packet_header_len..])
225 else {
226 return None;
227 };
228 let Ok(requester) = response.requester() else {
229 return None;
230 };
231 let Ok(requester_worker) = response.requester_worker() else {
232 return None;
233 };
234
235 probius::trace_label("route_to_worker");
236 probius::trace_branch(|| {
237 if requester == local_addr && requester_worker != local_worker_id {
238 probius::trace_metric("redirect", 1);
239 Some(modrpc::WorkerId(requester_worker))
242 } else {
243 probius::trace_metric("no-redirect", 1);
244 None
245 }
246 })
247 });
248 }
249}
250
251impl<Req, Resp> Clone for RequestClient<Req, Resp> {
252 fn clone(&self) -> Self {
253 Self {
254 name: self.name,
255 rt: self.rt.clone(),
256 worker_id: self.worker_id,
257 hooks: self.hooks.clone(),
258 tracker: self.tracker.clone(),
259 spawner: self.spawner.clone(),
260 }
261 }
262}
263