triton_distributed/pipeline/
nodes.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Pipeline Nodes
17//!
18//! A `ServicePipeline` is a directed graph of nodes where each node defines a behavior for both
19//! forward/request path and the backward/response path. The allowed behaviors in each direction
20//! are is either a `Source`, or a `Sink`.
21//!
22//! A `Frontend` is a the start of a graph and is a [`Source`] for the forward path and a [`Sink`] for the
23//! backward path.
24//!
25//! A `Backend` is the end of a graph and is a [`Sink`] for the forward path and a [`Source`] for the
26//! backward path.
27//!
28//! An [`PipelineOperator`] is a node that can transform both the forward and backward paths using the
29//! logic supplied by the implementation of an [`Operator`] trait. Because the [`PipelineOperator`] is
30//! both a [`Source`] and a [`Sink`] of the forward request path and the backward response path respectively,
31//! i.e. it is two sources and two sinks. We can differentiate the two by using the [`PipelineOperator::forward_edge`]
32//! and [`PipelineOperator::backward_edge`] methods.
33//!
34//! - The [`PipelineOperator::forward_edge`] returns a [`PipelineOperatorForwardEdge`] which is a [`Sink`]
35//!   for incoming/upstream request and a [`Source`] for the downstream request.
36//! - The [`PipelineOperator::backward_edge`] returns a [`PipelineOperatorBackwardEdge`] which is a [`Sink`]
37//!   for the downstream response and a [`Source`] for the upstream response.
38//!
39//! An `EdgeOperator` currently named [`PipelineNode`] is a node in the graph can transform only a forward
40//! or a backward path, but does not transform both.
41//!
42//! This makes the [`Operator`] a more powerful trait as it can propagate information from the forward
43//! path to the backward path. An `EdgeOperator` on the forward path has no visibility into the backward
44//! path and therefore, cannot directly influence the backward path.
45//!
46use std::{
47    collections::HashMap,
48    sync::{Arc, Mutex, OnceLock},
49};
50
51use super::AsyncEngine;
52use async_trait::async_trait;
53use tokio::sync::oneshot;
54
55use super::{Data, Error, PipelineError, PipelineIO};
56
57mod sinks;
58mod sources;
59
60pub use sinks::{SegmentSink, ServiceBackend};
61pub use sources::{SegmentSource, ServiceFrontend};
62
63pub type Service<In, Out> = Arc<ServiceFrontend<In, Out>>;
64
65mod private {
66    pub struct Token;
67}
68
69// todo rename `ServicePipelineExt`
70/// A [`Source`] trait defines how data is emitted from a source to a downstream sink
71/// over an [`Edge`].
72#[async_trait]
73pub trait Source<T: PipelineIO>: Data {
74    async fn on_next(&self, data: T, _: private::Token) -> Result<(), Error>;
75
76    fn set_edge(&self, edge: Edge<T>, _: private::Token) -> Result<(), PipelineError>;
77
78    fn link<S: Sink<T> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
79        let edge = Edge::new(sink.clone());
80        self.set_edge(edge, private::Token)?;
81        Ok(sink)
82    }
83}
84
85/// A [`Sink`] trait defines how data is received from a source and processed.
86#[async_trait]
87pub trait Sink<T: PipelineIO>: Data {
88    async fn on_data(&self, data: T, _: private::Token) -> Result<(), Error>;
89}
90
91/// An [`Edge`] is a connection between a [`Source`] and a [`Sink`]. Data flows over an [`Edge`].
92pub struct Edge<T: PipelineIO> {
93    downstream: Arc<dyn Sink<T>>,
94}
95
96impl<T: PipelineIO> Edge<T> {
97    fn new(downstream: Arc<dyn Sink<T>>) -> Self {
98        Edge { downstream }
99    }
100
101    async fn write(&self, data: T) -> Result<(), Error> {
102        self.downstream.on_data(data, private::Token).await
103    }
104}
105
106type NodeFn<In, Out> = Box<dyn Fn(In) -> Result<Out, Error> + Send + Sync>;
107
108/// An [`Operator`] is a trait that defines the behavior of how two [`AsyncEngine`] can be chained together.
109/// An [`Operator`] is not quite an [`AsyncEngine`] because its generate method requires both the upstream
110/// request, but also the downstream [`AsyncEngine`] to which it will pass the transformed request.
111/// The [`Operator`] logic must transform the upstream request `UpIn` to the downstream request `DownIn`,
112/// then transform the downstream response `DownOut` to the upstream response `UpOut`.
113///
114/// A [`PipelineOperator`] accepts an [`Operator`] and presents itself as an [`AsyncEngine`] for the upstream
115/// [`AsyncEngine<UpIn, UpOut, Error>`].
116///
117/// ### Example of type transformation and data flow
118/// ```text
119/// ... --> <UpIn> ---> [Operator] --> <DownIn> ---> ...
120/// ... <-- <UpOut> --> [Operator] <-- <DownOut> <-- ...
121/// ```
122#[async_trait]
123pub trait Operator<UpIn: PipelineIO, UpOut: PipelineIO, DownIn: PipelineIO, DownOut: PipelineIO>:
124    Data
125{
126    /// This method is expected to transform the upstream request `UpIn` to the downstream request `DownIn`,
127    /// call the next [`AsyncEngine`] with the transformed request, then transform the downstream response
128    /// `DownOut` to the upstream response `UpOut`.
129    async fn generate(
130        &self,
131        req: UpIn,
132        next: Arc<dyn AsyncEngine<DownIn, DownOut, Error>>,
133    ) -> Result<UpOut, Error>;
134
135    fn into_operator(self: &Arc<Self>) -> Arc<PipelineOperator<UpIn, UpOut, DownIn, DownOut>>
136    where
137        Self: Sized,
138    {
139        PipelineOperator::new(self.clone())
140    }
141}
142
143/// A [`PipelineOperatorForwardEdge`] is [`Sink`] for the upstream request type `UpIn` and a [`Source`] for the
144/// downstream request type `DownIn`.
145pub struct PipelineOperatorForwardEdge<
146    UpIn: PipelineIO,
147    UpOut: PipelineIO,
148    DownIn: PipelineIO,
149    DownOut: PipelineIO,
150> {
151    parent: Arc<PipelineOperator<UpIn, UpOut, DownIn, DownOut>>,
152}
153
154/// A [`PipelineOperatorBackwardEdge`] is [`Sink`] for the downstream response type `DownOut` and a [`Source`] for the
155/// upstream response type `UpOut`.
156pub struct PipelineOperatorBackwardEdge<
157    UpIn: PipelineIO,
158    UpOut: PipelineIO,
159    DownIn: PipelineIO,
160    DownOut: PipelineIO,
161> {
162    parent: Arc<PipelineOperator<UpIn, UpOut, DownIn, DownOut>>,
163}
164
165/// A [`PipelineOperator`] is a node that can transform both the forward and backward paths using the logic defined
166/// by the implementation of an [`Operator`] trait.
167pub struct PipelineOperator<
168    UpIn: PipelineIO,
169    UpOut: PipelineIO,
170    DownIn: PipelineIO,
171    DownOut: PipelineIO,
172> {
173    // core business logic of this object
174    operator: Arc<dyn Operator<UpIn, UpOut, DownIn, DownOut>>,
175
176    // this hold the downstream connections via the generic frontend
177    // frontends provide both a source and a sink interfaces
178    downstream: Arc<sources::Frontend<DownIn, DownOut>>,
179
180    // this hold the connection to the previous/upstream response sink
181    // we are a source to that upstream's response sink
182    upstream: sinks::SinkEdge<UpOut>,
183}
184
185impl<UpIn, UpOut, DownIn, DownOut> PipelineOperator<UpIn, UpOut, DownIn, DownOut>
186where
187    UpIn: PipelineIO,
188    UpOut: PipelineIO,
189    DownIn: PipelineIO,
190    DownOut: PipelineIO,
191{
192    /// Create a new [`PipelineOperator`] with the given [`Operator`] implementation.
193    pub fn new(operator: Arc<dyn Operator<UpIn, UpOut, DownIn, DownOut>>) -> Arc<Self> {
194        Arc::new(PipelineOperator {
195            operator,
196            downstream: Arc::new(sources::Frontend::default()),
197            upstream: sinks::SinkEdge::default(),
198        })
199    }
200
201    /// Access the forward edge of the [`PipelineOperator`] allowing the forward/requests paths to be linked.
202    pub fn forward_edge(
203        self: &Arc<Self>,
204    ) -> Arc<PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut>> {
205        Arc::new(PipelineOperatorForwardEdge {
206            parent: self.clone(),
207        })
208    }
209
210    /// Access the backward edge of the [`PipelineOperator`] allowing the backward/responses paths to be linked.
211    pub fn backward_edge(
212        self: &Arc<Self>,
213    ) -> Arc<PipelineOperatorBackwardEdge<UpIn, UpOut, DownIn, DownOut>> {
214        Arc::new(PipelineOperatorBackwardEdge {
215            parent: self.clone(),
216        })
217    }
218}
219
220/// A [`PipelineOperator`] is an [`AsyncEngine`] for the upstream [`AsyncEngine<UpIn, UpOut, Error>`].
221#[async_trait]
222impl<UpIn, UpOut, DownIn, DownOut> AsyncEngine<UpIn, UpOut, Error>
223    for PipelineOperator<UpIn, UpOut, DownIn, DownOut>
224where
225    UpIn: PipelineIO,
226    DownIn: PipelineIO,
227    DownOut: PipelineIO,
228    UpOut: PipelineIO,
229{
230    async fn generate(&self, req: UpIn) -> Result<UpOut, Error> {
231        self.operator.generate(req, self.downstream.clone()).await
232    }
233}
234
235#[async_trait]
236impl<UpIn, UpOut, DownIn, DownOut> Sink<UpIn>
237    for PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut>
238where
239    UpIn: PipelineIO,
240    DownIn: PipelineIO,
241    DownOut: PipelineIO,
242    UpOut: PipelineIO,
243{
244    async fn on_data(&self, data: UpIn, _token: private::Token) -> Result<(), Error> {
245        let stream = self.parent.generate(data).await?;
246        self.parent.upstream.on_next(stream, private::Token).await
247    }
248}
249
250#[async_trait]
251impl<UpIn, UpOut, DownIn, DownOut> Source<DownIn>
252    for PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut>
253where
254    UpIn: PipelineIO,
255    DownIn: PipelineIO,
256    DownOut: PipelineIO,
257    UpOut: PipelineIO,
258{
259    async fn on_next(&self, data: DownIn, token: private::Token) -> Result<(), Error> {
260        self.parent.downstream.on_next(data, token).await
261    }
262
263    fn set_edge(&self, edge: Edge<DownIn>, token: private::Token) -> Result<(), PipelineError> {
264        self.parent.downstream.set_edge(edge, token)
265    }
266}
267
268#[async_trait]
269impl<UpIn, UpOut, DownIn, DownOut> Sink<DownOut>
270    for PipelineOperatorBackwardEdge<UpIn, UpOut, DownIn, DownOut>
271where
272    UpIn: PipelineIO,
273    DownIn: PipelineIO,
274    DownOut: PipelineIO,
275    UpOut: PipelineIO,
276{
277    async fn on_data(&self, data: DownOut, token: private::Token) -> Result<(), Error> {
278        self.parent.downstream.on_data(data, token).await
279    }
280}
281
282#[async_trait]
283impl<UpIn, UpOut, DownIn, DownOut> Source<UpOut>
284    for PipelineOperatorBackwardEdge<UpIn, UpOut, DownIn, DownOut>
285where
286    UpIn: PipelineIO,
287    DownIn: PipelineIO,
288    DownOut: PipelineIO,
289    UpOut: PipelineIO,
290{
291    async fn on_next(&self, data: UpOut, token: private::Token) -> Result<(), Error> {
292        self.parent.upstream.on_next(data, token).await
293    }
294
295    fn set_edge(&self, edge: Edge<UpOut>, token: private::Token) -> Result<(), PipelineError> {
296        self.parent.upstream.set_edge(edge, token)
297    }
298}
299
300pub struct PipelineNode<In: PipelineIO, Out: PipelineIO> {
301    edge: OnceLock<Edge<Out>>,
302    map_fn: NodeFn<In, Out>,
303}
304
305impl<In: PipelineIO, Out: PipelineIO> PipelineNode<In, Out> {
306    pub fn new(map_fn: NodeFn<In, Out>) -> Arc<Self> {
307        Arc::new(PipelineNode::<In, Out> {
308            edge: OnceLock::new(),
309            map_fn,
310        })
311    }
312}
313
314#[async_trait]
315impl<In: PipelineIO, Out: PipelineIO> Source<Out> for PipelineNode<In, Out> {
316    async fn on_next(&self, data: Out, _: private::Token) -> Result<(), Error> {
317        self.edge
318            .get()
319            .ok_or(PipelineError::NoEdge)?
320            .write(data)
321            .await
322    }
323
324    fn set_edge(&self, edge: Edge<Out>, _: private::Token) -> Result<(), PipelineError> {
325        self.edge
326            .set(edge)
327            .map_err(|_| PipelineError::EdgeAlreadySet)?;
328
329        Ok(())
330    }
331}
332
333#[async_trait]
334impl<In: PipelineIO, Out: PipelineIO> Sink<In> for PipelineNode<In, Out> {
335    async fn on_data(&self, data: In, _: private::Token) -> Result<(), Error> {
336        self.on_next((self.map_fn)(data)?, private::Token).await
337    }
338}
339
340#[cfg(test)]
341mod tests {
342
343    use super::*;
344    use crate::pipeline::*;
345
346    #[tokio::test]
347    async fn test_pipeline_source_no_edge() {
348        let source = ServiceFrontend::<SingleIn<()>, ManyOut<()>>::new();
349        let stream = source.generate(().into()).await;
350        assert!(stream.is_err());
351    }
352}