1use super::{FunctionWorker, RuntimeWorker, Worker};
2use crate::operations::RuntimeOperation;
3use crate::RunGraphError;
4use anyhow::{anyhow, Context};
5use hyper_util::rt::TokioIo;
6use std::collections::HashMap;
7use std::ffi::OsStr;
8use std::process::Stdio;
9use std::sync::Arc;
10use tierkreis_core::graph::{Graph, Value};
11use tierkreis_core::namespace::Signature;
12use tierkreis_core::prelude::TryInto;
13use tierkreis_core::symbol::{FunctionName, Label, Location};
14use tierkreis_proto::messages::{
15 self as proto_messages, Callback, JobHandle, NodeTrace, RunGraphResponse,
16};
17use tierkreis_proto::protos_gen::v1alpha1::runtime as proto_runtime;
18use tierkreis_proto::protos_gen::v1alpha1::signature as proto_sig;
19use tierkreis_proto::protos_gen::v1alpha1::worker as proto_worker;
20use tierkreis_proto::ConvertError;
21use tokio::io::{AsyncBufReadExt, BufReader};
22use tokio::net::UnixStream;
23use tokio::process::Child;
24use tokio::process::Command;
25use tonic::async_trait;
26use tonic::service::Interceptor;
27use tonic::transport::{Channel, Endpoint, Uri};
28use tracing::instrument;
29#[derive(Clone)]
31struct Connection {
32 channel: Channel,
33 interceptor: ClientInterceptor,
34
35 #[allow(dead_code)]
37 process: Option<Arc<Child>>,
38}
39
40#[derive(Clone)]
43pub struct ExternalWorker {
44 connection: Connection,
45 signature: Signature,
47}
48
49pub struct CallbackForwarder(Connection);
52
53#[derive(Clone)]
58pub struct FunctionForwarder(Connection, Location);
59
60#[derive(Clone)]
63pub struct EscapeHatch {
64 parent: Option<FunctionForwarder>,
67 this_runtime: Callback,
70}
71
72#[derive(Clone, Debug, Default)]
74pub enum AuthInjector {
75 #[default]
77 NoAuth,
78 #[allow(missing_docs)]
80 TokenKey { token: String, key: String },
81}
82
83#[derive(Clone, Debug, Default)]
85pub struct ClientInterceptor {
86 pub auther: AuthInjector,
88}
89
90impl ClientInterceptor {
91 pub fn new(auther: AuthInjector) -> Self {
93 Self { auther }
94 }
95}
96
97impl Interceptor for ClientInterceptor {
98 fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
99 match &self.auther {
100 AuthInjector::NoAuth => Ok(req),
101 AuthInjector::TokenKey { token, key } => {
102 req.metadata_mut().insert("token", token.parse().unwrap());
103 req.metadata_mut().insert("key", key.parse().unwrap());
104 Ok(req)
105 }
106 }
107 }
108}
109
110impl Connection {
111 #[instrument(
114 name="starting external worker process"
115 skip(command, interceptor),
116 fields(command = &*command.as_ref().to_string_lossy()))]
117 async fn new_spawn(
118 command: impl AsRef<OsStr>,
119 interceptor: ClientInterceptor,
120 ) -> anyhow::Result<Self> {
121 let mut process = Command::new(command.as_ref())
123 .stdout(Stdio::piped())
124 .stderr(Stdio::inherit())
125 .kill_on_drop(true)
126 .spawn()
127 .with_context(|| "failed to spawn worker process")?;
128
129 tracing::debug!("Spawned external worker process.");
130
131 let socket_path = BufReader::new(process.stdout.take().unwrap())
136 .lines()
137 .next_line()
138 .await
139 .with_context(|| "failed to receive socket path from worker process")
140 .unwrap()
141 .ok_or_else(|| anyhow!("worker process did not send socket path"))
142 .unwrap();
143
144 tracing::debug!("Received unix domain socket path from python worker.");
145 let channel = Endpoint::new("http://[::]:50051")?
148 .connect_with_connector(tower::service_fn({
149 let socket_path = socket_path.clone();
150 move |_: Uri| {
151 let socket_path = socket_path.clone();
152 async move {
153 Ok::<_, std::io::Error>(TokioIo::new(
154 UnixStream::connect(socket_path).await?,
155 ))
156 }
157 }
158 }))
159 .await
160 .with_context(|| "failed to connect to the socket of the python worker")?;
161
162 tracing::info!("Connected to spawned worker's grpc server.");
163
164 Ok(Self {
165 channel,
166 process: Some(Arc::new(process)),
167 interceptor,
168 })
169 }
170
171 #[instrument(
172 name="connecting to external worker"
173 skip(uri, interceptor),
174 fields(uri = uri.to_string().as_str())
175 )]
176 async fn new_connect(uri: &Uri, interceptor: ClientInterceptor) -> anyhow::Result<Self> {
177 let channel = Endpoint::from(uri.clone())
178 .connect()
179 .await
180 .with_context(|| format!("failed to connect to worker at uri {}", uri))?;
181
182 tracing::info!(
183 uri = uri.to_string().as_str(),
184 "Connected to the worker's grpc server."
185 );
186
187 Ok(Self {
188 channel,
189 process: None,
190 interceptor,
191 })
192 }
193
194 #[instrument(
195 name = "run external function",
196 fields(otel.kind = "client", _node_trace = node_trace.as_ref().map_or(String::new(), ToString::to_string)),
197 skip(self, inputs, node_trace, callback)
198 )]
199 async fn run_function(
200 &self,
201 function: FunctionName,
202 inputs: HashMap<Label, Value>,
203 loc: Location,
204 callback: Callback,
205 node_trace: Option<NodeTrace>,
206 job_handle: Option<JobHandle>,
207 ) -> anyhow::Result<HashMap<Label, Value>> {
208 let mut request = proto_messages::RunFunctionRequest::new(function, inputs, loc, callback);
209 let node_str = node_trace
210 .as_ref()
211 .map(|nt| nt.to_string())
212 .unwrap_or_else(|| "None".to_string());
213 if let Some(node_trace) = node_trace {
214 request.set_node_trace(node_trace);
215 }
216
217 if let Some(job_id) = job_handle {
218 request.set_job_id(job_id.into_inner().0);
219 }
220
221 let mut worker_client = proto_worker::worker_client::WorkerClient::with_interceptor(
222 self.channel.clone(),
223 self.interceptor.clone(),
224 );
225
226 let response = worker_client.run_function(inject_trace(request)).await;
227 if let Err(e) = &response {
228 tracing::info!(
229 error = e.to_string(),
230 node = node_str,
231 "failed to run function in worker."
232 );
233 }
234 let response = response
235 .context("failed to run function in worker")?
236 .into_inner();
237
238 tracing::debug!("received response from worker");
239
240 let outputs = TryInto::try_into(response.outputs.unwrap_or_default())
241 .context("failed to parse function outputs from worker response")?;
242
243 Ok(outputs)
244 }
245
246 async fn run_graph(
247 &self,
248 graph: Graph,
249 inputs: HashMap<Label, Value>,
250 loc: Location,
251 type_check: bool,
252 escape: Option<Callback>,
253 ) -> Result<HashMap<Label, Value>, RunGraphError> {
254 let request = proto_messages::RunGraphRequest {
255 graph,
256 inputs,
257 loc,
258 type_check,
259 escape,
260 };
261
262 let mut runtime_client = proto_runtime::runtime_client::RuntimeClient::with_interceptor(
263 self.channel.clone(),
264 self.interceptor.clone(),
265 );
266
267 let resp: RunGraphResponse = TryInto::try_into(
268 runtime_client
269 .run_graph(inject_trace(request))
270 .await
271 .context("failed to run graph in scope")?
272 .into_inner(),
273 )
274 .map_err(|e: ConvertError| anyhow!(e))?;
275
276 match resp {
277 RunGraphResponse::Success(x) => Ok(x),
278 RunGraphResponse::TypeError(x) => Err(RunGraphError::TypeError(x)),
279 RunGraphResponse::Error(x) => Err(RunGraphError::RuntimeError(Arc::new(anyhow!(x)))),
280 }
281 }
282
283 async fn get_signature(
284 &self,
285 loc: Location,
286 ) -> anyhow::Result<proto_sig::ListFunctionsResponse> {
287 let mut signature_client = proto_sig::signature_client::SignatureClient::with_interceptor(
288 self.channel.clone(),
289 self.interceptor.clone(),
290 );
291
292 tracing::debug!("Retrieving signature for worker.");
294 let signature = signature_client
295 .list_functions(inject_trace(proto_messages::ListFunctionsRequest { loc }))
296 .await
297 .context("failed to retrieve signature from worker")?
298 .into_inner();
299
300 Ok(signature)
301 }
302
303 async fn type_check(
304 &self,
305 value: Value,
306 loc: Location,
307 ) -> anyhow::Result<proto_sig::InferTypeResponse> {
308 let mut tc_client = proto_sig::type_inference_client::TypeInferenceClient::with_interceptor(
309 self.channel.clone(),
310 self.interceptor.clone(),
311 );
312
313 tracing::debug!("Type checking on worker.");
314 let resp = tc_client
315 .infer_type(inject_trace(proto_messages::InferTypeRequest {
316 value,
317 loc,
318 }))
319 .await
320 .context("failed to infer type on worker")?
321 .into_inner();
322
323 Ok(resp)
324 }
325
326 fn spawn(&self, function: &FunctionName, loc: &Location) -> RuntimeOperation {
327 let function = function.clone();
328 let loc = loc.clone();
329 let this = self.clone();
330 RuntimeOperation::new_fn_async(move |inputs, context| async move {
331 let _ = &context;
332 this.run_function(
333 function,
334 inputs,
335 loc,
336 context.callback,
337 Some(context.graph_trace.as_node_trace()?),
338 context.checkpoint_client.map(|ch| ch.job_handle),
339 )
340 .await
341 })
342 }
343}
344
345#[async_trait]
346impl RuntimeWorker for Connection {
347 async fn execute_graph(
348 &self,
349 graph: Graph,
350 inputs: HashMap<Label, Value>,
351 location: Location,
352 type_check: bool,
353 escape: Option<Callback>,
354 ) -> Result<HashMap<Label, Value>, RunGraphError> {
355 self.run_graph(graph, inputs, location, type_check, escape)
356 .await
357 }
358
359 fn spawn_graph(&self, graph: Graph, location: &Location) -> RuntimeOperation {
360 let this = self.clone();
361 let loc = location.clone();
362 RuntimeOperation::new_fn_async(move |inputs, context| async move {
363 let _ = &context;
364 match this
365 .run_graph(graph, inputs, loc, false, Some(context.escape.this_runtime))
366 .await
367 {
368 Ok(x) => Ok(x),
369 Err(RunGraphError::TypeError(x)) => Err(anyhow!("Type errors: {}", x)),
370 Err(RunGraphError::RuntimeError(x)) => Err(anyhow!(x)),
371 }
372 })
373 }
374
375 async fn infer_type(
376 &self,
377 to_check: Value,
378 location: Location,
379 ) -> anyhow::Result<proto_sig::InferTypeResponse> {
380 self.type_check(to_check, location).await
381 }
382}
383
384impl ExternalWorker {
385 async fn from_connection(connection: Connection) -> anyhow::Result<Self> {
386 let signature = TryInto::try_into(connection.get_signature(Location::local()).await?)?;
387 Ok(Self {
388 connection,
389 signature,
390 })
391 }
392
393 pub async fn new_spawn(
401 command: impl AsRef<OsStr>,
402 interceptor: ClientInterceptor,
403 ) -> anyhow::Result<Self> {
404 let connection = Connection::new_spawn(command, interceptor).await?;
405 Self::from_connection(connection).await
406 }
407
408 pub async fn new_connect(uri: &Uri, interceptor: ClientInterceptor) -> anyhow::Result<Self> {
410 let connection = Connection::new_connect(uri, interceptor).await?;
411 Self::from_connection(connection).await
412 }
413}
414
415#[async_trait]
416impl FunctionWorker for ExternalWorker {
417 fn spawn(&self, function: &FunctionName, loc: &Location) -> RuntimeOperation {
418 self.connection.spawn(function, loc)
419 }
420}
421
422#[async_trait]
423impl Worker for ExternalWorker {
424 async fn signature(
425 &self,
426 location: Location,
427 ) -> anyhow::Result<proto_sig::ListFunctionsResponse> {
428 if location == Location::local() {
429 Ok(self.signature.clone().into())
430 } else {
431 self.connection.get_signature(location).await
432 }
433 }
434
435 fn to_runtime_worker(&self) -> Option<&dyn RuntimeWorker> {
436 if self.signature.scopes.is_empty() {
437 None
438 } else {
439 Some(&self.connection)
440 }
441 }
442}
443
444impl CallbackForwarder {
445 pub async fn new_connect(uri: &Uri, interceptor: ClientInterceptor) -> anyhow::Result<Self> {
451 Ok(Self(Connection::new_connect(uri, interceptor).await?))
452 }
453
454 pub async fn signature(
456 &self,
457 location: Location,
458 ) -> anyhow::Result<proto_sig::ListFunctionsResponse> {
459 self.0.get_signature(location).await
460 }
461
462 pub fn as_runtime_worker(&self) -> &dyn RuntimeWorker {
464 &self.0
465 }
466}
467
468impl FunctionForwarder {
469 pub async fn new(
472 uri: &Uri,
473 interceptor: ClientInterceptor,
474 loc: Location,
475 ) -> anyhow::Result<Self> {
476 let conn = Connection::new_connect(uri, interceptor).await?;
479 Ok(Self(conn, loc))
480 }
481
482 pub async fn fwd_run_function(
485 &self,
486 function: FunctionName,
487 inputs: HashMap<Label, Value>,
488 loc: Location,
489 callback: Callback,
490 node_trace: Option<NodeTrace>,
491 ) -> anyhow::Result<HashMap<Label, Value>> {
492 self.0
493 .run_function(
494 function,
495 inputs,
496 self.1.clone().concat(&loc),
499 callback,
500 node_trace,
501 None,
502 )
503 .await
504 }
505 fn spawn(&self, function: &FunctionName) -> RuntimeOperation {
506 self.0.spawn(function, &self.1)
507 }
508}
509
510impl EscapeHatch {
511 pub fn this_runtime(target: Callback) -> Self {
514 Self::new(None, target)
515 }
516
517 pub fn new(parent: Option<FunctionForwarder>, this_runtime: Callback) -> Self {
521 Self {
522 this_runtime,
523 parent,
524 }
525 }
526
527 pub fn spawn_escape(&self, function: FunctionName) -> Option<RuntimeOperation> {
530 self.parent.as_ref().map(|ff| ff.spawn(&function))
533 }
534}
535
536fn inject_trace<R, T>(request: R) -> tonic::Request<T>
537where
538 R: tonic::IntoRequest<T>,
539{
540 use opentelemetry::propagation::TextMapPropagator;
541 use opentelemetry_http::HeaderInjector;
542 use opentelemetry_sdk::propagation::TraceContextPropagator;
543 use tracing_opentelemetry::OpenTelemetrySpanExt;
544
545 let mut request = request.into_request();
546
547 let mut headers = std::mem::take(request.metadata_mut()).into_headers();
548 let span = tracing::Span::current();
549 let context = span.context();
550
551 let propagator = TraceContextPropagator::new();
552 propagator.inject_context(&context, &mut HeaderInjector(&mut headers));
553
554 *request.metadata_mut() = tonic::metadata::MetadataMap::from_headers(headers);
555 request
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use crate::tests::{fake_callback, fake_escape, fake_interceptor, py_loc};
562 use std::error::Error;
563 use std::{collections::HashMap, time::Duration};
564 use tierkreis_core::namespace::Signature;
565 use tierkreis_core::prelude::TryInto;
566 use tierkreis_core::symbol::{Location, Name, Prefix};
567 use tokio::join;
568
569 #[tokio::test]
570 async fn simple_snippet() -> Result<(), Box<dyn Error + Send + Sync>> {
571 let python = format!(
572 "{}/../python/tests/test_worker/main.py",
573 env!("CARGO_MANIFEST_DIR")
574 );
575 let python = ExternalWorker::new_spawn(python, fake_interceptor()).await?;
576
577 let mut inputs = HashMap::new();
578 inputs.insert(TryInto::try_into("a")?, Value::Int(2));
579 inputs.insert(TryInto::try_into("b")?, Value::Int(3));
580
581 let outputs = python
582 .connection
583 .run_function(
584 "python_nodes::python_add".parse()?,
585 inputs,
586 Location::local(),
587 fake_callback(),
588 None,
589 None,
590 )
591 .await?;
592
593 assert_eq!(outputs.get(&Label::value()), Some(&Value::Int(5)));
594
595 Ok(())
596 }
597
598 #[tokio::test]
599 async fn test_mistyped_worker() -> Result<(), Box<dyn Error + Send + Sync>> {
600 use crate::Runtime;
601 use tierkreis_core::graph::GraphBuilder;
602
603 let python = format!(
604 "{}/../python/tests/test_worker/main.py",
605 env!("CARGO_MANIFEST_DIR")
606 );
607
608 let py_worker: ExternalWorker =
609 ExternalWorker::new_spawn(python, fake_interceptor()).await?;
610 let runtime = Runtime::builder()
611 .with_worker(py_worker, py_loc())
612 .await?
613 .start();
614
615 let graph = {
616 let mut builder = GraphBuilder::new();
617 let [input, output] = tierkreis_core::graph::Graph::boundary();
618
619 let bad_op = builder.add_node("python_nodes::mistyped_op")?;
620 builder.add_edge((input, "i"), (bad_op, "inp"), None)?;
621 builder.add_edge((bad_op, "value"), (output, "res"), None)?;
622 builder.build()?
623 };
624
625 let inputs = HashMap::from([(TryInto::try_into("i")?, Value::Int(6))]);
626
627 let err = runtime
628 .execute_graph_cb(graph, inputs, true, fake_callback(), fake_escape())
629 .await
630 .expect_err("Should detect runtime type mismatch");
631 assert!(format!("{:?}", err).contains("failed to run function in worker"));
634 Ok(())
635 }
636
637 #[tokio::test]
638 async fn concurrent_nodes() -> Result<(), Box<dyn Error + Send + Sync>> {
639 let python = format!(
640 "{}/../python/tests/test_worker/main.py",
641 env!("CARGO_MANIFEST_DIR")
642 );
643 let python_1 = ExternalWorker::new_spawn(&python, fake_interceptor()).await?;
644 let python_2 = python_1.clone();
645
646 let mut inputs_1 = HashMap::new();
647 inputs_1.insert(TryInto::try_into("wait")?, Value::Int(1));
648 inputs_1.insert(Label::value(), Value::Int(0));
649
650 let mut inputs_2 = HashMap::new();
651 inputs_2.insert(TryInto::try_into("wait")?, Value::Int(1));
652 inputs_2.insert(Label::value(), Value::Int(1));
653
654 let earlier = std::time::Instant::now();
655
656 let fut_1 = python_1.connection.run_function(
657 "python_nodes::id_delay".parse()?,
658 inputs_1,
659 Location::local(),
660 fake_callback(),
661 None,
662 None,
663 );
664 let fut_2 = python_2.connection.run_function(
665 "python_nodes::id_delay".parse()?,
666 inputs_2,
667 Location::local(),
668 fake_callback(),
669 None,
670 None,
671 );
672
673 let (result_1, result_2) = join!(fut_1, fut_2);
676
677 let outputs_1 = result_1?;
678 let outputs_2 = result_2?;
679
680 let now = std::time::Instant::now();
681
682 assert!(now.duration_since(earlier) < Duration::from_millis(1250));
683 assert_eq!(outputs_1[&Label::value()], Value::Int(0));
684 assert_eq!(outputs_2[&Label::value()], Value::Int(1));
685
686 Ok(())
687 }
688
689 #[tokio::test]
690 async fn check_external_location() -> Result<(), Box<dyn Error + Send + Sync>> {
691 let python = format!(
692 "{}/../python/tests/test_worker/main.py",
693 env!("CARGO_MANIFEST_DIR")
694 );
695 let python = ExternalWorker::new_spawn(&python, fake_interceptor()).await?;
696 let pn: Prefix = TryInto::try_into("python_nodes")?;
697 let name: Name = TryInto::try_into("id_py")?;
698 let sig: Signature = TryInto::try_into(python.signature(Location::local()).await?)?;
699 let item = &sig.root.subspaces[&pn].functions[&name];
700 assert_eq!(item.locations, vec![Location(vec![])]);
701 Ok(())
702 }
703}