Skip to main content

tierkreis_runtime/workers/
external.rs

1use super::{FunctionWorker, RuntimeWorker, Worker};
2use crate::operations::RuntimeOperation;
3use crate::RunGraphError;
4use anyhow::{anyhow, Context};
5use hyper_util::rt::TokioIo;
6use std::collections::HashMap;
7use std::ffi::OsStr;
8use std::process::Stdio;
9use std::sync::Arc;
10use tierkreis_core::graph::{Graph, Value};
11use tierkreis_core::namespace::Signature;
12use tierkreis_core::prelude::TryInto;
13use tierkreis_core::symbol::{FunctionName, Label, Location};
14use tierkreis_proto::messages::{
15    self as proto_messages, Callback, JobHandle, NodeTrace, RunGraphResponse,
16};
17use tierkreis_proto::protos_gen::v1alpha1::runtime as proto_runtime;
18use tierkreis_proto::protos_gen::v1alpha1::signature as proto_sig;
19use tierkreis_proto::protos_gen::v1alpha1::worker as proto_worker;
20use tierkreis_proto::ConvertError;
21use tokio::io::{AsyncBufReadExt, BufReader};
22use tokio::net::UnixStream;
23use tokio::process::Child;
24use tokio::process::Command;
25use tonic::async_trait;
26use tonic::service::Interceptor;
27use tonic::transport::{Channel, Endpoint, Uri};
28use tracing::instrument;
29/// Worker for functions implemented in external processes.
30#[derive(Clone)]
31struct Connection {
32    channel: Channel,
33    interceptor: ClientInterceptor,
34
35    /// Keep a handle to the child process so that it is not killed.
36    #[allow(dead_code)]
37    process: Option<Arc<Child>>,
38}
39
40/// Connection to an external worker (perhaps a Tierkreis server),
41/// that can execute `run_function` for some defined+reported set of functions.
42#[derive(Clone)]
43pub struct ExternalWorker {
44    connection: Connection,
45    /// Cached signature
46    signature: Signature,
47}
48
49/// Forwards `run_graph` requests; used to route these from a worker back to the
50/// server (ancestor) which issued the `run_function` request (to that worker).
51pub struct CallbackForwarder(Connection);
52
53/// A link in a forwarding chain - forwards `run_function` calls to a destination,
54/// passing a [Location] which uniquely identifies the chain and the next link therein
55/// (typically by the location being unique to each source/caller).
56// TODO consider moving Location into Connection
57#[derive(Clone)]
58pub struct FunctionForwarder(Connection, Location);
59
60/// Used for forwarding "escape hatch" calls (a `run_function` sent from a child runtime
61/// back to an ancestor).
62#[derive(Clone)]
63pub struct EscapeHatch {
64    /// The next link "above" this runtime, or None if
65    /// the forwarding chain ends here (nowhere further to escape to).
66    parent: Option<FunctionForwarder>,
67    /// Tells other workers ("below") how to access the chain here,
68    /// i.e. becomes the next link for such children.
69    this_runtime: Callback,
70}
71
72/// Placeholder for future code that adds authentication data to a client request
73#[derive(Clone, Debug, Default)]
74pub enum AuthInjector {
75    /// No authentication data to add
76    #[default]
77    NoAuth,
78    /// token and key metadata to add
79    #[allow(missing_docs)]
80    TokenKey { token: String, key: String },
81}
82
83/// Placeholder [Interceptor] that just adds some authentication (meta)data.
84#[derive(Clone, Debug, Default)]
85pub struct ClientInterceptor {
86    /// Authentication data to add.
87    pub auther: AuthInjector,
88}
89
90impl ClientInterceptor {
91    /// Create a new instance to add specified authentication metadata
92    pub fn new(auther: AuthInjector) -> Self {
93        Self { auther }
94    }
95}
96
97impl Interceptor for ClientInterceptor {
98    fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
99        match &self.auther {
100            AuthInjector::NoAuth => Ok(req),
101            AuthInjector::TokenKey { token, key } => {
102                req.metadata_mut().insert("token", token.parse().unwrap());
103                req.metadata_mut().insert("key", key.parse().unwrap());
104                Ok(req)
105            }
106        }
107    }
108}
109
110impl Connection {
111    /// Spawn a process given an external command, and connect. See [FunctionForwarder::new_spawn],
112    /// although we do not request signature here.
113    #[instrument(
114        name="starting external worker process"
115        skip(command, interceptor),
116        fields(command = &*command.as_ref().to_string_lossy()))]
117    async fn new_spawn(
118        command: impl AsRef<OsStr>,
119        interceptor: ClientInterceptor,
120    ) -> anyhow::Result<Self> {
121        // Spawn the python subprocess
122        let mut process = Command::new(command.as_ref())
123            .stdout(Stdio::piped())
124            .stderr(Stdio::inherit())
125            .kill_on_drop(true)
126            .spawn()
127            .with_context(|| "failed to spawn worker process")?;
128
129        tracing::debug!("Spawned external worker process.");
130
131        // Wait for the child to write the path of the unix socket to stdout. If the worker process
132        // stops this will also fail with an error. After the line is received we drop the stdout
133        // handle, so the pipe is closed as we no longer need it.
134        // TODO: Maybe a timeout for this?
135        let socket_path = BufReader::new(process.stdout.take().unwrap())
136            .lines()
137            .next_line()
138            .await
139            .with_context(|| "failed to receive socket path from worker process")
140            .unwrap()
141            .ok_or_else(|| anyhow!("worker process did not send socket path"))
142            .unwrap();
143
144        tracing::debug!("Received unix domain socket path from python worker.");
145        // Open a communication channel for the socket path. We need to supply some uri here since the
146        // API is weird, but it is arbitrary and not used.
147        let channel = Endpoint::new("http://[::]:50051")?
148            .connect_with_connector(tower::service_fn({
149                let socket_path = socket_path.clone();
150                move |_: Uri| {
151                    let socket_path = socket_path.clone();
152                    async move {
153                        Ok::<_, std::io::Error>(TokioIo::new(
154                            UnixStream::connect(socket_path).await?,
155                        ))
156                    }
157                }
158            }))
159            .await
160            .with_context(|| "failed to connect to the socket of the python worker")?;
161
162        tracing::info!("Connected to spawned worker's grpc server.");
163
164        Ok(Self {
165            channel,
166            process: Some(Arc::new(process)),
167            interceptor,
168        })
169    }
170
171    #[instrument(
172        name="connecting to external worker"
173        skip(uri, interceptor),
174        fields(uri = uri.to_string().as_str())
175    )]
176    async fn new_connect(uri: &Uri, interceptor: ClientInterceptor) -> anyhow::Result<Self> {
177        let channel = Endpoint::from(uri.clone())
178            .connect()
179            .await
180            .with_context(|| format!("failed to connect to worker at uri {}", uri))?;
181
182        tracing::info!(
183            uri = uri.to_string().as_str(),
184            "Connected to the worker's grpc server."
185        );
186
187        Ok(Self {
188            channel,
189            process: None,
190            interceptor,
191        })
192    }
193
194    #[instrument(
195        name = "run external function",
196        fields(otel.kind = "client", _node_trace = node_trace.as_ref().map_or(String::new(), ToString::to_string)),
197        skip(self, inputs, node_trace, callback)
198    )]
199    async fn run_function(
200        &self,
201        function: FunctionName,
202        inputs: HashMap<Label, Value>,
203        loc: Location,
204        callback: Callback,
205        node_trace: Option<NodeTrace>,
206        job_handle: Option<JobHandle>,
207    ) -> anyhow::Result<HashMap<Label, Value>> {
208        let mut request = proto_messages::RunFunctionRequest::new(function, inputs, loc, callback);
209        let node_str = node_trace
210            .as_ref()
211            .map(|nt| nt.to_string())
212            .unwrap_or_else(|| "None".to_string());
213        if let Some(node_trace) = node_trace {
214            request.set_node_trace(node_trace);
215        }
216
217        if let Some(job_id) = job_handle {
218            request.set_job_id(job_id.into_inner().0);
219        }
220
221        let mut worker_client = proto_worker::worker_client::WorkerClient::with_interceptor(
222            self.channel.clone(),
223            self.interceptor.clone(),
224        );
225
226        let response = worker_client.run_function(inject_trace(request)).await;
227        if let Err(e) = &response {
228            tracing::info!(
229                error = e.to_string(),
230                node = node_str,
231                "failed to run function in worker."
232            );
233        }
234        let response = response
235            .context("failed to run function in worker")?
236            .into_inner();
237
238        tracing::debug!("received response from worker");
239
240        let outputs = TryInto::try_into(response.outputs.unwrap_or_default())
241            .context("failed to parse function outputs from worker response")?;
242
243        Ok(outputs)
244    }
245
246    async fn run_graph(
247        &self,
248        graph: Graph,
249        inputs: HashMap<Label, Value>,
250        loc: Location,
251        type_check: bool,
252        escape: Option<Callback>,
253    ) -> Result<HashMap<Label, Value>, RunGraphError> {
254        let request = proto_messages::RunGraphRequest {
255            graph,
256            inputs,
257            loc,
258            type_check,
259            escape,
260        };
261
262        let mut runtime_client = proto_runtime::runtime_client::RuntimeClient::with_interceptor(
263            self.channel.clone(),
264            self.interceptor.clone(),
265        );
266
267        let resp: RunGraphResponse = TryInto::try_into(
268            runtime_client
269                .run_graph(inject_trace(request))
270                .await
271                .context("failed to run graph in scope")?
272                .into_inner(),
273        )
274        .map_err(|e: ConvertError| anyhow!(e))?;
275
276        match resp {
277            RunGraphResponse::Success(x) => Ok(x),
278            RunGraphResponse::TypeError(x) => Err(RunGraphError::TypeError(x)),
279            RunGraphResponse::Error(x) => Err(RunGraphError::RuntimeError(Arc::new(anyhow!(x)))),
280        }
281    }
282
283    async fn get_signature(
284        &self,
285        loc: Location,
286    ) -> anyhow::Result<proto_sig::ListFunctionsResponse> {
287        let mut signature_client = proto_sig::signature_client::SignatureClient::with_interceptor(
288            self.channel.clone(),
289            self.interceptor.clone(),
290        );
291
292        // Retrieve the worker's signature
293        tracing::debug!("Retrieving signature for worker.");
294        let signature = signature_client
295            .list_functions(inject_trace(proto_messages::ListFunctionsRequest { loc }))
296            .await
297            .context("failed to retrieve signature from worker")?
298            .into_inner();
299
300        Ok(signature)
301    }
302
303    async fn type_check(
304        &self,
305        value: Value,
306        loc: Location,
307    ) -> anyhow::Result<proto_sig::InferTypeResponse> {
308        let mut tc_client = proto_sig::type_inference_client::TypeInferenceClient::with_interceptor(
309            self.channel.clone(),
310            self.interceptor.clone(),
311        );
312
313        tracing::debug!("Type checking on worker.");
314        let resp = tc_client
315            .infer_type(inject_trace(proto_messages::InferTypeRequest {
316                value,
317                loc,
318            }))
319            .await
320            .context("failed to infer type on worker")?
321            .into_inner();
322
323        Ok(resp)
324    }
325
326    fn spawn(&self, function: &FunctionName, loc: &Location) -> RuntimeOperation {
327        let function = function.clone();
328        let loc = loc.clone();
329        let this = self.clone();
330        RuntimeOperation::new_fn_async(move |inputs, context| async move {
331            let _ = &context;
332            this.run_function(
333                function,
334                inputs,
335                loc,
336                context.callback,
337                Some(context.graph_trace.as_node_trace()?),
338                context.checkpoint_client.map(|ch| ch.job_handle),
339            )
340            .await
341        })
342    }
343}
344
345#[async_trait]
346impl RuntimeWorker for Connection {
347    async fn execute_graph(
348        &self,
349        graph: Graph,
350        inputs: HashMap<Label, Value>,
351        location: Location,
352        type_check: bool,
353        escape: Option<Callback>,
354    ) -> Result<HashMap<Label, Value>, RunGraphError> {
355        self.run_graph(graph, inputs, location, type_check, escape)
356            .await
357    }
358
359    fn spawn_graph(&self, graph: Graph, location: &Location) -> RuntimeOperation {
360        let this = self.clone();
361        let loc = location.clone();
362        RuntimeOperation::new_fn_async(move |inputs, context| async move {
363            let _ = &context;
364            match this
365                .run_graph(graph, inputs, loc, false, Some(context.escape.this_runtime))
366                .await
367            {
368                Ok(x) => Ok(x),
369                Err(RunGraphError::TypeError(x)) => Err(anyhow!("Type errors: {}", x)),
370                Err(RunGraphError::RuntimeError(x)) => Err(anyhow!(x)),
371            }
372        })
373    }
374
375    async fn infer_type(
376        &self,
377        to_check: Value,
378        location: Location,
379    ) -> anyhow::Result<proto_sig::InferTypeResponse> {
380        self.type_check(to_check, location).await
381    }
382}
383
384impl ExternalWorker {
385    async fn from_connection(connection: Connection) -> anyhow::Result<Self> {
386        let signature = TryInto::try_into(connection.get_signature(Location::local()).await?)?;
387        Ok(Self {
388            connection,
389            signature,
390        })
391    }
392
393    /// Spawn a process that runs the specified command, which must print out the socket address
394    /// for a GRPC server as the first line of stdout. (We expect the command to start a new gRPC
395    /// server in its own process.)
396    ///
397    /// The socket address could be a hostname + port, or path to a UNIX filesystem socket.
398    ///
399    /// `new_spawn` then connects to the worker's gRPC server and requests its signature.
400    pub async fn new_spawn(
401        command: impl AsRef<OsStr>,
402        interceptor: ClientInterceptor,
403    ) -> anyhow::Result<Self> {
404        let connection = Connection::new_spawn(command, interceptor).await?;
405        Self::from_connection(connection).await
406    }
407
408    /// Connects to an already-running external process given its Uri
409    pub async fn new_connect(uri: &Uri, interceptor: ClientInterceptor) -> anyhow::Result<Self> {
410        let connection = Connection::new_connect(uri, interceptor).await?;
411        Self::from_connection(connection).await
412    }
413}
414
415#[async_trait]
416impl FunctionWorker for ExternalWorker {
417    fn spawn(&self, function: &FunctionName, loc: &Location) -> RuntimeOperation {
418        self.connection.spawn(function, loc)
419    }
420}
421
422#[async_trait]
423impl Worker for ExternalWorker {
424    async fn signature(
425        &self,
426        location: Location,
427    ) -> anyhow::Result<proto_sig::ListFunctionsResponse> {
428        if location == Location::local() {
429            Ok(self.signature.clone().into())
430        } else {
431            self.connection.get_signature(location).await
432        }
433    }
434
435    fn to_runtime_worker(&self) -> Option<&dyn RuntimeWorker> {
436        if self.signature.scopes.is_empty() {
437            None
438        } else {
439            Some(&self.connection)
440        }
441    }
442}
443
444impl CallbackForwarder {
445    // It would be easy to implement new_spawn, but it doesn't really make sense
446    // - we are forwarding callbacks from a worker back to the already-existing
447    // server that issued the original request to that worker.
448
449    /// Connects to a remote Tierkreis server in order to forward callbacks onto it
450    pub async fn new_connect(uri: &Uri, interceptor: ClientInterceptor) -> anyhow::Result<Self> {
451        Ok(Self(Connection::new_connect(uri, interceptor).await?))
452    }
453
454    /// Gets the signature (list of known functions + namespaces) of the server
455    pub async fn signature(
456        &self,
457        location: Location,
458    ) -> anyhow::Result<proto_sig::ListFunctionsResponse> {
459        self.0.get_signature(location).await
460    }
461
462    /// Allows running graphs (/typechecking) by forwarding requests
463    pub fn as_runtime_worker(&self) -> &dyn RuntimeWorker {
464        &self.0
465    }
466}
467
468impl FunctionForwarder {
469    /// Creates a new instance that forwards requests to a Uri, identifying itself
470    /// with a Location
471    pub async fn new(
472        uri: &Uri,
473        interceptor: ClientInterceptor,
474        loc: Location,
475    ) -> anyhow::Result<Self> {
476        // TODO it probably only makes sense if `loc` has exactly one LocationName.
477        // Should we enforce that?
478        let conn = Connection::new_connect(uri, interceptor).await?;
479        Ok(Self(conn, loc))
480    }
481
482    /// Forwards a call to run a function up the chain.
483    /// `loc` specifies the Location relative to the root of the forwarding chain.
484    pub async fn fwd_run_function(
485        &self,
486        function: FunctionName,
487        inputs: HashMap<Label, Value>,
488        loc: Location,
489        callback: Callback,
490        node_trace: Option<NodeTrace>,
491    ) -> anyhow::Result<HashMap<Label, Value>> {
492        self.0
493            .run_function(
494                function,
495                inputs,
496                // We expect self.1 to be a single LocationName identifying the route
497                // to the root, which then uses Location `loc` (untouched)
498                self.1.clone().concat(&loc),
499                callback,
500                node_trace,
501                None,
502            )
503            .await
504    }
505    fn spawn(&self, function: &FunctionName) -> RuntimeOperation {
506        self.0.spawn(function, &self.1)
507    }
508}
509
510impl EscapeHatch {
511    /// An instance whereby `run_function`s execute on this runtime,
512    /// given a Callback telling child workers how to connect to this runtime.
513    pub fn this_runtime(target: Callback) -> Self {
514        Self::new(None, target)
515    }
516
517    /// Creates a new instance from a Callback to this runtime and optionally a parent.
518    /// If `parent` is non-null, then `this_runtime` should identify `parent`
519    /// in the servers EscapeHatch-forwarding map.
520    pub fn new(parent: Option<FunctionForwarder>, this_runtime: Callback) -> Self {
521        Self {
522            this_runtime,
523            parent,
524        }
525    }
526
527    /// Cross-fingers that this runs a function.
528    /// Deliberately no way to specify a location.
529    pub fn spawn_escape(&self, function: FunctionName) -> Option<RuntimeOperation> {
530        // Note that the Connection will include the OperationContext's callback
531        // (perhaps to this runtime) in the run_function request it issues.
532        self.parent.as_ref().map(|ff| ff.spawn(&function))
533    }
534}
535
536fn inject_trace<R, T>(request: R) -> tonic::Request<T>
537where
538    R: tonic::IntoRequest<T>,
539{
540    use opentelemetry::propagation::TextMapPropagator;
541    use opentelemetry_http::HeaderInjector;
542    use opentelemetry_sdk::propagation::TraceContextPropagator;
543    use tracing_opentelemetry::OpenTelemetrySpanExt;
544
545    let mut request = request.into_request();
546
547    let mut headers = std::mem::take(request.metadata_mut()).into_headers();
548    let span = tracing::Span::current();
549    let context = span.context();
550
551    let propagator = TraceContextPropagator::new();
552    propagator.inject_context(&context, &mut HeaderInjector(&mut headers));
553
554    *request.metadata_mut() = tonic::metadata::MetadataMap::from_headers(headers);
555    request
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use crate::tests::{fake_callback, fake_escape, fake_interceptor, py_loc};
562    use std::error::Error;
563    use std::{collections::HashMap, time::Duration};
564    use tierkreis_core::namespace::Signature;
565    use tierkreis_core::prelude::TryInto;
566    use tierkreis_core::symbol::{Location, Name, Prefix};
567    use tokio::join;
568
569    #[tokio::test]
570    async fn simple_snippet() -> Result<(), Box<dyn Error + Send + Sync>> {
571        let python = format!(
572            "{}/../python/tests/test_worker/main.py",
573            env!("CARGO_MANIFEST_DIR")
574        );
575        let python = ExternalWorker::new_spawn(python, fake_interceptor()).await?;
576
577        let mut inputs = HashMap::new();
578        inputs.insert(TryInto::try_into("a")?, Value::Int(2));
579        inputs.insert(TryInto::try_into("b")?, Value::Int(3));
580
581        let outputs = python
582            .connection
583            .run_function(
584                "python_nodes::python_add".parse()?,
585                inputs,
586                Location::local(),
587                fake_callback(),
588                None,
589                None,
590            )
591            .await?;
592
593        assert_eq!(outputs.get(&Label::value()), Some(&Value::Int(5)));
594
595        Ok(())
596    }
597
598    #[tokio::test]
599    async fn test_mistyped_worker() -> Result<(), Box<dyn Error + Send + Sync>> {
600        use crate::Runtime;
601        use tierkreis_core::graph::GraphBuilder;
602
603        let python = format!(
604            "{}/../python/tests/test_worker/main.py",
605            env!("CARGO_MANIFEST_DIR")
606        );
607
608        let py_worker: ExternalWorker =
609            ExternalWorker::new_spawn(python, fake_interceptor()).await?;
610        let runtime = Runtime::builder()
611            .with_worker(py_worker, py_loc())
612            .await?
613            .start();
614
615        let graph = {
616            let mut builder = GraphBuilder::new();
617            let [input, output] = tierkreis_core::graph::Graph::boundary();
618
619            let bad_op = builder.add_node("python_nodes::mistyped_op")?;
620            builder.add_edge((input, "i"), (bad_op, "inp"), None)?;
621            builder.add_edge((bad_op, "value"), (output, "res"), None)?;
622            builder.build()?
623        };
624
625        let inputs = HashMap::from([(TryInto::try_into("i")?, Value::Int(6))]);
626
627        let err = runtime
628            .execute_graph_cb(graph, inputs, true, fake_callback(), fake_escape())
629            .await
630            .expect_err("Should detect runtime type mismatch");
631        // TODO worker function does not propagate internal betterproto
632        // conversion error (expects integer and finds float)
633        assert!(format!("{:?}", err).contains("failed to run function in worker"));
634        Ok(())
635    }
636
637    #[tokio::test]
638    async fn concurrent_nodes() -> Result<(), Box<dyn Error + Send + Sync>> {
639        let python = format!(
640            "{}/../python/tests/test_worker/main.py",
641            env!("CARGO_MANIFEST_DIR")
642        );
643        let python_1 = ExternalWorker::new_spawn(&python, fake_interceptor()).await?;
644        let python_2 = python_1.clone();
645
646        let mut inputs_1 = HashMap::new();
647        inputs_1.insert(TryInto::try_into("wait")?, Value::Int(1));
648        inputs_1.insert(Label::value(), Value::Int(0));
649
650        let mut inputs_2 = HashMap::new();
651        inputs_2.insert(TryInto::try_into("wait")?, Value::Int(1));
652        inputs_2.insert(Label::value(), Value::Int(1));
653
654        let earlier = std::time::Instant::now();
655
656        let fut_1 = python_1.connection.run_function(
657            "python_nodes::id_delay".parse()?,
658            inputs_1,
659            Location::local(),
660            fake_callback(),
661            None,
662            None,
663        );
664        let fut_2 = python_2.connection.run_function(
665            "python_nodes::id_delay".parse()?,
666            inputs_2,
667            Location::local(),
668            fake_callback(),
669            None,
670            None,
671        );
672
673        // We need to use `join!` instead of awaiting in sequence to make sure the futures
674        // are both polled simultaneously.
675        let (result_1, result_2) = join!(fut_1, fut_2);
676
677        let outputs_1 = result_1?;
678        let outputs_2 = result_2?;
679
680        let now = std::time::Instant::now();
681
682        assert!(now.duration_since(earlier) < Duration::from_millis(1250));
683        assert_eq!(outputs_1[&Label::value()], Value::Int(0));
684        assert_eq!(outputs_2[&Label::value()], Value::Int(1));
685
686        Ok(())
687    }
688
689    #[tokio::test]
690    async fn check_external_location() -> Result<(), Box<dyn Error + Send + Sync>> {
691        let python = format!(
692            "{}/../python/tests/test_worker/main.py",
693            env!("CARGO_MANIFEST_DIR")
694        );
695        let python = ExternalWorker::new_spawn(&python, fake_interceptor()).await?;
696        let pn: Prefix = TryInto::try_into("python_nodes")?;
697        let name: Name = TryInto::try_into("id_py")?;
698        let sig: Signature = TryInto::try_into(python.signature(Location::local()).await?)?;
699        let item = &sig.root.subspaces[&pn].functions[&name];
700        assert_eq!(item.locations, vec![Location(vec![])]);
701        Ok(())
702    }
703}