Skip to main content

temporalio_common/
payload_visitor.rs

1//! Payload visitor infrastructure for applying codecs to proto messages.
2//!
3//! This module provides a visitor pattern for traversing proto messages and transforming
4//! payload fields. It's used to apply `PayloadCodec` encode/decode operations at the
5//! boundary between the SDK and Core.
6
7use crate::{
8    data_converters::{PayloadCodec, SerializationContextData},
9    protos::temporal::api::common::v1::{Payload, Payloads},
10};
11use futures::future::BoxFuture;
12
13/// Represents a payload field in a proto message.
14/// Payloads within the same field may be processed together by the codec.
15pub struct PayloadField<'a> {
16    /// The fully-qualified field path (e.g.,
17    /// `coresdk.workflow_commands.ScheduleActivity.arguments`)
18    pub path: &'static str,
19    /// The payload data
20    pub data: PayloadFieldData<'a>,
21}
22
23/// The payload data within a field, varying by field type.
24pub enum PayloadFieldData<'a> {
25    /// A singular [Payload] field
26    Single(&'a mut Payload),
27    /// A repeated [Payload] field
28    Repeated(&'a mut Vec<Payload>),
29    /// A [Payloads] message field
30    Payloads(&'a mut Payloads),
31}
32
33/// Async visitor for transforming payload fields.
34pub trait AsyncPayloadVisitor {
35    /// Visit a payload field, potentially transforming it.
36    fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()>;
37}
38
39/// Trait for messages that contain Payload fields (directly or transitively).
40/// Generated via codegen for all relevant proto message types.
41pub trait PayloadVisitable: Send {
42    /// Visit all payload fields in this message.
43    /// The visitor is called once per field, receiving the field's payload(s).
44    fn visit_payloads_mut<'a>(
45        &'a mut self,
46        visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
47    ) -> BoxFuture<'a, ()>;
48}
49
50/// Check if a field path represents search attributes that should not be encoded.
51/// Search attributes must remain server-readable for indexing.
52fn is_search_attributes_path(path: &str) -> bool {
53    // All search attributes go through the SearchAttributes message which has indexed_fields
54    path.contains("SearchAttributes.indexed_fields")
55}
56
57fn should_encode(path: &str) -> bool {
58    !is_search_attributes_path(path)
59}
60
61/// Visitor that encodes payloads using a codec.
62pub struct EncodeVisitor<'a> {
63    codec: &'a (dyn PayloadCodec + Send + Sync),
64    context: &'a SerializationContextData,
65}
66
67impl AsyncPayloadVisitor for EncodeVisitor<'_> {
68    fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
69        Box::pin(async move {
70            if !should_encode(field.path) {
71                return;
72            }
73            match field.data {
74                PayloadFieldData::Single(payload) => {
75                    let encoded = self
76                        .codec
77                        .encode(self.context, vec![std::mem::take(payload)])
78                        .await;
79                    if let Some(p) = encoded.into_iter().next() {
80                        *payload = p;
81                    }
82                }
83                PayloadFieldData::Repeated(payloads) => {
84                    *payloads = self
85                        .codec
86                        .encode(self.context, std::mem::take(payloads))
87                        .await;
88                }
89                PayloadFieldData::Payloads(payloads_msg) => {
90                    payloads_msg.payloads = self
91                        .codec
92                        .encode(self.context, std::mem::take(&mut payloads_msg.payloads))
93                        .await;
94                }
95            }
96        })
97    }
98}
99
100/// Visitor that decodes payloads using a codec.
101pub struct DecodeVisitor<'a> {
102    codec: &'a (dyn PayloadCodec + Send + Sync),
103    context: &'a SerializationContextData,
104}
105
106impl AsyncPayloadVisitor for DecodeVisitor<'_> {
107    fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
108        Box::pin(async move {
109            match field.data {
110                PayloadFieldData::Single(payload) => {
111                    let decoded = self
112                        .codec
113                        .decode(self.context, vec![std::mem::take(payload)])
114                        .await;
115                    if let Some(p) = decoded.into_iter().next() {
116                        *payload = p;
117                    }
118                }
119                PayloadFieldData::Repeated(payloads) => {
120                    *payloads = self
121                        .codec
122                        .decode(self.context, std::mem::take(payloads))
123                        .await;
124                }
125                PayloadFieldData::Payloads(payloads_msg) => {
126                    payloads_msg.payloads = self
127                        .codec
128                        .decode(self.context, std::mem::take(&mut payloads_msg.payloads))
129                        .await;
130                }
131            }
132        })
133    }
134}
135
136/// Encode all payloads in a message using the given codec.
137pub async fn encode_payloads<M: PayloadVisitable + Send>(
138    msg: &mut M,
139    codec: &(dyn PayloadCodec + Send + Sync),
140    context: &SerializationContextData,
141) {
142    let mut visitor = EncodeVisitor { codec, context };
143    msg.visit_payloads_mut(&mut visitor).await;
144}
145
146/// Decode all payloads in a message using the given codec.
147pub async fn decode_payloads<M: PayloadVisitable + Send>(
148    msg: &mut M,
149    codec: &(dyn PayloadCodec + Send + Sync),
150    context: &SerializationContextData,
151) {
152    let mut visitor = DecodeVisitor { codec, context };
153    msg.visit_payloads_mut(&mut visitor).await;
154}
155
156// Manual impl for Payload - visits itself as a single payload
157impl PayloadVisitable for Payload {
158    fn visit_payloads_mut<'a>(
159        &'a mut self,
160        visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
161    ) -> BoxFuture<'a, ()> {
162        Box::pin(async move {
163            visitor
164                .visit(PayloadField {
165                    path: "temporal.api.common.v1.Payload",
166                    data: PayloadFieldData::Single(self),
167                })
168                .await;
169        })
170    }
171}
172
173// Manual impl for Payloads - visits itself as a Payloads field
174impl PayloadVisitable for Payloads {
175    fn visit_payloads_mut<'a>(
176        &'a mut self,
177        visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
178    ) -> BoxFuture<'a, ()> {
179        Box::pin(async move {
180            visitor
181                .visit(PayloadField {
182                    path: "temporal.api.common.v1.Payloads",
183                    data: PayloadFieldData::Payloads(self),
184                })
185                .await;
186        })
187    }
188}
189
190// Include the generated PayloadVisitable implementations
191include!(concat!(env!("OUT_DIR"), "/payload_visitor_impl.rs"));
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::protos::{
197        coresdk::{
198            activity_result::{
199                ActivityResolution, Success, activity_resolution::Status as ActivityStatus,
200            },
201            workflow_activation::{
202                InitializeWorkflow, ResolveActivity, WorkflowActivation, WorkflowActivationJob,
203                workflow_activation_job::Variant,
204            },
205            workflow_commands::{
206                ContinueAsNewWorkflowExecution, ScheduleActivity, StartChildWorkflowExecution,
207                UpsertWorkflowSearchAttributes, WorkflowCommand,
208                workflow_command::Variant as CmdVariant,
209            },
210            workflow_completion::{
211                WorkflowActivationCompletion, workflow_activation_completion::Status,
212            },
213        },
214        temporal::api::common::v1::{Memo, SearchAttributes},
215    };
216    use futures::FutureExt;
217    use std::collections::HashMap;
218
219    struct MarkingCodec;
220    impl PayloadCodec for MarkingCodec {
221        fn encode(
222            &self,
223            _: &SerializationContextData,
224            payloads: Vec<Payload>,
225        ) -> BoxFuture<'static, Vec<Payload>> {
226            async move {
227                payloads
228                    .into_iter()
229                    .map(|mut p| {
230                        p.metadata.insert("encoded".to_string(), b"true".to_vec());
231                        p
232                    })
233                    .collect()
234            }
235            .boxed()
236        }
237
238        fn decode(
239            &self,
240            _: &SerializationContextData,
241            payloads: Vec<Payload>,
242        ) -> BoxFuture<'static, Vec<Payload>> {
243            async move {
244                payloads
245                    .into_iter()
246                    .map(|mut p| {
247                        p.metadata.insert("decoded".to_string(), b"true".to_vec());
248                        p
249                    })
250                    .collect()
251            }
252            .boxed()
253        }
254    }
255
256    struct PathRecordingVisitor {
257        visited_paths: Vec<String>,
258    }
259    impl PathRecordingVisitor {
260        fn new() -> Self {
261            Self {
262                visited_paths: Vec::new(),
263            }
264        }
265
266        fn paths(&self) -> Vec<String> {
267            self.visited_paths.clone()
268        }
269    }
270
271    impl AsyncPayloadVisitor for PathRecordingVisitor {
272        fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
273            let path = field.path.to_string();
274            self.visited_paths.push(path);
275            async move {}.boxed()
276        }
277    }
278
279    fn make_payload(data: &str) -> Payload {
280        Payload {
281            metadata: HashMap::new(),
282            data: data.as_bytes().to_vec(),
283            external_payloads: vec![],
284        }
285    }
286
287    fn is_encoded(p: &Payload) -> bool {
288        p.metadata.contains_key("encoded")
289    }
290
291    fn is_decoded(p: &Payload) -> bool {
292        p.metadata.contains_key("decoded")
293    }
294
295    #[tokio::test]
296    async fn test_direct_visitor_records_paths() {
297        let mut activation = WorkflowActivation {
298            run_id: "test-run".to_string(),
299            jobs: vec![WorkflowActivationJob {
300                variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
301                    workflow_type: "test-workflow".to_string(),
302                    arguments: vec![make_payload("input1")],
303                    headers: {
304                        let mut h = HashMap::new();
305                        h.insert("header-key".to_string(), make_payload("header-value"));
306                        h
307                    },
308                    memo: Some(Memo {
309                        fields: {
310                            let mut m = HashMap::new();
311                            m.insert("memo-key".to_string(), make_payload("memo-value"));
312                            m
313                        },
314                    }),
315                    ..Default::default()
316                })),
317            }],
318            ..Default::default()
319        };
320
321        let mut visitor = PathRecordingVisitor::new();
322        activation.visit_payloads_mut(&mut visitor).await;
323
324        let paths = visitor.paths();
325        assert!(
326            paths
327                .iter()
328                .any(|p| p.contains("InitializeWorkflow.arguments")),
329            "should visit arguments, got: {:?}",
330            paths
331        );
332        assert!(
333            paths
334                .iter()
335                .any(|p| p.contains("InitializeWorkflow.headers")),
336            "should visit headers, got: {:?}",
337            paths
338        );
339        assert!(
340            paths.iter().any(|p| p.contains("Memo.fields")),
341            "should visit memo fields, got: {:?}",
342            paths
343        );
344    }
345
346    #[tokio::test]
347    async fn test_encode_workflow_activation_completion_with_schedule_activity() {
348        let mut completion = WorkflowActivationCompletion {
349            run_id: "test-run".to_string(),
350            status: Some(Status::Successful(
351                crate::protos::coresdk::workflow_completion::Success {
352                    commands: vec![WorkflowCommand {
353                        variant: Some(CmdVariant::ScheduleActivity(ScheduleActivity {
354                            activity_id: "act-1".to_string(),
355                            activity_type: "test-activity".to_string(),
356                            arguments: vec![make_payload("arg1"), make_payload("arg2")],
357                            headers: {
358                                let mut h = HashMap::new();
359                                h.insert("header-key".to_string(), make_payload("header-value"));
360                                h
361                            },
362                            ..Default::default()
363                        })),
364                        user_metadata: None,
365                    }],
366                    ..Default::default()
367                },
368            )),
369        };
370
371        encode_payloads(
372            &mut completion,
373            &MarkingCodec,
374            &SerializationContextData::Workflow,
375        )
376        .await;
377
378        let status = completion.status.as_ref().unwrap();
379        let Status::Successful(success) = status else {
380            panic!("Expected successful status")
381        };
382        let cmd = &success.commands[0];
383        let CmdVariant::ScheduleActivity(schedule) = cmd.variant.as_ref().unwrap() else {
384            panic!("Expected ScheduleActivity")
385        };
386
387        assert!(is_encoded(&schedule.arguments[0]), "arg1 should be encoded");
388        assert!(is_encoded(&schedule.arguments[1]), "arg2 should be encoded");
389        assert!(
390            is_encoded(schedule.headers.get("header-key").unwrap()),
391            "header should be encoded"
392        );
393    }
394
395    #[tokio::test]
396    async fn test_decode_workflow_activation_with_initialize() {
397        let mut activation = WorkflowActivation {
398            run_id: "test-run".to_string(),
399            jobs: vec![WorkflowActivationJob {
400                variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
401                    workflow_type: "test-workflow".to_string(),
402                    arguments: vec![make_payload("input1"), make_payload("input2")],
403                    headers: {
404                        let mut h = HashMap::new();
405                        h.insert("header-key".to_string(), make_payload("header-value"));
406                        h
407                    },
408                    ..Default::default()
409                })),
410            }],
411            ..Default::default()
412        };
413
414        decode_payloads(
415            &mut activation,
416            &MarkingCodec,
417            &SerializationContextData::Workflow,
418        )
419        .await;
420
421        let job = &activation.jobs[0];
422        let Variant::InitializeWorkflow(init) = job.variant.as_ref().unwrap() else {
423            panic!("Expected InitializeWorkflow")
424        };
425
426        assert!(is_decoded(&init.arguments[0]), "arg1 should be decoded");
427        assert!(is_decoded(&init.arguments[1]), "arg2 should be decoded");
428        assert!(
429            is_decoded(init.headers.get("header-key").unwrap()),
430            "header should be decoded"
431        );
432    }
433
434    #[tokio::test]
435    async fn test_decode_workflow_activation_with_resolve_activity() {
436        let mut activation = WorkflowActivation {
437            run_id: "test-run".to_string(),
438            jobs: vec![WorkflowActivationJob {
439                variant: Some(Variant::ResolveActivity(ResolveActivity {
440                    seq: 1,
441                    result: Some(ActivityResolution {
442                        status: Some(ActivityStatus::Completed(Success {
443                            result: Some(make_payload("activity-result")),
444                        })),
445                    }),
446                    ..Default::default()
447                })),
448            }],
449            ..Default::default()
450        };
451
452        decode_payloads(
453            &mut activation,
454            &MarkingCodec,
455            &SerializationContextData::Workflow,
456        )
457        .await;
458
459        let job = &activation.jobs[0];
460        let Variant::ResolveActivity(resolve) = job.variant.as_ref().unwrap() else {
461            panic!("Expected ResolveActivity")
462        };
463        let ActivityStatus::Completed(success) =
464            resolve.result.as_ref().unwrap().status.as_ref().unwrap()
465        else {
466            panic!("Expected Completed status")
467        };
468
469        assert!(
470            is_decoded(success.result.as_ref().unwrap()),
471            "activity result should be decoded"
472        );
473    }
474
475    #[tokio::test]
476    async fn test_search_attributes_skipped_on_encode() {
477        // Test that search attributes are NOT encoded (they must remain server-readable)
478        let mut completion = WorkflowActivationCompletion {
479            run_id: "test-run".to_string(),
480            status: Some(Status::Successful(
481                crate::protos::coresdk::workflow_completion::Success {
482                    commands: vec![
483                        // UpsertWorkflowSearchAttributes command
484                        WorkflowCommand {
485                            variant: Some(CmdVariant::UpsertWorkflowSearchAttributes(
486                                UpsertWorkflowSearchAttributes {
487                                    search_attributes: Some(SearchAttributes {
488                                        indexed_fields: {
489                                            let mut sa = HashMap::new();
490                                            sa.insert(
491                                                "CustomField".to_string(),
492                                                make_payload("search-value"),
493                                            );
494                                            sa
495                                        },
496                                    }),
497                                },
498                            )),
499                            user_metadata: None,
500                        },
501                        // ContinueAsNewWorkflowExecution command
502                        WorkflowCommand {
503                            variant: Some(CmdVariant::ContinueAsNewWorkflowExecution(
504                                ContinueAsNewWorkflowExecution {
505                                    arguments: vec![make_payload("continue-arg")],
506                                    search_attributes: Some(SearchAttributes {
507                                        indexed_fields: {
508                                            let mut sa = HashMap::new();
509                                            sa.insert(
510                                                "CustomField".to_string(),
511                                                make_payload("continue-search-value"),
512                                            );
513                                            sa
514                                        },
515                                    }),
516                                    ..Default::default()
517                                },
518                            )),
519                            user_metadata: None,
520                        },
521                        // StartChildWorkflowExecution command
522                        WorkflowCommand {
523                            variant: Some(CmdVariant::StartChildWorkflowExecution(
524                                StartChildWorkflowExecution {
525                                    seq: 1,
526                                    workflow_type: "child-workflow".to_string(),
527                                    input: vec![make_payload("child-arg")],
528                                    search_attributes: Some(SearchAttributes {
529                                        indexed_fields: {
530                                            let mut sa = HashMap::new();
531                                            sa.insert(
532                                                "CustomField".to_string(),
533                                                make_payload("child-search-value"),
534                                            );
535                                            sa
536                                        },
537                                    }),
538                                    ..Default::default()
539                                },
540                            )),
541                            user_metadata: None,
542                        },
543                    ],
544                    ..Default::default()
545                },
546            )),
547        };
548
549        encode_payloads(
550            &mut completion,
551            &MarkingCodec,
552            &SerializationContextData::Workflow,
553        )
554        .await;
555
556        let status = completion.status.as_ref().unwrap();
557        let Status::Successful(success) = status else {
558            panic!("Expected successful status")
559        };
560
561        // UpsertWorkflowSearchAttributes - search attributes should NOT be encoded
562        let CmdVariant::UpsertWorkflowSearchAttributes(upsert) =
563            success.commands[0].variant.as_ref().unwrap()
564        else {
565            panic!("Expected UpsertWorkflowSearchAttributes")
566        };
567        let sa = upsert.search_attributes.as_ref().unwrap();
568        assert!(
569            !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
570            "search attributes should NOT be encoded"
571        );
572
573        // ContinueAsNewWorkflowExecution - arguments encoded, search attributes NOT
574        let CmdVariant::ContinueAsNewWorkflowExecution(continue_as_new) =
575            success.commands[1].variant.as_ref().unwrap()
576        else {
577            panic!("Expected ContinueAsNewWorkflowExecution")
578        };
579        assert!(
580            is_encoded(&continue_as_new.arguments[0]),
581            "arguments should be encoded"
582        );
583        let sa = continue_as_new.search_attributes.as_ref().unwrap();
584        assert!(
585            !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
586            "search attributes should NOT be encoded"
587        );
588
589        // StartChildWorkflowExecution - input encoded, search attributes NOT
590        let CmdVariant::StartChildWorkflowExecution(start_child) =
591            success.commands[2].variant.as_ref().unwrap()
592        else {
593            panic!("Expected StartChildWorkflowExecution")
594        };
595        assert!(is_encoded(&start_child.input[0]), "input should be encoded");
596        let sa = start_child.search_attributes.as_ref().unwrap();
597        assert!(
598            !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
599            "search attributes should NOT be encoded"
600        );
601    }
602
603    #[tokio::test]
604    async fn test_encode_single_payload() {
605        let mut payload = make_payload("test-data");
606
607        encode_payloads(
608            &mut payload,
609            &MarkingCodec,
610            &SerializationContextData::Workflow,
611        )
612        .await;
613
614        assert!(is_encoded(&payload), "single payload should be encoded");
615    }
616
617    #[tokio::test]
618    async fn test_decode_single_payload() {
619        let mut payload = make_payload("test-data");
620
621        decode_payloads(
622            &mut payload,
623            &MarkingCodec,
624            &SerializationContextData::Workflow,
625        )
626        .await;
627
628        assert!(is_decoded(&payload), "single payload should be decoded");
629    }
630
631    #[tokio::test]
632    async fn test_encode_payloads_message() {
633        let mut payloads = Payloads {
634            payloads: vec![make_payload("p1"), make_payload("p2"), make_payload("p3")],
635        };
636
637        encode_payloads(
638            &mut payloads,
639            &MarkingCodec,
640            &SerializationContextData::Workflow,
641        )
642        .await;
643
644        for (i, p) in payloads.payloads.iter().enumerate() {
645            assert!(is_encoded(p), "payload {} should be encoded", i);
646        }
647    }
648}