triton_distributed/transports/nats.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! NATS transport
17//!
18//! The following environment variables are used to configure the NATS client:
19//!
20//! - `NATS_SERVER`: the NATS server address
21//!
22//! For authentication, the following environment variables are used and prioritized in the following order:
23//!
24//! - `NATS_AUTH_USERNAME`: the username for authentication
25//! - `NATS_AUTH_PASSWORD`: the password for authentication
26//! - `NATS_AUTH_TOKEN`: the token for authentication
27//! - `NATS_AUTH_NKEY`: the nkey for authentication
28//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
29//!
30//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
31use crate::Result;
32
33use async_nats::{client, jetstream, Subscriber};
34use derive_builder::Builder;
35use futures::TryStreamExt;
36use std::path::PathBuf;
37use validator::{Validate, ValidationError};
38
39mod slug;
40pub use slug::Slug;
41
42#[derive(Clone)]
43pub struct Client {
44 client: client::Client,
45 js_ctx: jetstream::Context,
46}
47
48impl Client {
49 /// Create a NATS [`ClientOptionsBuilder`].
50 pub fn builder() -> ClientOptionsBuilder {
51 ClientOptionsBuilder::default()
52 }
53
54 /// Returns a reference to the underlying [`async_nats::client::Client`] instance
55 pub fn client(&self) -> &client::Client {
56 &self.client
57 }
58
59 /// Returns a reference to the underlying [`async_nats::jetstream::Context`] instance
60 pub fn jetstream(&self) -> &jetstream::Context {
61 &self.js_ctx
62 }
63
64 /// fetch the list of streams
65 pub async fn list_streams(&self) -> Result<Vec<String>> {
66 let names = self.js_ctx.stream_names();
67 let stream_names: Vec<String> = names.try_collect().await?;
68 Ok(stream_names)
69 }
70
71 /// fetch the list of consumers for a given stream
72 pub async fn list_consumers(&self, stream_name: &str) -> Result<Vec<String>> {
73 let stream = self.js_ctx.get_stream(stream_name).await?;
74 let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
75 Ok(consumers)
76 }
77
78 pub async fn stream_info(&self, stream_name: &str) -> Result<jetstream::stream::State> {
79 let mut stream = self.js_ctx.get_stream(stream_name).await?;
80 let info = stream.info().await?;
81 Ok(info.state.clone())
82 }
83
84 pub async fn get_stream(&self, name: &str) -> Result<jetstream::stream::Stream> {
85 let stream = self.js_ctx.get_stream(name).await?;
86 Ok(stream)
87 }
88
89 pub async fn service_subscriber(&self, service_name: &str) -> Result<Subscriber> {
90 let subject = format!("$SRV.STATS.{}", service_name);
91 let reply_subject = format!("_INBOX.{}", nuid::next());
92 let subscription = self.client.subscribe(reply_subject.clone()).await?;
93
94 // Publish the request with the reply-to subject
95 self.client
96 .publish_with_reply(subject, reply_subject, "".into())
97 .await?;
98
99 // // Set a timeout to gather responses
100 // let mut responses = Vec::new();
101 // // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
102
103 // let start = time::Instant::now();
104 // while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
105 // tx.send(message.payload);
106 // if start.elapsed() > timeout {
107 // break;
108 // }
109 // }
110
111 // Ok(responses)
112
113 Ok(subscription)
114 }
115
116 // /// create a new stream
117 // async fn get_or_create_work_queue_stream(
118 // &self,
119 // name: &super::Namespace,
120 // ) -> Result<jetstream::stream::Stream> {
121 // let stream = self
122 // .js_ctx
123 // .get_or_create_stream(async_nats::jetstream::stream::Config {
124 // name: name.to_string(),
125 // retention: async_nats::jetstream::stream::RetentionPolicy::WorkQueue,
126 // subjects: vec![format!("{name}.>")],
127 // ..Default::default()
128 // })
129 // .await?;
130 // Ok(stream)
131 // }
132
133 // // get work queue
134 // pub async fn get_or_create_work_queue(
135 // &self,
136 // namespace: &super::Namespace,
137 // queue_name: &Slug,
138 // ) -> Result<WorkQueue> {
139 // let stream = self.get_or_create_work_queue_stream(namespace).await?;
140
141 // let consumer_name = single_name(namespace, queue_name);
142 // let subject_name = subject_name(namespace, queue_name);
143 // let subject_name = format!("{}.*", subject_name);
144
145 // tracing::trace!(
146 // durable_name = consumer_name,
147 // filter_subject = subject_name,
148 // "get_or_create_work_queue"
149 // );
150 // let consumer = stream
151 // .get_or_create_consumer(
152 // &consumer_name,
153 // jetstream::consumer::pull::Config {
154 // durable_name: Some(consumer_name.clone()),
155 // filter_subject: subject_name,
156 // ack_policy: jetstream::consumer::AckPolicy::Explicit,
157 // ..Default::default()
158 // },
159 // )
160 // .await?;
161 // Ok(WorkQueue::new(consumer))
162 // }
163
164 // pub async fn get_or_create_work_queue_publisher(
165 // &self,
166 // namespace: &super::Namespace,
167 // queue_name: &Slug,
168 // ) -> Result<WorkQueuePublisher> {
169 // let _stream = self.get_or_create_work_queue_stream(namespace).await?;
170 // let _subject = subject_name(namespace, queue_name);
171 // Ok(WorkQueuePublisher {
172 // client: self.clone(),
173 // namespace: namespace.clone(),
174 // queue_name: queue_name.clone(),
175 // })
176 // }
177
178 // pub async fn list_work_queues(
179 // &self,
180 // namespace: &super::Namespace,
181 // ) -> Result<Vec<String>> {
182 // let stream = self.get_stream(namespace.as_ref()).await?;
183 // let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
184 // Ok(consumers)
185 // }
186
187 // /// remove a work queue
188 // pub async fn remove_work_queue(
189 // &self,
190 // namespace: &super::Namespace,
191 // queue_name: &Slug,
192 // ) -> Result<()> {
193 // let stream = self.get_stream(namespace.as_ref()).await?;
194 // let consumer_name = single_name(namespace, queue_name);
195 // let consumers = self.list_consumers(namespace.as_ref()).await?;
196 // if consumers.contains(&consumer_name) {
197 // stream.delete_consumer(&consumer_name).await?;
198 // }
199 // Ok(())
200 // }
201
202 // /// publish a message to a subject
203 // pub async fn publish(&self, subject: String, msg: Vec<u8>) -> Result<()> {
204 // self.client.publish(subject, msg.into()).await?;
205 // Ok(())
206 // }
207
208 // /// subscribe to a subject
209 // pub async fn subscribe(
210 // &self,
211 // subject: String,
212 // ) -> Result<async_nats::Subscriber> {
213 // let sub = self.client.subscribe(subject).await?;
214 // Ok(sub)
215 // }
216
217 // pub async fn enqueue(
218 // &self,
219 // namespace: &super::Namespace,
220 // queue_name: &Slug,
221 // payload: Bytes,
222 // ) -> Result<String> {
223 // // let mut headers = HeaderMap::new();
224 // let subject = subject_name(namespace, queue_name);
225 // let request_id = uuid::Uuid::new_v4().to_string();
226 // let subject = format!("{}.{}", subject, request_id);
227
228 // self.client.publish(subject, payload).await?;
229
230 // // self.client
231 // // .publish_with_headers(subject, headers, payload.into())
232 // // .await?;
233
234 // Ok(request_id)
235 // }
236
237 // pub async fn enqueue_with_id(
238 // &self,
239 // namespace: &super::Namespace,
240 // queue_name: &Slug,
241 // request_id: &str,
242 // payload: Vec<u8>,
243 // ) -> Result<()> {
244 // let subject = subject_name(namespace, queue_name);
245 // let subject = format!("{}.{}", subject, request_id);
246
247 // self.client.publish(subject, payload.into()).await?;
248 // Ok(())
249 // }
250
251 // pub async fn get_endpoints(
252 // &self,
253 // service_name: &str,
254 // timeout: Duration,
255 // ) -> Result<Vec<Bytes>, anyhow::Error> {
256 // let subject = format!("$SRV.STATS.{}", service_name);
257 // let reply_subject = format!("_INBOX.{}", nuid::next());
258 // let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
259
260 // // Publish the request with the reply-to subject
261 // self.client
262 // .publish_with_reply(subject, reply_subject, "".into())
263 // .await?;
264
265 // // Set a timeout to gather responses
266 // let mut responses = Vec::new();
267 // // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
268
269 // let start = time::Instant::now();
270 // while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
271 // responses.push(message.payload);
272 // if start.elapsed() > timeout {
273 // break;
274 // }
275 // }
276
277 // Ok(responses)
278 // }
279
280 // pub fn frontend_client(&self, request_id: String) -> SpecializedClient {
281 // SpecializedClient::new(self.client.clone(), ClientKind::Frontend, request_id)
282 // }
283
284 // pub fn backend_client(&self, request_id: String) -> SpecializedClient {
285 // SpecializedClient::new(self.client.clone(), ClientKind::Backend, request_id)
286 // }
287}
288
289/// NATS client options
290///
291/// This object uses the builder pattern with default values that are evaluates
292/// from the environment variables if they are not explicitly set by the builder.
293#[derive(Debug, Clone, Builder, Validate)]
294pub struct ClientOptions {
295 #[builder(setter(into), default = "default_server()")]
296 #[validate(custom(function = "validate_nats_server"))]
297 server: String,
298
299 #[builder(default)]
300 auth: NatsAuth,
301}
302
303fn default_server() -> String {
304 if let Ok(server) = std::env::var("NATS_SERVER") {
305 return server;
306 }
307
308 "nats://localhost:4222".to_string()
309}
310
311fn validate_nats_server(server: &str) -> Result<(), ValidationError> {
312 if server.starts_with("nats://") {
313 Ok(())
314 } else {
315 Err(ValidationError::new("server must start with 'nats://'"))
316 }
317}
318
319#[allow(dead_code)]
320impl ClientOptions {
321 /// Create a new [`ClientOptionsBuilder`]
322 pub fn builder() -> ClientOptionsBuilder {
323 ClientOptionsBuilder::default()
324 }
325
326 /// Validate the config and attempt to connection to the NATS server
327 pub async fn connect(self) -> Result<Client> {
328 self.validate()?;
329
330 let client = match self.auth {
331 NatsAuth::UserPass(username, password) => {
332 async_nats::ConnectOptions::with_user_and_password(username, password)
333 }
334 NatsAuth::Token(token) => async_nats::ConnectOptions::with_token(token),
335 NatsAuth::NKey(nkey) => async_nats::ConnectOptions::with_nkey(nkey),
336 NatsAuth::CredentialsFile(path) => {
337 async_nats::ConnectOptions::with_credentials_file(path).await?
338 }
339 };
340
341 let client = client.connect(self.server).await?;
342 let js_ctx = jetstream::new(client.clone());
343
344 Ok(Client { client, js_ctx })
345 }
346}
347
348impl Default for ClientOptions {
349 fn default() -> Self {
350 ClientOptions {
351 server: default_server(),
352 auth: NatsAuth::default(),
353 }
354 }
355}
356
357#[derive(Clone, Eq, PartialEq)]
358pub enum NatsAuth {
359 UserPass(String, String),
360 Token(String),
361 NKey(String),
362 CredentialsFile(PathBuf),
363}
364
365impl std::fmt::Debug for NatsAuth {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 match self {
368 NatsAuth::UserPass(user, _pass) => {
369 write!(f, "UserPass({}, <redacted>)", user)
370 }
371 NatsAuth::Token(_token) => write!(f, "Token(<redacted>)"),
372 NatsAuth::NKey(_nkey) => write!(f, "NKey(<redacted>)"),
373 NatsAuth::CredentialsFile(path) => write!(f, "CredentialsFile({:?})", path),
374 }
375 }
376}
377
378impl Default for NatsAuth {
379 fn default() -> Self {
380 if let (Ok(username), Ok(password)) = (
381 std::env::var("NATS_AUTH_USERNAME"),
382 std::env::var("NATS_AUTH_PASSWORD"),
383 ) {
384 return NatsAuth::UserPass(username, password);
385 }
386
387 if let Ok(token) = std::env::var("NATS_AUTH_TOKEN") {
388 return NatsAuth::Token(token);
389 }
390
391 if let Ok(nkey) = std::env::var("NATS_AUTH_NKEY") {
392 return NatsAuth::NKey(nkey);
393 }
394
395 if let Ok(path) = std::env::var("NATS_AUTH_CREDENTIALS_FILE") {
396 return NatsAuth::CredentialsFile(PathBuf::from(path));
397 }
398
399 NatsAuth::UserPass("user".to_string(), "user".to_string())
400 }
401}
402
403#[cfg(test)]
404mod tests {
405
406 use super::*;
407 use figment::Jail;
408
409 #[test]
410 fn test_client_options_builder() {
411 Jail::expect_with(|_jail| {
412 let opts = ClientOptions::builder().build();
413 assert!(opts.is_ok());
414 Ok(())
415 });
416
417 Jail::expect_with(|jail| {
418 jail.set_env("NATS_SERVER", "nats://localhost:5222");
419 jail.set_env("NATS_AUTH_USERNAME", "user");
420 jail.set_env("NATS_AUTH_PASSWORD", "pass");
421
422 let opts = ClientOptions::builder().build();
423 assert!(opts.is_ok());
424 let opts = opts.unwrap();
425
426 assert_eq!(opts.server, "nats://localhost:5222");
427 assert_eq!(
428 opts.auth,
429 NatsAuth::UserPass("user".to_string(), "pass".to_string())
430 );
431
432 Ok(())
433 });
434
435 Jail::expect_with(|jail| {
436 jail.set_env("NATS_SERVER", "nats://localhost:5222");
437 jail.set_env("NATS_AUTH_USERNAME", "user");
438 jail.set_env("NATS_AUTH_PASSWORD", "pass");
439
440 let opts = ClientOptions::builder()
441 .server("nats://localhost:6222")
442 .auth(NatsAuth::Token("token".to_string()))
443 .build();
444 assert!(opts.is_ok());
445 let opts = opts.unwrap();
446
447 assert_eq!(opts.server, "nats://localhost:6222");
448 assert_eq!(opts.auth, NatsAuth::Token("token".to_string()));
449
450 Ok(())
451 });
452 }
453
454 // const TEST_STREAM: &str = "test_async_nats_stream";
455
456 // #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
457 // struct Request {
458 // id: String,
459 // }
460
461 // async fn nats_client() -> Result<Client> {
462 // Client::builder()
463 // .server("nats://localhost:4222")
464 // .username("user")
465 // .password("user")
466 // .build()
467 // .await
468 // }
469
470 // #[tokio::test]
471 // async fn test_list_streams() {
472 // let client = match nats_client().await.ok() {
473 // Some(client) => client,
474 // None => {
475 // println!("Failed to create client; skipping nats tests");
476 // return;
477 // }
478 // };
479
480 // let streams = client.list_streams().await.expect("failed to list streams");
481
482 // for stream in streams {
483 // let info = client
484 // .stream_info(&stream)
485 // .await
486 // .expect("failed to get stream info");
487 // assert_eq!(info.messages, 0, "stream {} not empty", stream);
488 // }
489 // }
490
491 // #[tokio::test]
492 // async fn test_workq_pull_and_response_stream() {
493 // let ns: Namespace = TEST_STREAM.try_into().unwrap();
494 // let _client = match nats_client().await.ok() {
495 // Some(client) => client,
496 // None => {
497 // println!("Failed to create client; skipping nats tests");
498 // return;
499 // }
500 // };
501
502 // let client = Client::builder()
503 // .server("nats://localhost:4222")
504 // .username("user")
505 // .password("user")
506 // .build()
507 // .await
508 // .expect("failed to create client");
509
510 // let _streams = client.list_streams().await.expect("failed to list streams");
511 // // assert!(!streams.contains(&TEST_STREAM.to_string()));
512
513 // let _stream = client
514 // .get_or_create_work_queue_stream(&ns)
515 // .await
516 // .expect("failed to create stream");
517
518 // let model_name: Slug = "foo".try_into().unwrap();
519 // let request_id = "bar";
520
521 // let request = Request {
522 // id: request_id.to_string(),
523 // };
524
525 // let request_payload = serde_json::to_vec(&request).expect("failed to serialize request");
526
527 // // let request = CompletionRequest {
528 // // prompt: CompletionContext::from_prompt("deep learning is".to_string()).into(),
529 // // stop_conditions: None,
530 // // sampling_options: None,
531 // // };
532
533 // // remove work queue if it exists
534 // client
535 // .remove_work_queue(&ns, &model_name)
536 // .await
537 // .expect("remove work queue does not fail if queue does not exist");
538
539 // // get the count of the work queues
540 // let initial_work_queue_count = client
541 // .list_work_queues(&ns)
542 // .await
543 // .expect("failed to list work queues")
544 // .len();
545
546 // // create work queue
547 // let workq = client
548 // .get_or_create_work_queue(&ns, &model_name)
549 // .await
550 // .expect("failed to get work queue");
551
552 // // new work queue count
553 // let work_queue_count = client
554 // .list_work_queues(&ns)
555 // .await
556 // .expect("failed to list work queues")
557 // .len();
558
559 // assert_eq!(initial_work_queue_count, work_queue_count - 1);
560
561 // client
562 // .enqueue(&ns, &model_name, request_payload.into())
563 // .await
564 // .expect("failed to enqueue completion request");
565
566 // let mut messages = workq
567 // .pull(1, std::time::Duration::from_secs(1))
568 // .await
569 // .expect("failed to pull messages from work queue");
570
571 // assert_eq!(1, messages.len());
572
573 // let msg = messages.pop().expect("no message received");
574 // msg.ack().await.expect("failed to ack");
575
576 // let request: Request =
577 // serde_json::from_slice(&msg.payload).expect("failed to deserialize message");
578
579 // assert_eq!(request.id, request_id);
580
581 // // clean up and delete nats work queue and stream
582 // client
583 // .remove_work_queue(&ns, &model_name)
584 // .await
585 // .expect("failed to remove work queue");
586
587 // // client
588 // // .delete_stream(TEST_STREAM)
589 // // .await
590 // // .expect("failed to delete stream");
591 // }
592}
593// let frontend_client = client.frontend_client("test".to_string());
594
595// // the represents the frontend response subscription
596// let mut frontend_sub = frontend_client
597// .subscribe()
598// .await
599// .expect("failed to subscribe");
600
601// let backend_client = client.backend_client("test".to_string());
602
603// let mut backend_sub = backend_client
604// .subscribe()
605// .await
606// .expect("failed to subscribe");
607
608// let msg = messages[0].clone();
609// let req = serde_json::from_slice::<CompletionRequest>(&msg.payload)
610// .expect("failed to deserialize message");
611
612// msg.ack().await.expect("failed to ack");
613
614// assert_eq!(req.prompt, request.prompt);
615
616// // ping pong message between backend and frontend
617
618// // backend publishes to frontend
619// backend_client
620// .publish(&MessageKind::Initialize(Prologue {
621// formatted_prompt: None,
622// input_token_ids: None,
623// }))
624// .await
625// .expect("failed to publish");
626
627// // frontend receives initialize message
628// let msg = frontend_sub.next().await.expect("msg not received");
629// let msg = serde_json::from_slice::<MessageKind>(&msg.payload)
630// .expect("failed to deserialize message");
631
632// match msg {
633// MessageKind::Initialize(_) => {}
634// _ => panic!("unexpected message"),
635// }
636
637// // frontend publishes to backend
638// frontend_client
639// .publish(&MessageKind::Finalize(Epilogue {}))
640// .await
641// .expect("failed to publish");
642
643// // backend receives finalize message
644// let msg = backend_sub.next().await.expect("msg not received");
645// let msg = serde_json::from_slice::<MessageKind>(&msg.payload)
646// .expect("failed to deserialize message");
647
648// match &msg {
649// MessageKind::Finalize(_) => {}
650// _ => panic!("unexpected message"),
651// }
652
653// // delete the work queue
654// client
655// .remove_work_queue(model_name, TEST_STREAM)
656// .await
657// .expect("failed to remove work queue");
658
659// // new work queue count
660// let work_queue_count = client
661// .list_work_queues(TEST_STREAM)
662// .await
663// .expect("failed to list work queues")
664// .len();
665
666// // compare against the initial work queue count
667// assert_eq!(initial_work_queue_count, work_queue_count);
668// }
669
670// pub async fn get_endpoints(
671// &self,
672// service_name: &str,
673// timeout: Duration,
674// ) -> Result<Vec<Bytes>, anyhow::Error> {
675// let subject = format!("$SRV.STATS.{}", service_name);
676// let reply_subject = format!("_INBOX.{}", nuid::next());
677// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
678
679// // Publish the request with the reply-to subject
680// self.client
681// .publish_with_reply(subject, reply_subject, "".into())
682// .await?;
683
684// // Set a timeout to gather responses
685// let mut responses = Vec::new();
686// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
687
688// let start = time::Instant::now();
689// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
690// responses.push(message.payload);
691// if start.elapsed() > timeout {
692// break;
693// }
694// }
695
696// Ok(responses)
697// }
698
699// async fn connect(config: Arc<Config>) -> Result<NatsClient> {
700// let client = ClientOptions::builder()
701// .server(config.nats_address.clone())
702// .build()
703// .await
704// .context("Creating NATS Client")?;
705
706// Ok(client)
707// }
708
709// async fn create_service(
710// nats: NatsClient,
711// config: Arc<Config>,
712// observer: ServiceObserver,
713// ) -> Result<NatsService> {
714// let service = nats
715// .client()
716// .service_builder()
717// .description(config.service_description.as_str())
718// .stats_handler(move |_name, _stats| {
719// let stats = InstanceStats {
720// stage: observer.stage(),
721// };
722// serde_json::to_value(&stats).unwrap()
723// })
724// .start(
725// config.service_name.as_str(),
726// config.service_version.as_str(),
727// )
728// .await
729// .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
730// Ok(service)
731// }
732
733// async fn create_endpoint(
734// endpoint_name: impl Into<String>,
735// service: &NatsService,
736// ) -> Result<Endpoint> {
737// let info = service.info().await;
738// let group_name = format!("{}-{}", info.name, info.id);
739// let group = service.group(group_name);
740
741// let endpoint = group
742// .endpoint(endpoint_name.into())
743// .await
744// .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
745
746// Ok(endpoint)
747// }
748
749// async fn shutdown_endpoint_handler(
750// controller: ServiceController,
751// endpoint: Endpoint,
752// ) -> Result<()> {
753// let mut endpoint = endpoint;
754
755// // note: this is a child cancellation token, canceling it will not cancel the parent
756// // but the parent will cancel the child -- we only use this to observe if another
757// // controller has cancelled the service
758// let cancellation_token = controller.cancel_token();
759
760// loop {
761// let req = tokio::select! {
762// _ = cancellation_token.cancelled() => {
763// // log::trace!(worker_id, "Shutting down service {}", self.endpoint.name);
764// return Ok(());
765// }
766
767// // await on service request
768// req = endpoint.next() => {
769// req
770// }
771// };
772
773// if let Some(req) = req {
774// let response = "DONE".to_string();
775// if let Err(e) = req.respond(Ok(response.into())).await {
776// log::warn!("Failed to respond to the shutdown request: {:?}", e);
777// }
778
779// controller.set_stage(ServiceStage::ShuttingDown);
780// }
781// }
782// }
783
784// #[derive(Debug, Clone, Builder)]
785// pub struct Config {
786// /// The NATS server address
787// #[builder(default = "String::from(\"nats://localhost:4222\")")]
788// pub nats_address: String,
789
790// #[builder(setter(into), default = "String::from(SERVICE_NAME)")]
791// pub service_name: String,
792
793// #[builder(setter(into), default = "String::from(SERVICE_VERSION)")]
794// pub service_version: String,
795
796// #[builder(setter(into), default = "String::from(SERVICE_DESCRIPTION)")]
797// pub service_description: String,
798// }
799
800// impl Config {
801// pub fn new() -> Result<Config> {
802// Ok(ConfigBuilder::default().build()?)
803// }
804
805// /// Create a new [`ConfigBuilder`]
806// pub fn builder() -> ConfigBuilder {
807// ConfigBuilder::default()
808// }
809// }
810
811// // todo: move to icp - transports
812
813// #[derive(Clone, Debug)]
814// pub struct NatsClient {
815// client: Client,
816// js_ctx: jetstream::Context,
817// }
818
819// impl NatsClient {
820// pub fn client(&self) -> &Client {
821// &self.client
822// }
823
824// pub fn jetstream(&self) -> &jetstream::Context {
825// &self.js_ctx
826// }
827
828// pub fn service_builder(&self) -> NatsServiceBuilder {
829// self.client.service_builder()
830// }
831
832// pub async fn get_endpoints(
833// &self,
834// service_name: &str,
835// timeout: Duration,
836// ) -> Result<Vec<Bytes>, anyhow::Error> {
837// let subject = format!("$SRV.STATS.{}", service_name);
838// let reply_subject = format!("_INBOX.{}", nuid::next());
839// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
840
841// // Publish the request with the reply-to subject
842// self.client
843// .publish_with_reply(subject, reply_subject, "".into())
844// .await?;
845
846// // Set a timeout to gather responses
847// let mut responses = Vec::new();
848// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
849
850// let start = tokio::time::Instant::now();
851// while let Ok(Some(message)) = tokio::time::timeout(timeout, subscription.next()).await {
852// responses.push(message.payload);
853// if start.elapsed() > timeout {
854// break;
855// }
856// }
857
858// Ok(responses)
859// }
860// }
861
862// #[derive(Debug, Clone, Serialize, Deserialize)]
863// pub struct ServiceInfo {
864// pub name: String,
865// pub id: String,
866// pub version: String,
867// pub started: String,
868// pub endpoints: Vec<EndpointInfo>,
869// }
870
871// #[derive(Debug, Clone, Serialize, Deserialize)]
872// pub struct EndpointInfo {
873// pub name: String,
874// pub subject: String,
875// pub data: serde_json::Value,
876// }
877
878// impl EndpointInfo {
879// pub fn get<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
880// serde_json::from_value(self.data.clone()).map_err(Into::into)
881// }
882// }
883
884// #[derive(Clone, Debug, Builder)]
885// #[builder(build_fn(private, name = "build_internal"))]
886// pub struct ClientOptions {
887// #[builder(setter(into))]
888// server: String,
889
890// #[builder(setter(into, strip_option), default)]
891// username: Option<String>,
892
893// #[builder(setter(into, strip_option), default)]
894// password: Option<String>,
895// }
896
897// #[allow(dead_code)]
898// impl ClientOptions {
899// pub fn builder() -> ClientOptionsBuilder {
900// ClientOptionsBuilder::default()
901// }
902// }
903
904// impl ClientOptionsBuilder {
905// pub async fn build(&self) -> Result<NatsClient> {
906// let opts = self.build_internal()?;
907
908// // Create an unauthenticated connection to NATS.
909// let client = async_nats::ConnectOptions::new();
910
911// let client = if let (Some(username), Some(password)) = (opts.username, opts.password) {
912// client.user_and_password(username, password)
913// } else {
914// client
915// };
916
917// let client = client.connect(&opts.server).await?;
918
919// let js_ctx = jetstream::new(client.clone());
920
921// Ok(NatsClient { client, js_ctx })
922// }
923// }