1use std::{collections::HashMap, sync::Arc};
20
21use async_trait::async_trait;
22use serde_json::Value;
23use symphony::models::{ComponentResultSpec, ComponentSpec, DeploymentSpec};
24use tracing::{debug, error, trace, warn, Level};
25
26use crate::{
27 communication::{RequestHandler, RpcServer, ServiceInvocationError, UPayload},
28 UAttributes, UPayloadFormat,
29};
30
31pub const METHOD_GET_RESOURCE_ID: u16 = 0x0001;
32pub const METHOD_UPDATE_RESOURCE_ID: u16 = 0x0002;
33pub const METHOD_DELETE_RESOURCE_ID: u16 = 0x0003;
34
35pub async fn register_target_provider_endpoints<R: RpcServer, T: DeploymentTarget + 'static>(
50 rpc_server: &R,
51 deployment_target: Arc<T>,
52) -> Result<(), Box<dyn std::error::Error>> {
53 let get_op = Arc::new(GetOperation {
54 target: deployment_target.clone(),
55 });
56 let apply_op = Arc::new(ApplyOperation {
57 target: deployment_target,
58 });
59 rpc_server
60 .register_endpoint(None, METHOD_GET_RESOURCE_ID, get_op)
61 .await
62 .inspect_err(|e| error!("failed to register Get operation on RPC Server: {e}"))?;
63 rpc_server
64 .register_endpoint(None, METHOD_UPDATE_RESOURCE_ID, apply_op.clone())
65 .await
66 .inspect_err(|e| error!("failed to register Update operation on RPC Server: {e}"))?;
67 rpc_server
68 .register_endpoint(None, METHOD_DELETE_RESOURCE_ID, apply_op)
69 .await
70 .inspect_err(|e| error!("failed to register Delete operation on RPC Server: {e}"))?;
71 Ok(())
72}
73
74#[cfg_attr(any(test, feature = "test-util"), mockall::automock)]
75#[async_trait]
76pub trait DeploymentTarget: Send + Sync {
77 async fn get(
89 &self,
90 components: Vec<ComponentSpec>,
91 deployment_spec: DeploymentSpec,
92 ) -> Result<Vec<ComponentSpec>, Box<dyn std::error::Error>>;
93
94 async fn update(
108 &self,
109 components_to_update: Vec<ComponentSpec>,
110 deployment_spec: DeploymentSpec,
111 ) -> Result<HashMap<String, ComponentResultSpec>, Box<dyn std::error::Error>>;
112
113 async fn delete(
127 &self,
128 components_to_delete: Vec<ComponentSpec>,
129 deployment_spec: DeploymentSpec,
130 ) -> Result<HashMap<String, ComponentResultSpec>, Box<dyn std::error::Error>>;
131}
132
133fn extract_request_data(
134 request_payload: Option<UPayload>,
135) -> Result<Value, ServiceInvocationError> {
136 let Some(req_payload) = request_payload
137 .filter(|req_payload| req_payload.payload_format() == UPayloadFormat::UPAYLOAD_FORMAT_JSON)
138 else {
139 return Err(ServiceInvocationError::InvalidArgument(
140 "request has no JSON payload".to_string(),
141 ));
142 };
143
144 serde_json::from_slice(req_payload.payload().to_vec().as_slice()).map_err(|err| {
145 debug!("failed to deserialize request payload: {:?}", err);
146 ServiceInvocationError::InvalidArgument(
147 "request payload is not a valid UTF-8 string".to_string(),
148 )
149 })
150}
151
152struct GetOperation<T: DeploymentTarget> {
153 target: Arc<T>,
154}
155
156#[async_trait::async_trait]
157impl<T: DeploymentTarget> RequestHandler for GetOperation<T> {
158 async fn handle_request(
160 &self,
161 _resource_id: u16,
162 message_attributes: &UAttributes,
163 request_payload: Option<UPayload>,
164 ) -> Result<Option<UPayload>, ServiceInvocationError> {
165 let source_uri = message_attributes.source_unchecked().to_uri(true);
166 if tracing::enabled!(Level::DEBUG) {
167 debug!(source = source_uri, "processing GET request");
168 }
169 let request_data = extract_request_data(request_payload)?;
170 if tracing::enabled!(Level::TRACE) {
171 trace!(
172 source = source_uri,
173 "payload: {}",
174 serde_json::to_string_pretty(&request_data).expect("failed to serialize Value")
175 );
176 }
177 let deployment_spec: DeploymentSpec =
178 serde_json::from_value(request_data["deployment"].clone()).map_err(|err| {
179 debug!(
180 source = source_uri,
181 "request does not contain DeploymentSpec: {err}"
182 );
183 ServiceInvocationError::InvalidArgument(
184 "request does not contain DeploymentSpec".to_string(),
185 )
186 })?;
187 let component_specs: Vec<ComponentSpec> =
188 serde_json::from_value(request_data["components"].clone()).map_err(|err| {
189 debug!(
190 source = source_uri,
191 "request does not contain ComponentSpec array: {err}"
192 );
193 ServiceInvocationError::InvalidArgument(
194 "request does not contain ComponentSpec array".to_string(),
195 )
196 })?;
197
198 let result = self
199 .target
200 .get(component_specs, deployment_spec)
201 .await
202 .map_err(|err| {
203 warn!(source = source_uri, "error getting component status: {err}");
204 ServiceInvocationError::Internal("failed to get component status".to_string())
205 })?;
206 let serialized_response_data = serde_json::to_vec(&result).map_err(|err| {
207 warn!(
208 source = source_uri,
209 "error serializing ComponentSpec: {err}"
210 );
211 ServiceInvocationError::Internal("failed to create response payload".to_string())
212 })?;
213 if tracing::enabled!(Level::TRACE) {
214 trace!(
215 source = source_uri,
216 "returning response: {}",
217 serde_json::to_string_pretty(&result).expect("failed to serialize Value")
218 );
219 }
220 let response_payload = UPayload::new(
221 serialized_response_data,
222 UPayloadFormat::UPAYLOAD_FORMAT_JSON,
223 );
224 Ok(Some(response_payload))
225 }
226}
227
228struct ApplyOperation<T: DeploymentTarget> {
229 target: Arc<T>,
230}
231
232#[async_trait::async_trait]
233impl<T: DeploymentTarget> RequestHandler for ApplyOperation<T> {
234 async fn handle_request(
235 &self,
236 resource_id: u16,
237 message_attributes: &UAttributes,
238 request_payload: Option<UPayload>,
239 ) -> Result<Option<UPayload>, ServiceInvocationError> {
240 let source_uri = message_attributes.source_unchecked().to_uri(true);
241 let sink_uri = message_attributes.sink_unchecked().to_uri(true);
242 if tracing::enabled!(Level::DEBUG) {
243 debug!(source = source_uri, method = sink_uri, "processing request",);
244 }
245 let request_data = extract_request_data(request_payload)?;
246 if tracing::enabled!(Level::TRACE) {
247 let json =
248 serde_json::to_string_pretty(&request_data).expect("failed to serialize Value");
249 trace!("payload: {}", json);
250 }
251
252 let deployment_spec: DeploymentSpec =
253 serde_json::from_value(request_data["deployment"].clone()).map_err(|err| {
254 debug!(
255 source = source_uri,
256 method = sink_uri,
257 "request does not contain DeploymentSpec: {err}"
258 );
259 ServiceInvocationError::InvalidArgument(
260 "request does not contain DeploymentSpec".to_string(),
261 )
262 })?;
263
264 let affected_components: Vec<ComponentSpec> =
265 serde_json::from_value(request_data["components"].clone()).map_err(|err| {
266 debug!(
267 source = source_uri,
268 method = sink_uri,
269 "request does not contain ComponentSpec array: {err}"
270 );
271 ServiceInvocationError::InvalidArgument(
272 "request does not contain ComponentSpec array".to_string(),
273 )
274 })?;
275
276 let result = match resource_id {
277 METHOD_UPDATE_RESOURCE_ID => self
278 .target
279 .update(affected_components, deployment_spec)
280 .await
281 .map_err(|err| {
282 warn!(
283 source = source_uri,
284 method = sink_uri,
285 "error updating components: {err}"
286 );
287 ServiceInvocationError::Internal("failed to update components".to_string())
288 }),
289 METHOD_DELETE_RESOURCE_ID => self
290 .target
291 .delete(affected_components, deployment_spec)
292 .await
293 .map_err(|err| {
294 warn!(
295 source = source_uri,
296 method = sink_uri,
297 "error deleting components: {err}"
298 );
299 ServiceInvocationError::Internal("failed to delete components".to_string())
300 }),
301 _ => {
302 return Err(ServiceInvocationError::Unimplemented(
303 "no such operation".to_string(),
304 ));
305 }
306 }?;
307
308 let serialized_response_data = serde_json::to_vec(&result).map_err(|err| {
309 warn!(
310 source = source_uri,
311 method = sink_uri,
312 "error serializing HashMap: {err}"
313 );
314 ServiceInvocationError::Internal("failed to create response payload".to_string())
315 })?;
316
317 let response_payload = UPayload::new(
318 serialized_response_data,
319 UPayloadFormat::UPAYLOAD_FORMAT_JSON,
320 );
321 Ok(Some(response_payload))
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use std::time::Duration;
328
329 use serde_json::json;
330 use tokio::sync::Notify;
331
332 use crate::{
333 communication::{
334 CallOptions, InMemoryRpcClient, InMemoryRpcServer, MockRpcServerImpl, RpcClient,
335 },
336 local_transport::LocalTransport,
337 StaticUriProvider, UUri,
338 };
339
340 use super::*;
341
342 #[tokio::test]
343 async fn test_register_target_provider_endpoints_fails() {
344 let mut rpc_server = MockRpcServerImpl::new();
345 rpc_server
346 .expect_do_register_endpoint()
347 .returning(|_, _, _| {
348 Err(crate::communication::RegistrationError::MaxListenersExceeded)
349 });
350 let deployment_target = MockDeploymentTarget::new();
351
352 assert!(
353 register_target_provider_endpoints(&rpc_server, Arc::new(deployment_target))
354 .await
355 .is_err()
356 );
357 }
358
359 #[tokio::test]
360 async fn test_endpoints_delegate_to_deployment_target() {
361 let transport = Arc::new(LocalTransport::default());
362
363 let get_method =
364 UUri::try_from_parts("local_authority", 0xAAA1, 0x01, METHOD_GET_RESOURCE_ID)
365 .expect("failed to create get method URI");
366 let update_method =
367 UUri::try_from_parts("local_authority", 0xAAA1, 0x01, METHOD_UPDATE_RESOURCE_ID)
368 .expect("failed to create update method URI");
369 let delete_method =
370 UUri::try_from_parts("local_authority", 0xAAA1, 0x01, METHOD_DELETE_RESOURCE_ID)
371 .expect("failed to create delete method URI");
372 let uri_provider =
373 StaticUriProvider::try_from(&get_method).expect("failed to create URI provider");
374 let rpc_server = InMemoryRpcServer::new(transport.clone(), Arc::new(uri_provider));
375
376 let mut mock_target = MockDeploymentTarget::default();
377 let get_notify = Arc::new(Notify::new());
378 let cloned_get_notify = get_notify.clone();
379 mock_target.expect_get().returning(move |_, _| {
380 cloned_get_notify.notify_one();
381 Ok(vec![])
382 });
383 let update_notify = Arc::new(Notify::new());
384 let cloned_update_notify = update_notify.clone();
385 mock_target.expect_update().returning(move |_, _| {
386 cloned_update_notify.notify_one();
387 Ok(HashMap::new())
388 });
389 let delete_notify = Arc::new(Notify::new());
390 let cloned_delete_notify = delete_notify.clone();
391 mock_target.expect_delete().returning(move |_, _| {
392 cloned_delete_notify.notify_one();
393 Ok(HashMap::new())
394 });
395 register_target_provider_endpoints(&rpc_server, Arc::new(mock_target))
396 .await
397 .expect("failed to register endpoints");
398
399 let rpc_client = InMemoryRpcClient::new(
400 transport.clone(),
401 Arc::new(StaticUriProvider::new("local_authority", 0xAAA2, 0x01)),
402 )
403 .await
404 .expect("failed to create RPC client");
405
406 let request_payload = json!({
407 "deployment": DeploymentSpec::empty(),
408 "components": []
409 });
410 let payload = UPayload::new(
411 serde_json::to_vec(&request_payload).expect("failed to create request payload"),
412 UPayloadFormat::UPAYLOAD_FORMAT_JSON,
413 );
414 let call_options = CallOptions::for_rpc_request(0x1000, None, None, None);
415 rpc_client
416 .invoke_method(get_method, call_options.clone(), Some(payload.clone()))
417 .await
418 .expect("Get invocation failed");
419 rpc_client
420 .invoke_method(update_method, call_options.clone(), Some(payload.clone()))
421 .await
422 .expect("Update invocation failed");
423 rpc_client
424 .invoke_method(delete_method, call_options, Some(payload))
425 .await
426 .expect("Delete invocation failed");
427
428 tokio::try_join!(
429 tokio::time::timeout(Duration::from_secs(2), get_notify.notified()),
430 tokio::time::timeout(Duration::from_secs(2), update_notify.notified()),
431 tokio::time::timeout(Duration::from_secs(2), delete_notify.notified()),
432 )
433 .expect("failed to receive notification from deployment target");
434 }
435}