Skip to main content

tierkreis_runtime/operations/graph/
checkpoint_client.rs

1use std::collections::HashMap;
2
3use tierkreis_core::graph::Value;
4
5use tierkreis_core::symbol::Label;
6
7use tierkreis_proto::messages::{JobHandle, Message, NodeTrace};
8use tierkreis_proto::protos_gen::v1alpha1::controller as pc;
9use tierkreis_proto::protos_gen::v1alpha1::graph as pg;
10
11type SystemServiceClient =
12    pc::checkpoint_recording_service_client::CheckpointRecordingServiceClient<
13        tonic::transport::Channel,
14    >;
15
16/// Allows reporting of checkpoint information for a specific job to a checkpointing server.
17#[derive(Clone)]
18pub struct CheckpointClient {
19    /// Identifies the specific run of a graph (job uuid and attempt id)
20    pub job_handle: JobHandle,
21    client: SystemServiceClient,
22}
23
24impl CheckpointClient {
25    /// Creates an instance for the specified job by connecting to a hostname and port
26    pub async fn new(job_handle: JobHandle, host: &str, port: u16) -> anyhow::Result<Self> {
27        let uri = format!("http://{}:{}", host, port);
28        let client = SystemServiceClient::connect(uri).await?;
29        Ok(Self { job_handle, client })
30    }
31
32    fn log_error(&self, message: &str, status: &tonic::Status) {
33        let (job_uuid, attempt) = self.job_handle.into_inner();
34
35        tracing::error!(
36            tierkreis.job = job_uuid.to_string(),
37            tierkreis.attempt = attempt,
38            request = message,
39            error = status.to_string(),
40            "Failed to report to checkpoint store.",
41        );
42    }
43
44    pub(super) async fn node_finished(
45        &mut self,
46        node_trace: NodeTrace,
47        outputs: HashMap<Label, Value>,
48    ) {
49        let outputs: pg::StructValue = outputs.into();
50        let (job_uuid, attempt) = self.job_handle.into_inner();
51
52        let request = tonic::Request::new(pc::RecordNodeFinishedRequest {
53            id: Some(node_trace.clone().into()),
54            outputs: outputs.encode_to_vec(),
55            job_id: job_uuid.to_string(),
56            attempt_id: attempt,
57        });
58
59        if let Err(s) = self.client.record_node_finished(request).await {
60            self.log_error(&format!("node {} finished.", node_trace), &s);
61        }
62    }
63
64    pub(super) async fn node_started(
65        &mut self,
66        node_trace: NodeTrace,
67        retry_after_secs: Option<u32>,
68    ) {
69        let (job_uuid, attempt) = self.job_handle.into_inner();
70        let request = tonic::Request::new(pc::RecordNodeRunRequest {
71            id: Some(node_trace.clone().into()),
72            job_id: job_uuid.to_string(),
73            expected_duration_sec: retry_after_secs,
74            attempt_id: attempt,
75        });
76
77        if let Err(s) = self.client.record_node_run(request).await {
78            self.log_error(&format!("node {} started.", node_trace), &s);
79        }
80    }
81
82    pub(super) async fn job_finished(&mut self, error_message: Option<String>) {
83        let (job_uuid, _) = self.job_handle.into_inner();
84        let request = tonic::Request::new(pc::RecordJobFinishedRequest {
85            job_id: job_uuid.to_string(),
86            error_message,
87        });
88
89        if let Err(s) = self.client.record_job_finished(request).await {
90            self.log_error("job finished.", &s);
91        }
92    }
93
94    pub(super) async fn record_graph_output(&mut self, port: Label, value: Value) {
95        let value = pg::Value::from(value);
96        let (job_uuid, _) = self.job_handle.into_inner();
97
98        let request = tonic::Request::new(pc::RecordOutputRequest {
99            job_id: job_uuid.to_string(),
100            label: port.to_string(),
101            value: value.encode_to_vec(),
102        });
103
104        if let Err(s) = self.client.record_output(request).await {
105            self.log_error(&format!("output {} recorded.", port), &s);
106        }
107    }
108}