triton_distributed/
service.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// TODO - refactor this entire module
17//
18// we want to carry forward the concept of live vs ready for the components
19// we will want to associate the components cancellation token with the
20// component's "service state"
21
22use crate::{log, transports::nats, Result};
23
24use async_nats::Message;
25use async_stream::try_stream;
26use bytes::Bytes;
27use derive_getters::Dissolve;
28use futures::stream::StreamExt;
29use serde::{de::DeserializeOwned, Deserialize, Serialize};
30use std::time::Duration;
31
32pub struct ServiceClient {
33    nats_client: nats::Client,
34}
35
36impl ServiceClient {
37    #[allow(dead_code)]
38    pub(crate) fn new(nats_client: nats::Client) -> Self {
39        ServiceClient { nats_client }
40    }
41}
42
43pub struct ServiceSet {
44    services: Vec<ServiceInfo>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ServiceInfo {
49    pub name: String,
50    pub id: String,
51    pub version: String,
52    pub started: String,
53    pub endpoints: Vec<EndpointInfo>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
57pub struct EndpointInfo {
58    pub name: String,
59    pub subject: String,
60
61    #[serde(flatten)]
62    pub data: Metrics,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
66pub struct Metrics(pub serde_json::Value);
67
68impl Metrics {
69    pub fn decode<T: DeserializeOwned>(self) -> Result<T> {
70        serde_json::from_value(self.0).map_err(Into::into)
71    }
72}
73
74impl ServiceClient {
75    pub async fn unary(
76        &self,
77        subject: impl Into<String>,
78        payload: impl Into<Bytes>,
79    ) -> Result<Message> {
80        let response = self
81            .nats_client
82            .client()
83            .request(subject.into(), payload.into())
84            .await?;
85        Ok(response)
86    }
87
88    pub async fn collect_services(&self, service_name: &str) -> Result<ServiceSet> {
89        let mut sub = self.nats_client.service_subscriber(service_name).await?;
90        let deadline = tokio::time::Instant::now() + Duration::from_secs(1);
91
92        let services: Vec<Result<ServiceInfo>> = try_stream! {
93            while let Ok(Some(message)) = tokio::time::timeout_at(deadline, sub.next()).await {
94                if message.payload.is_empty() {
95                    continue;
96                }
97                let service = serde_json::from_slice::<ServiceInfo>(&message.payload)?;
98                log::trace!("service: {:?}", service);
99                yield service;
100            }
101        }
102        .collect()
103        .await;
104
105        // split ok and error results
106        let (ok, err): (Vec<_>, Vec<_>) = services.into_iter().partition(Result::is_ok);
107
108        if !err.is_empty() {
109            log::error!("failed to collect services: {:?}", err);
110        }
111
112        Ok(ServiceSet {
113            services: ok.into_iter().map(Result::unwrap).collect(),
114        })
115    }
116}
117
118impl ServiceSet {
119    pub fn into_endpoints(self) -> impl Iterator<Item = EndpointInfo> {
120        self.services
121            .into_iter()
122            .flat_map(|s| s.endpoints.into_iter())
123    }
124}
125
126#[cfg(test)]
127mod tests {
128
129    use super::*;
130
131    #[test]
132    fn test_service_set() {
133        let services = vec![
134            ServiceInfo {
135                name: "service1".to_string(),
136                id: "1".to_string(),
137                version: "1.0".to_string(),
138                started: "2021-01-01".to_string(),
139                endpoints: vec![
140                    EndpointInfo {
141                        name: "endpoint1".to_string(),
142                        subject: "subject1".to_string(),
143                        data: Metrics(serde_json::json!({"key": "value1"})),
144                    },
145                    EndpointInfo {
146                        name: "endpoint2-foo".to_string(),
147                        subject: "subject2".to_string(),
148                        data: Metrics(serde_json::json!({"key": "value1"})),
149                    },
150                ],
151            },
152            ServiceInfo {
153                name: "service1".to_string(),
154                id: "2".to_string(),
155                version: "1.0".to_string(),
156                started: "2021-01-01".to_string(),
157                endpoints: vec![
158                    EndpointInfo {
159                        name: "endpoint1".to_string(),
160                        subject: "subject1".to_string(),
161                        data: Metrics(serde_json::json!({"key": "value1"})),
162                    },
163                    EndpointInfo {
164                        name: "endpoint2-bar".to_string(),
165                        subject: "subject2".to_string(),
166                        data: Metrics(serde_json::json!({"key": "value2"})),
167                    },
168                ],
169            },
170        ];
171
172        let service_set = ServiceSet { services };
173
174        let endpoints: Vec<_> = service_set
175            .into_endpoints()
176            .filter(|e| e.name.starts_with("endpoint2"))
177            .collect();
178
179        assert_eq!(endpoints.len(), 2);
180    }
181}