1use modrpc::RoleSetup;
2
3use crate::{
4 proto::{
5 RequestClientConfig, RequestGen, RequestInitState, RequestLazy, Response, ResponseLazy,
6 },
7 request_tracker::{RequestTracker, get_request_tracker},
8};
9
10pub use sealed::ResponseWaiter;
11
12mod sealed {
13 use crate::{proto::Response, request_tracker::PendingRequestSubscription};
14
15 pub struct ResponseWaiter<Resp> {
16 pending_request: PendingRequestSubscription,
17 _phantom: core::marker::PhantomData<Resp>,
18 }
19
20 impl<Resp: for<'d> mproto::Decode<'d>> ResponseWaiter<Resp> {
21 pub async fn wait(self) -> mproto::DecodeResult<Resp> {
22 let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
23 let response_packet = self.pending_request.wait().await;
24
25 let response = mproto::decode_value(&response_packet.as_ref()[packet_header_len..])?;
26 let response: Response<Resp> = response;
27
28 Ok(response.payload)
29 }
30 }
31
32 pub fn wait_response_then_decode<Resp>(
33 pending_request: PendingRequestSubscription,
34 ) -> ResponseWaiter<Resp>
35 where
36 Resp: for<'d> mproto::Decode<'d>,
37 {
38 ResponseWaiter {
39 pending_request,
40 _phantom: core::marker::PhantomData,
41 }
42 }
43
44 }
64
65pub struct RequestClient<Req, Resp> {
66 name: &'static str,
67 rt: modrpc::RuntimeHandle,
68 worker_id: u16,
69 hooks: crate::RequestClientHooks<Req, Resp>,
70 tracker: RequestTracker,
71 spawner: modrpc::RoleSpawner,
72}
73
74impl<Req: mproto::Owned, Resp: mproto::Owned> RequestClient<Req, Resp> {
75 pub async fn call<LikeReq>(&self, payload: LikeReq) -> Resp
76 where
77 LikeReq: mproto::Compatible<Req>,
78 {
79 let pending_request = self.tracker.client_start_request();
80 self.hooks
81 .request
82 .send(RequestGen::<LikeReq> {
83 worker: self.worker_id,
84 request_id: pending_request.request_id(),
85 payload,
86 })
87 .await;
88
89 let response_buf = pending_request.await;
90 let header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
91 let response: Response<Resp> =
92 mproto::decode_value(&response_buf.as_ref()[header_len..]).unwrap();
93
94 response.payload
95 }
96
97 pub fn subscribe(
98 &self,
99 request_subscription: impl AsyncFnMut(modrpc::EndpointAddr, Req::Lazy<'_>, ResponseWaiter<Resp>)
100 + Clone
101 + 'static,
102 ) {
103 let local_worker_context = self
105 .rt
106 .local_worker_context()
107 .expect("modrpc::RequestClient::subscribe local worker context");
108 modrpc::add_topic_subscription(
109 local_worker_context,
110 "todo",
111 self.hooks.request.plane_id(),
112 self.hooks.request.topic(),
113 );
114
115 let spawner = self.spawner.clone();
116 self.tracker.subscribe(
117 self.hooks.request.plane_id(),
118 self.hooks.request.topic(),
119 Box::new(move |request_packet: modrpc::Packet, pending_request| {
120 let mut request_subscription = request_subscription.clone();
121 spawner.spawn(async move {
122 let header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
123
124 let Ok(header) = mproto::decode_value(&request_packet.as_ref()[..header_len])
125 else {
126 return;
127 };
128 let header: modrpc::TransmitPacket = header;
129
130 let Ok(request) = mproto::decode_value(&request_packet.as_ref()[header_len..])
131 else {
132 return;
133 };
134 let request: RequestLazy<Req> = request;
135
136 let Ok(request_payload) = request.payload() else {
137 return;
138 };
139
140 request_subscription(
141 header.source,
142 request_payload,
143 sealed::wait_response_then_decode(pending_request),
144 )
145 .await;
146 });
147 }),
148 );
149 }
150}
151
152pub struct RequestClientBuilder<Req, Resp> {
153 pub name: &'static str,
154 pub hooks: crate::RequestClientHooks<Req, Resp>,
155 pub stubs: crate::RequestClientStubs<Req, Resp>,
156 pub init: RequestInitState,
157}
158
159impl<Req: mproto::Owned, Resp: mproto::Owned> RequestClientBuilder<Req, Resp> {
160 pub fn new(
161 name: &'static str,
162 hooks: crate::RequestClientHooks<Req, Resp>,
163 stubs: crate::RequestClientStubs<Req, Resp>,
164 _config: &RequestClientConfig,
165 init: RequestInitState,
166 ) -> Self {
167 Self {
168 name,
169 hooks,
170 stubs,
171 init,
172 }
173 }
174
175 pub fn create_handle(&self, setup: &RoleSetup) -> RequestClient<Req, Resp> {
176 let worker_id = setup.worker_id();
177 let tracker = get_request_tracker(setup);
178
179 RequestClient {
180 name: self.name,
181 rt: setup.worker_context().rt().clone(),
182 worker_id,
183 hooks: self.hooks.clone(),
184 tracker,
185 spawner: setup.role_spawner().clone(),
186 }
187 }
188
189 pub fn build(self, setup: &RoleSetup) {
190 let local_addr = setup.endpoint_addr().endpoint;
191 let local_worker_id = setup.worker_id();
192 let plane_id = setup.plane_id();
193
194 let tracker = get_request_tracker(setup);
200 let request_topic = self.hooks.request.topic();
201 self.stubs
202 .request
203 .inline_untyped(setup, move |source, packet| {
204 tracker.handle_request::<Req>(
205 plane_id,
206 request_topic,
207 source.endpoint,
208 packet.clone(),
209 );
210 })
211 .local();
212
213 let tracker = get_request_tracker(setup);
214 self.stubs
215 .response
216 .clone()
217 .inline_untyped(setup, move |_source, packet| {
218 tracker.handle_response::<Resp>(
219 plane_id,
220 local_addr,
221 local_worker_id,
222 packet.clone(),
223 );
224 })
225 .local();
226
227 self.stubs
228 .response
229 .clone()
230 .route_to_worker(setup, move |_source, packet| {
231 let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
232 let Ok(response) = mproto::decode_value::<ResponseLazy<Resp>>(
233 &packet.as_ref()[packet_header_len..],
234 ) else {
235 return None;
236 };
237 let Ok(requester) = response.requester() else {
238 return None;
239 };
240 let Ok(requester_worker) = response.requester_worker() else {
241 return None;
242 };
243
244 probius::trace_label("route_to_worker");
245 probius::trace_branch(|| {
246 if requester == local_addr && requester_worker != local_worker_id {
247 probius::trace_metric("redirect", 1);
248 Some(modrpc::WorkerId(requester_worker))
251 } else {
252 probius::trace_metric("no-redirect", 1);
253 None
254 }
255 })
256 });
257 }
258}
259
260impl<Req, Resp> Clone for RequestClient<Req, Resp> {
261 fn clone(&self) -> Self {
262 Self {
263 name: self.name,
264 rt: self.rt.clone(),
265 worker_id: self.worker_id,
266 hooks: self.hooks.clone(),
267 tracker: self.tracker.clone(),
268 spawner: self.spawner.clone(),
269 }
270 }
271}