1use crate::{
8 data_converters::{PayloadCodec, SerializationContextData},
9 protos::temporal::api::common::v1::{Payload, Payloads},
10};
11use futures::future::BoxFuture;
12
13pub struct PayloadField<'a> {
16 pub path: &'static str,
19 pub data: PayloadFieldData<'a>,
21}
22
23pub enum PayloadFieldData<'a> {
25 Single(&'a mut Payload),
27 Repeated(&'a mut Vec<Payload>),
29 Payloads(&'a mut Payloads),
31}
32
33pub trait AsyncPayloadVisitor {
35 fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()>;
37}
38
39pub trait PayloadVisitable: Send {
42 fn visit_payloads_mut<'a>(
45 &'a mut self,
46 visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
47 ) -> BoxFuture<'a, ()>;
48}
49
50fn is_search_attributes_path(path: &str) -> bool {
53 path.contains("SearchAttributes.indexed_fields")
55}
56
57fn should_encode(path: &str) -> bool {
58 !is_search_attributes_path(path)
59}
60
61pub struct EncodeVisitor<'a> {
63 codec: &'a (dyn PayloadCodec + Send + Sync),
64 context: &'a SerializationContextData,
65}
66
67impl AsyncPayloadVisitor for EncodeVisitor<'_> {
68 fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
69 Box::pin(async move {
70 if !should_encode(field.path) {
71 return;
72 }
73 match field.data {
74 PayloadFieldData::Single(payload) => {
75 let encoded = self
76 .codec
77 .encode(self.context, vec![std::mem::take(payload)])
78 .await;
79 if let Some(p) = encoded.into_iter().next() {
80 *payload = p;
81 }
82 }
83 PayloadFieldData::Repeated(payloads) => {
84 *payloads = self
85 .codec
86 .encode(self.context, std::mem::take(payloads))
87 .await;
88 }
89 PayloadFieldData::Payloads(payloads_msg) => {
90 payloads_msg.payloads = self
91 .codec
92 .encode(self.context, std::mem::take(&mut payloads_msg.payloads))
93 .await;
94 }
95 }
96 })
97 }
98}
99
100pub struct DecodeVisitor<'a> {
102 codec: &'a (dyn PayloadCodec + Send + Sync),
103 context: &'a SerializationContextData,
104}
105
106impl AsyncPayloadVisitor for DecodeVisitor<'_> {
107 fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
108 Box::pin(async move {
109 match field.data {
110 PayloadFieldData::Single(payload) => {
111 let decoded = self
112 .codec
113 .decode(self.context, vec![std::mem::take(payload)])
114 .await;
115 if let Some(p) = decoded.into_iter().next() {
116 *payload = p;
117 }
118 }
119 PayloadFieldData::Repeated(payloads) => {
120 *payloads = self
121 .codec
122 .decode(self.context, std::mem::take(payloads))
123 .await;
124 }
125 PayloadFieldData::Payloads(payloads_msg) => {
126 payloads_msg.payloads = self
127 .codec
128 .decode(self.context, std::mem::take(&mut payloads_msg.payloads))
129 .await;
130 }
131 }
132 })
133 }
134}
135
136pub async fn encode_payloads<M: PayloadVisitable + Send>(
138 msg: &mut M,
139 codec: &(dyn PayloadCodec + Send + Sync),
140 context: &SerializationContextData,
141) {
142 let mut visitor = EncodeVisitor { codec, context };
143 msg.visit_payloads_mut(&mut visitor).await;
144}
145
146pub async fn decode_payloads<M: PayloadVisitable + Send>(
148 msg: &mut M,
149 codec: &(dyn PayloadCodec + Send + Sync),
150 context: &SerializationContextData,
151) {
152 let mut visitor = DecodeVisitor { codec, context };
153 msg.visit_payloads_mut(&mut visitor).await;
154}
155
156impl PayloadVisitable for Payload {
158 fn visit_payloads_mut<'a>(
159 &'a mut self,
160 visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
161 ) -> BoxFuture<'a, ()> {
162 Box::pin(async move {
163 visitor
164 .visit(PayloadField {
165 path: "temporal.api.common.v1.Payload",
166 data: PayloadFieldData::Single(self),
167 })
168 .await;
169 })
170 }
171}
172
173impl PayloadVisitable for Payloads {
175 fn visit_payloads_mut<'a>(
176 &'a mut self,
177 visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
178 ) -> BoxFuture<'a, ()> {
179 Box::pin(async move {
180 visitor
181 .visit(PayloadField {
182 path: "temporal.api.common.v1.Payloads",
183 data: PayloadFieldData::Payloads(self),
184 })
185 .await;
186 })
187 }
188}
189
190include!(concat!(env!("OUT_DIR"), "/payload_visitor_impl.rs"));
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::protos::{
197 coresdk::{
198 activity_result::{
199 ActivityResolution, Success, activity_resolution::Status as ActivityStatus,
200 },
201 workflow_activation::{
202 InitializeWorkflow, ResolveActivity, WorkflowActivation, WorkflowActivationJob,
203 workflow_activation_job::Variant,
204 },
205 workflow_commands::{
206 ContinueAsNewWorkflowExecution, ScheduleActivity, StartChildWorkflowExecution,
207 UpsertWorkflowSearchAttributes, WorkflowCommand,
208 workflow_command::Variant as CmdVariant,
209 },
210 workflow_completion::{
211 WorkflowActivationCompletion, workflow_activation_completion::Status,
212 },
213 },
214 temporal::api::common::v1::{Memo, SearchAttributes},
215 };
216 use futures::FutureExt;
217 use std::collections::HashMap;
218
219 struct MarkingCodec;
220 impl PayloadCodec for MarkingCodec {
221 fn encode(
222 &self,
223 _: &SerializationContextData,
224 payloads: Vec<Payload>,
225 ) -> BoxFuture<'static, Vec<Payload>> {
226 async move {
227 payloads
228 .into_iter()
229 .map(|mut p| {
230 p.metadata.insert("encoded".to_string(), b"true".to_vec());
231 p
232 })
233 .collect()
234 }
235 .boxed()
236 }
237
238 fn decode(
239 &self,
240 _: &SerializationContextData,
241 payloads: Vec<Payload>,
242 ) -> BoxFuture<'static, Vec<Payload>> {
243 async move {
244 payloads
245 .into_iter()
246 .map(|mut p| {
247 p.metadata.insert("decoded".to_string(), b"true".to_vec());
248 p
249 })
250 .collect()
251 }
252 .boxed()
253 }
254 }
255
256 struct PathRecordingVisitor {
257 visited_paths: Vec<String>,
258 }
259 impl PathRecordingVisitor {
260 fn new() -> Self {
261 Self {
262 visited_paths: Vec::new(),
263 }
264 }
265
266 fn paths(&self) -> Vec<String> {
267 self.visited_paths.clone()
268 }
269 }
270
271 impl AsyncPayloadVisitor for PathRecordingVisitor {
272 fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
273 let path = field.path.to_string();
274 self.visited_paths.push(path);
275 async move {}.boxed()
276 }
277 }
278
279 fn make_payload(data: &str) -> Payload {
280 Payload {
281 metadata: HashMap::new(),
282 data: data.as_bytes().to_vec(),
283 external_payloads: vec![],
284 }
285 }
286
287 fn is_encoded(p: &Payload) -> bool {
288 p.metadata.contains_key("encoded")
289 }
290
291 fn is_decoded(p: &Payload) -> bool {
292 p.metadata.contains_key("decoded")
293 }
294
295 #[tokio::test]
296 async fn test_direct_visitor_records_paths() {
297 let mut activation = WorkflowActivation {
298 run_id: "test-run".to_string(),
299 jobs: vec![WorkflowActivationJob {
300 variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
301 workflow_type: "test-workflow".to_string(),
302 arguments: vec![make_payload("input1")],
303 headers: {
304 let mut h = HashMap::new();
305 h.insert("header-key".to_string(), make_payload("header-value"));
306 h
307 },
308 memo: Some(Memo {
309 fields: {
310 let mut m = HashMap::new();
311 m.insert("memo-key".to_string(), make_payload("memo-value"));
312 m
313 },
314 }),
315 ..Default::default()
316 })),
317 }],
318 ..Default::default()
319 };
320
321 let mut visitor = PathRecordingVisitor::new();
322 activation.visit_payloads_mut(&mut visitor).await;
323
324 let paths = visitor.paths();
325 assert!(
326 paths
327 .iter()
328 .any(|p| p.contains("InitializeWorkflow.arguments")),
329 "should visit arguments, got: {:?}",
330 paths
331 );
332 assert!(
333 paths
334 .iter()
335 .any(|p| p.contains("InitializeWorkflow.headers")),
336 "should visit headers, got: {:?}",
337 paths
338 );
339 assert!(
340 paths.iter().any(|p| p.contains("Memo.fields")),
341 "should visit memo fields, got: {:?}",
342 paths
343 );
344 }
345
346 #[tokio::test]
347 async fn test_encode_workflow_activation_completion_with_schedule_activity() {
348 let mut completion = WorkflowActivationCompletion {
349 run_id: "test-run".to_string(),
350 status: Some(Status::Successful(
351 crate::protos::coresdk::workflow_completion::Success {
352 commands: vec![WorkflowCommand {
353 variant: Some(CmdVariant::ScheduleActivity(ScheduleActivity {
354 activity_id: "act-1".to_string(),
355 activity_type: "test-activity".to_string(),
356 arguments: vec![make_payload("arg1"), make_payload("arg2")],
357 headers: {
358 let mut h = HashMap::new();
359 h.insert("header-key".to_string(), make_payload("header-value"));
360 h
361 },
362 ..Default::default()
363 })),
364 user_metadata: None,
365 }],
366 ..Default::default()
367 },
368 )),
369 };
370
371 encode_payloads(
372 &mut completion,
373 &MarkingCodec,
374 &SerializationContextData::Workflow,
375 )
376 .await;
377
378 let status = completion.status.as_ref().unwrap();
379 let Status::Successful(success) = status else {
380 panic!("Expected successful status")
381 };
382 let cmd = &success.commands[0];
383 let CmdVariant::ScheduleActivity(schedule) = cmd.variant.as_ref().unwrap() else {
384 panic!("Expected ScheduleActivity")
385 };
386
387 assert!(is_encoded(&schedule.arguments[0]), "arg1 should be encoded");
388 assert!(is_encoded(&schedule.arguments[1]), "arg2 should be encoded");
389 assert!(
390 is_encoded(schedule.headers.get("header-key").unwrap()),
391 "header should be encoded"
392 );
393 }
394
395 #[tokio::test]
396 async fn test_decode_workflow_activation_with_initialize() {
397 let mut activation = WorkflowActivation {
398 run_id: "test-run".to_string(),
399 jobs: vec![WorkflowActivationJob {
400 variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
401 workflow_type: "test-workflow".to_string(),
402 arguments: vec![make_payload("input1"), make_payload("input2")],
403 headers: {
404 let mut h = HashMap::new();
405 h.insert("header-key".to_string(), make_payload("header-value"));
406 h
407 },
408 ..Default::default()
409 })),
410 }],
411 ..Default::default()
412 };
413
414 decode_payloads(
415 &mut activation,
416 &MarkingCodec,
417 &SerializationContextData::Workflow,
418 )
419 .await;
420
421 let job = &activation.jobs[0];
422 let Variant::InitializeWorkflow(init) = job.variant.as_ref().unwrap() else {
423 panic!("Expected InitializeWorkflow")
424 };
425
426 assert!(is_decoded(&init.arguments[0]), "arg1 should be decoded");
427 assert!(is_decoded(&init.arguments[1]), "arg2 should be decoded");
428 assert!(
429 is_decoded(init.headers.get("header-key").unwrap()),
430 "header should be decoded"
431 );
432 }
433
434 #[tokio::test]
435 async fn test_decode_workflow_activation_with_resolve_activity() {
436 let mut activation = WorkflowActivation {
437 run_id: "test-run".to_string(),
438 jobs: vec![WorkflowActivationJob {
439 variant: Some(Variant::ResolveActivity(ResolveActivity {
440 seq: 1,
441 result: Some(ActivityResolution {
442 status: Some(ActivityStatus::Completed(Success {
443 result: Some(make_payload("activity-result")),
444 })),
445 }),
446 ..Default::default()
447 })),
448 }],
449 ..Default::default()
450 };
451
452 decode_payloads(
453 &mut activation,
454 &MarkingCodec,
455 &SerializationContextData::Workflow,
456 )
457 .await;
458
459 let job = &activation.jobs[0];
460 let Variant::ResolveActivity(resolve) = job.variant.as_ref().unwrap() else {
461 panic!("Expected ResolveActivity")
462 };
463 let ActivityStatus::Completed(success) =
464 resolve.result.as_ref().unwrap().status.as_ref().unwrap()
465 else {
466 panic!("Expected Completed status")
467 };
468
469 assert!(
470 is_decoded(success.result.as_ref().unwrap()),
471 "activity result should be decoded"
472 );
473 }
474
475 #[tokio::test]
476 async fn test_search_attributes_skipped_on_encode() {
477 let mut completion = WorkflowActivationCompletion {
479 run_id: "test-run".to_string(),
480 status: Some(Status::Successful(
481 crate::protos::coresdk::workflow_completion::Success {
482 commands: vec![
483 WorkflowCommand {
485 variant: Some(CmdVariant::UpsertWorkflowSearchAttributes(
486 UpsertWorkflowSearchAttributes {
487 search_attributes: Some(SearchAttributes {
488 indexed_fields: {
489 let mut sa = HashMap::new();
490 sa.insert(
491 "CustomField".to_string(),
492 make_payload("search-value"),
493 );
494 sa
495 },
496 }),
497 },
498 )),
499 user_metadata: None,
500 },
501 WorkflowCommand {
503 variant: Some(CmdVariant::ContinueAsNewWorkflowExecution(
504 ContinueAsNewWorkflowExecution {
505 arguments: vec![make_payload("continue-arg")],
506 search_attributes: Some(SearchAttributes {
507 indexed_fields: {
508 let mut sa = HashMap::new();
509 sa.insert(
510 "CustomField".to_string(),
511 make_payload("continue-search-value"),
512 );
513 sa
514 },
515 }),
516 ..Default::default()
517 },
518 )),
519 user_metadata: None,
520 },
521 WorkflowCommand {
523 variant: Some(CmdVariant::StartChildWorkflowExecution(
524 StartChildWorkflowExecution {
525 seq: 1,
526 workflow_type: "child-workflow".to_string(),
527 input: vec![make_payload("child-arg")],
528 search_attributes: Some(SearchAttributes {
529 indexed_fields: {
530 let mut sa = HashMap::new();
531 sa.insert(
532 "CustomField".to_string(),
533 make_payload("child-search-value"),
534 );
535 sa
536 },
537 }),
538 ..Default::default()
539 },
540 )),
541 user_metadata: None,
542 },
543 ],
544 ..Default::default()
545 },
546 )),
547 };
548
549 encode_payloads(
550 &mut completion,
551 &MarkingCodec,
552 &SerializationContextData::Workflow,
553 )
554 .await;
555
556 let status = completion.status.as_ref().unwrap();
557 let Status::Successful(success) = status else {
558 panic!("Expected successful status")
559 };
560
561 let CmdVariant::UpsertWorkflowSearchAttributes(upsert) =
563 success.commands[0].variant.as_ref().unwrap()
564 else {
565 panic!("Expected UpsertWorkflowSearchAttributes")
566 };
567 let sa = upsert.search_attributes.as_ref().unwrap();
568 assert!(
569 !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
570 "search attributes should NOT be encoded"
571 );
572
573 let CmdVariant::ContinueAsNewWorkflowExecution(continue_as_new) =
575 success.commands[1].variant.as_ref().unwrap()
576 else {
577 panic!("Expected ContinueAsNewWorkflowExecution")
578 };
579 assert!(
580 is_encoded(&continue_as_new.arguments[0]),
581 "arguments should be encoded"
582 );
583 let sa = continue_as_new.search_attributes.as_ref().unwrap();
584 assert!(
585 !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
586 "search attributes should NOT be encoded"
587 );
588
589 let CmdVariant::StartChildWorkflowExecution(start_child) =
591 success.commands[2].variant.as_ref().unwrap()
592 else {
593 panic!("Expected StartChildWorkflowExecution")
594 };
595 assert!(is_encoded(&start_child.input[0]), "input should be encoded");
596 let sa = start_child.search_attributes.as_ref().unwrap();
597 assert!(
598 !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
599 "search attributes should NOT be encoded"
600 );
601 }
602
603 #[tokio::test]
604 async fn test_encode_single_payload() {
605 let mut payload = make_payload("test-data");
606
607 encode_payloads(
608 &mut payload,
609 &MarkingCodec,
610 &SerializationContextData::Workflow,
611 )
612 .await;
613
614 assert!(is_encoded(&payload), "single payload should be encoded");
615 }
616
617 #[tokio::test]
618 async fn test_decode_single_payload() {
619 let mut payload = make_payload("test-data");
620
621 decode_payloads(
622 &mut payload,
623 &MarkingCodec,
624 &SerializationContextData::Workflow,
625 )
626 .await;
627
628 assert!(is_decoded(&payload), "single payload should be decoded");
629 }
630
631 #[tokio::test]
632 async fn test_encode_payloads_message() {
633 let mut payloads = Payloads {
634 payloads: vec![make_payload("p1"), make_payload("p2"), make_payload("p3")],
635 };
636
637 encode_payloads(
638 &mut payloads,
639 &MarkingCodec,
640 &SerializationContextData::Workflow,
641 )
642 .await;
643
644 for (i, p) in payloads.payloads.iter().enumerate() {
645 assert!(is_encoded(p), "payload {} should be encoded", i);
646 }
647 }
648}