Skip to main content

temporalio_sdk/
activities.rs

1//! Functionality related to defining and interacting with activities
2//!
3//!
4//! An example of defining an activity:
5//! ```
6//! use std::sync::{
7//!     Arc,
8//!     atomic::{AtomicUsize, Ordering},
9//! };
10//! use temporalio_macros::activities;
11//! use temporalio_sdk::activities::{ActivityContext, ActivityError};
12//!
13//! struct MyActivities {
14//!     counter: AtomicUsize,
15//! }
16//!
17//! #[activities]
18//! impl MyActivities {
19//!     #[activity]
20//!     async fn echo(_ctx: ActivityContext, e: String) -> Result<String, ActivityError> {
21//!         Ok(e)
22//!     }
23//!
24//!     #[activity]
25//!     async fn uses_self(self: Arc<Self>, _ctx: ActivityContext) -> Result<(), ActivityError> {
26//!         self.counter.fetch_add(1, Ordering::Relaxed);
27//!         Ok(())
28//!     }
29//! }
30//!
31//! // If you need to refer to an activity that is defined externally, in a different codebase or
32//! // possibly a differenet language, you can simply leave the function body unimplemented like so:
33//!
34//! struct ExternalActivities;
35//! #[activities]
36//! impl ExternalActivities {
37//!     #[activity(name = "foo")]
38//!     async fn foo(_ctx: ActivityContext, _: String) -> Result<String, ActivityError> {
39//!         unimplemented!()
40//!     }
41//! }
42//! ```
43//!
44//! This will allows you to call the activity from workflow code still, but the actual function
45//! will never be invoked, since you won't have registered it with the worker.
46
47#[doc(inline)]
48pub use temporalio_macros::activities;
49
50use futures_util::{FutureExt, future::BoxFuture};
51use prost_types::{Duration, Timestamp};
52use std::{
53    collections::HashMap,
54    fmt::Debug,
55    sync::Arc,
56    time::{Duration as StdDuration, SystemTime},
57};
58use temporalio_client::Priority;
59use temporalio_common::{
60    ActivityDefinition,
61    data_converters::{
62        DataConverter, GenericPayloadConverter, SerializationContext, SerializationContextData,
63    },
64    error::{ApplicationFailure, FailurePayloads},
65    protos::{
66        coresdk::{ActivityHeartbeat, activity_task},
67        temporal::api::common::v1::{Payload, RetryPolicy, WorkflowExecution},
68        utilities::TryIntoOrNone,
69    },
70};
71use temporalio_sdk_core::Worker as CoreWorker;
72use tokio_util::sync::CancellationToken;
73
74/// Used within activities to get info, heartbeat management etc.
75#[derive(Clone)]
76pub struct ActivityContext {
77    worker: Arc<CoreWorker>,
78    cancellation_token: CancellationToken,
79    heartbeat_details: Vec<Payload>,
80    header_fields: HashMap<String, Payload>,
81    info: ActivityInfo,
82}
83
84impl ActivityContext {
85    /// Construct new Activity Context, returning the context and all arguments to the activity.
86    pub fn new(
87        worker: Arc<CoreWorker>,
88        cancellation_token: CancellationToken,
89        task_queue: String,
90        task_token: Vec<u8>,
91        task: activity_task::Start,
92    ) -> (Self, Vec<Payload>) {
93        let activity_task::Start {
94            workflow_namespace,
95            workflow_type,
96            workflow_execution,
97            activity_id,
98            activity_type,
99            header_fields,
100            input,
101            heartbeat_details,
102            scheduled_time,
103            current_attempt_scheduled_time,
104            started_time,
105            attempt,
106            schedule_to_close_timeout,
107            start_to_close_timeout,
108            heartbeat_timeout,
109            retry_policy,
110            is_local,
111            priority,
112            run_id,
113        } = task;
114        let deadline = calculate_deadline(
115            scheduled_time.as_ref(),
116            started_time.as_ref(),
117            start_to_close_timeout.as_ref(),
118            schedule_to_close_timeout.as_ref(),
119        );
120
121        (
122            ActivityContext {
123                worker,
124                cancellation_token,
125                heartbeat_details,
126                header_fields,
127                info: ActivityInfo {
128                    task_token,
129                    task_queue,
130                    workflow_type,
131                    workflow_namespace,
132                    workflow_execution,
133                    activity_id,
134                    activity_type,
135                    heartbeat_timeout: heartbeat_timeout.try_into_or_none(),
136                    scheduled_time: scheduled_time.try_into_or_none(),
137                    started_time: started_time.try_into_or_none(),
138                    deadline,
139                    attempt,
140                    current_attempt_scheduled_time: current_attempt_scheduled_time
141                        .try_into_or_none(),
142                    retry_policy,
143                    is_local,
144                    priority: priority.map(Into::into).unwrap_or_default(),
145                    run_id: (!run_id.is_empty()).then_some(run_id),
146                },
147            },
148            input,
149        )
150    }
151
152    /// Returns a future the completes if and when the activity this was called inside has been
153    /// cancelled
154    pub async fn cancelled(&self) {
155        self.cancellation_token.clone().cancelled().await
156    }
157
158    /// Returns true if this activity has already been cancelled
159    pub fn is_cancelled(&self) -> bool {
160        self.cancellation_token.is_cancelled()
161    }
162
163    /// Extract heartbeat details from last failed attempt. This is used in combination with retry
164    /// policy.
165    pub fn heartbeat_details(&self) -> &[Payload] {
166        &self.heartbeat_details
167    }
168
169    /// RecordHeartbeat sends heartbeat for the currently executing activity
170    pub fn record_heartbeat(&self, details: Vec<Payload>) {
171        if !self.info.is_local {
172            self.worker.record_activity_heartbeat(ActivityHeartbeat {
173                task_token: self.info.task_token.clone(),
174                details,
175            })
176        }
177    }
178
179    /// Returns activity info of the executing activity
180    pub fn info(&self) -> &ActivityInfo {
181        &self.info
182    }
183
184    /// Get headers attached to this activity
185    pub fn headers(&self) -> &HashMap<String, Payload> {
186        &self.header_fields
187    }
188}
189
190/// Various information about a specific activity attempt.
191#[derive(Clone, Debug)]
192#[non_exhaustive]
193pub struct ActivityInfo {
194    /// An opaque token representing a specific Activity task.
195    pub task_token: Vec<u8>,
196    /// The type of the workflow that invoked this activity.
197    pub workflow_type: String,
198    /// The namespace of the workflow that invoked this activity.
199    pub workflow_namespace: String,
200    /// The execution of the workflow that invoked this activity.
201    pub workflow_execution: Option<WorkflowExecution>,
202    /// The ID of this activity.
203    pub activity_id: String,
204    /// The type of this activity.
205    pub activity_type: String,
206    /// The task queue of this activity.
207    pub task_queue: String,
208    /// The interval within which this activity must heartbeat or be timed out.
209    pub heartbeat_timeout: Option<StdDuration>,
210    /// Time activity was scheduled by a workflow.
211    pub scheduled_time: Option<SystemTime>,
212    /// Time of activity start.
213    pub started_time: Option<SystemTime>,
214    /// Time of activity timeout.
215    pub deadline: Option<SystemTime>,
216    /// Attempt starts from 1, and increase by 1 for every retry, if retry policy is specified.
217    pub attempt: u32,
218    /// Time this attempt at the activity was scheduled.
219    pub current_attempt_scheduled_time: Option<SystemTime>,
220    /// The retry policy for this activity.
221    pub retry_policy: Option<RetryPolicy>,
222    /// Whether or not this is a local activity.
223    pub is_local: bool,
224    /// Priority of this activity. If unset uses [Priority::default].
225    pub priority: Priority,
226    /// Run ID of this activity execution. Only set for standalone activities.
227    pub run_id: Option<String>,
228}
229
230/// Returned as errors from activity functions.
231#[derive(Debug)]
232pub enum ActivityError {
233    /// Return this error to attach application-failure metadata to an activity failure.
234    Application(Box<ApplicationFailure>),
235    /// Return this error to indicate your activity is cancelling
236    Cancelled {
237        /// Optional cancellation details.
238        details: Option<FailurePayloads>,
239    },
240    /// Return this error to indicate that the activity will be completed outside of this activity
241    /// definition, by an external client.
242    WillCompleteAsync,
243}
244
245impl ActivityError {
246    /// Construct a cancelled error without details
247    pub fn cancelled() -> Self {
248        Self::Cancelled { details: None }
249    }
250
251    /// Construct a cancelled error with details that will be converted using the active data
252    /// converter.
253    pub fn cancelled_with_details<T>(details: T) -> Self
254    where
255        T: Into<FailurePayloads>,
256    {
257        Self::Cancelled {
258            details: Some(details.into()),
259        }
260    }
261
262    /// Construct an application activity error.
263    pub fn application(err: ApplicationFailure) -> Self {
264        Self::Application(err.into())
265    }
266}
267
268impl<E> From<E> for ActivityError
269where
270    E: Into<anyhow::Error>,
271{
272    fn from(source: E) -> Self {
273        match source.into().downcast::<ApplicationFailure>() {
274            Ok(application_failure) => Self::Application(Box::new(application_failure)),
275            Err(err) => Self::Application(ApplicationFailure::new(err).into()),
276        }
277    }
278}
279
280/// Deadline calculation.  This is a port of
281/// https://github.com/temporalio/sdk-go/blob/8651550973088f27f678118f997839fb1bb9e62f/internal/activity.go#L225
282fn calculate_deadline(
283    scheduled_time: Option<&Timestamp>,
284    started_time: Option<&Timestamp>,
285    start_to_close_timeout: Option<&Duration>,
286    schedule_to_close_timeout: Option<&Duration>,
287) -> Option<SystemTime> {
288    match (
289        scheduled_time,
290        started_time,
291        start_to_close_timeout,
292        schedule_to_close_timeout,
293    ) {
294        (
295            Some(scheduled),
296            Some(started),
297            Some(start_to_close_timeout),
298            Some(schedule_to_close_timeout),
299        ) => {
300            let scheduled: SystemTime = maybe_convert_timestamp(scheduled)?;
301            let started: SystemTime = maybe_convert_timestamp(started)?;
302            let start_to_close_timeout: StdDuration = (*start_to_close_timeout).try_into().ok()?;
303            let schedule_to_close_timeout: StdDuration =
304                (*schedule_to_close_timeout).try_into().ok()?;
305
306            let start_to_close_deadline: SystemTime =
307                started.checked_add(start_to_close_timeout)?;
308            if schedule_to_close_timeout > StdDuration::ZERO {
309                let schedule_to_close_deadline =
310                    scheduled.checked_add(schedule_to_close_timeout)?;
311                // Minimum of the two deadlines.
312                if schedule_to_close_deadline < start_to_close_deadline {
313                    Some(schedule_to_close_deadline)
314                } else {
315                    Some(start_to_close_deadline)
316                }
317            } else {
318                Some(start_to_close_deadline)
319            }
320        }
321        _ => None,
322    }
323}
324
325/// Helper function lifted from prost_types::Timestamp implementation to prevent double cloning in
326/// error construction
327fn maybe_convert_timestamp(timestamp: &Timestamp) -> Option<SystemTime> {
328    let mut timestamp = *timestamp;
329    timestamp.normalize();
330
331    let system_time = if timestamp.seconds >= 0 {
332        std::time::UNIX_EPOCH.checked_add(StdDuration::from_secs(timestamp.seconds as u64))
333    } else {
334        std::time::UNIX_EPOCH.checked_sub(StdDuration::from_secs((-timestamp.seconds) as u64))
335    };
336
337    system_time.and_then(|system_time| {
338        system_time.checked_add(StdDuration::from_nanos(timestamp.nanos as u64))
339    })
340}
341
342pub(crate) type ActivityInvocation = Arc<
343    dyn Fn(
344            Vec<Payload>,
345            DataConverter,
346            ActivityContext,
347        ) -> BoxFuture<'static, Result<Payload, ActivityError>>
348        + Send
349        + Sync,
350>;
351
352#[doc(hidden)]
353pub trait ActivityImplementer {
354    fn register_all(self: Arc<Self>, defs: &mut ActivityDefinitions);
355}
356
357#[doc(hidden)]
358pub trait ExecutableActivity: ActivityDefinition {
359    type Implementer: ActivityImplementer + Send + Sync + 'static;
360    fn execute(
361        receiver: Option<Arc<Self::Implementer>>,
362        ctx: ActivityContext,
363        input: Self::Input,
364    ) -> BoxFuture<'static, Result<Self::Output, ActivityError>>;
365}
366
367#[doc(hidden)]
368pub trait HasOnlyStaticMethods {}
369
370/// Contains activity registrations in a form ready for execution by workers.
371#[derive(Default, Clone)]
372pub struct ActivityDefinitions {
373    activities: HashMap<&'static str, ActivityInvocation>,
374}
375
376impl ActivityDefinitions {
377    /// Registers all activities on an activity implementer.
378    pub fn register_activities<AI: ActivityImplementer>(&mut self, instance: AI) -> &mut Self {
379        let arcd = Arc::new(instance);
380        AI::register_all(arcd, self);
381        self
382    }
383    /// Registers a specific activitiy.
384    pub fn register_activity<AD>(&mut self, instance: Arc<AD::Implementer>) -> &mut Self
385    where
386        AD: ActivityDefinition + ExecutableActivity,
387        AD::Output: Send + Sync,
388    {
389        self.activities.insert(
390            AD::name(),
391            Arc::new(move |payloads, dc, c| {
392                let instance = instance.clone();
393                let dc = dc.clone();
394                async move {
395                    // Use PayloadConverter (not DataConverter) since the codec is applied
396                    // at the SDK/Core boundary by the visitor, not here.
397                    let pc = dc.payload_converter();
398                    let ctx = SerializationContext {
399                        data: &SerializationContextData::Activity,
400                        converter: pc,
401                    };
402                    let deserialized: AD::Input = pc
403                        .from_payloads(&ctx, payloads)
404                        .map_err(ActivityError::from)?;
405                    let result = AD::execute(Some(instance), c, deserialized).await?;
406                    pc.to_payload(&ctx, &result).map_err(ActivityError::from)
407                }
408                .boxed()
409            }),
410        );
411        self
412    }
413
414    pub(crate) fn is_empty(&self) -> bool {
415        self.activities.is_empty()
416    }
417
418    pub(crate) fn get(&self, act_type: &str) -> Option<ActivityInvocation> {
419        self.activities.get(act_type).cloned()
420    }
421}
422
423impl Debug for ActivityDefinitions {
424    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425        f.debug_struct("ActivityDefinitions")
426            .field("activities", &self.activities.keys())
427            .finish()
428    }
429}
430
431#[cfg(test)]
432mod test {
433    use super::*;
434    use rstest::rstest;
435
436    #[rstest]
437    #[case(true)]
438    #[case(false)]
439    fn activity_error_conversion_is_not_lossy(#[case] non_retryable: bool) {
440        use temporalio_common::protos::temporal::api::enums::v1::ApplicationErrorCategory;
441
442        let original = ApplicationFailure::builder(anyhow::anyhow!("big boom"))
443            .type_name("BigBoom".to_owned())
444            .non_retryable(non_retryable)
445            .next_retry_delay(StdDuration::from_secs(3))
446            .category(ApplicationErrorCategory::Benign)
447            .details("details")
448            .build();
449        let err = ActivityError::from(original);
450        let ActivityError::Application(actual) = err else {
451            panic!("application failure should become app failure")
452        };
453        assert_eq!(actual.type_name(), Some("BigBoom"));
454        assert_eq!(actual.is_non_retryable(), non_retryable);
455        assert_eq!(actual.next_retry_delay(), Some(StdDuration::from_secs(3)));
456        assert_eq!(actual.category(), ApplicationErrorCategory::Benign);
457        assert_eq!(actual.to_string(), "big boom");
458    }
459
460    #[test]
461    fn activity_error_from_special_err_becomes_application() {
462        #[derive(Debug, PartialEq)]
463        struct MyError;
464
465        impl std::error::Error for MyError {}
466        impl std::fmt::Display for MyError {
467            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468                f.write_str("MyError")
469            }
470        }
471
472        let err = ActivityError::from(MyError);
473        let ActivityError::Application(actual) = err else {
474            panic!("expected application failure, got {err:?}")
475        };
476        assert_eq!(actual.to_string(), "MyError");
477    }
478}