triton_distributed/pipeline/network/ingress/
push_endpoint.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17use anyhow::Result;
18use async_nats::service::endpoint::Endpoint;
19use derive_builder::Builder;
20use tokio_util::sync::CancellationToken;
21use tracing as log;
22
23#[derive(Builder)]
24pub struct PushEndpoint {
25    pub service_handler: Arc<dyn PushWorkHandler>,
26    pub cancellation_token: CancellationToken,
27}
28
29/// version of crate
30pub const VERSION: &str = env!("CARGO_PKG_VERSION");
31
32impl PushEndpoint {
33    pub fn builder() -> PushEndpointBuilder {
34        PushEndpointBuilder::default()
35    }
36
37    pub async fn start(self, endpoint: Endpoint) -> Result<()> {
38        let mut endpoint = endpoint;
39
40        loop {
41            let req = tokio::select! {
42                biased;
43
44                // await on service request
45                req = endpoint.next() => {
46                    req
47                }
48
49                // process shutdown
50                _ = self.cancellation_token.cancelled() => {
51                    // log::trace!(worker_id, "Shutting down service {}", self.endpoint.name);
52                    if let Err(e) = endpoint.stop().await {
53                        log::warn!("Failed to stop NATS service: {:?}", e);
54                    }
55                    break;
56                }
57            };
58
59            if let Some(req) = req {
60                let response = "".to_string();
61                if let Err(e) = req.respond(Ok(response.into())).await {
62                    log::warn!("Failed to respond to request; this may indicate the request has shutdown: {:?}", e);
63                }
64
65                let ingress = self.service_handler.clone();
66                let worker_id = "".to_string();
67                tokio::spawn(async move {
68                    log::trace!(worker_id, "handling new request");
69                    let result = ingress.handle_payload(req.message.payload).await;
70                    log::trace!(worker_id, "request handled: {:?}", result);
71                });
72            } else {
73                break;
74            }
75        }
76
77        Ok(())
78    }
79}