triton_distributed/transports/
etcd.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{error, log, CancellationToken, ErrorContext, Result, Runtime};
17
18use async_nats::jetstream::kv;
19use derive_builder::Builder;
20use derive_getters::Dissolve;
21use futures::StreamExt;
22use tokio::sync::mpsc;
23use validator::Validate;
24
25use etcd_client::{
26    Compare, CompareOp, GetOptions, KeyValue, PutOptions, Txn, TxnOp, WatchOptions, Watcher,
27};
28
29pub use etcd_client::{ConnectOptions, LeaseClient};
30
31mod lease;
32use lease::*;
33
34//pub use etcd::ConnectOptions as EtcdConnectOptions;
35
36/// ETCD Client
37#[derive(Clone)]
38pub struct Client {
39    client: etcd_client::Client,
40    primary_lease: i64,
41    runtime: Runtime,
42}
43
44#[derive(Debug, Clone)]
45pub struct Lease {
46    /// ETCD lease ID
47    id: i64,
48
49    /// [`CancellationToken`] associated with the lease
50    cancel_token: CancellationToken,
51}
52
53impl Lease {
54    /// Get the lease ID
55    pub fn id(&self) -> i64 {
56        self.id
57    }
58
59    /// Get the primary [`CancellationToken`] associated with the lease.
60    /// This token will revoke the lease if canceled.
61    pub fn primary_token(&self) -> CancellationToken {
62        self.cancel_token.clone()
63    }
64
65    /// Get a child [`CancellationToken`] from the lease's [`CancellationToken`].
66    /// This child token will be triggered if the lease is revoked, but will not revoke the lease if canceled.
67    pub fn child_token(&self) -> CancellationToken {
68        self.cancel_token.child_token()
69    }
70
71    /// Revoke the lease triggering the [`CancellationToken`].
72    pub fn revoke(&self) {
73        self.cancel_token.cancel();
74    }
75}
76
77impl Client {
78    pub fn builder() -> ClientOptionsBuilder {
79        ClientOptionsBuilder::default()
80    }
81
82    /// Create a new discovery client
83    ///
84    /// This will establish a connection to the etcd server, create a primary lease,
85    /// and spawn a task to keep the lease alive and tie the lifetime of the [`Runtime`]
86    /// to the lease.
87    ///
88    /// If the lease expires, the [`Runtime`] will be shutdown.
89    /// If the [`Runtime`] is shutdown, the lease will be revoked.
90    pub async fn new(config: ClientOptions, runtime: Runtime) -> Result<Self> {
91        runtime
92            .secondary()
93            .spawn(Self::create(config, runtime.clone()))
94            .await?
95    }
96
97    /// Create a new etcd client and tie the primary [`CancellationToken`] to the primary etcd lease.
98    async fn create(config: ClientOptions, runtime: Runtime) -> Result<Self> {
99        let token = runtime.primary_token();
100        let client =
101            etcd_client::Client::connect(config.etcd_url, config.etcd_connect_options).await?;
102        let lease_client = client.lease_client();
103
104        let lease = create_lease(lease_client, 10, token)
105            .await
106            .context("creating primary lease")?;
107
108        Ok(Client {
109            client,
110            primary_lease: lease.id,
111            runtime,
112        })
113    }
114
115    /// Get a reference to the underlying [`etcd_client::Client`] instance.
116    pub fn etcd_client(&self) -> &etcd_client::Client {
117        &self.client
118    }
119
120    /// Get the primary lease ID.
121    pub fn lease_id(&self) -> i64 {
122        self.primary_lease
123    }
124
125    /// Primary [`Lease`]
126    pub fn primary_lease(&self) -> Lease {
127        Lease {
128            id: self.primary_lease,
129            cancel_token: self.runtime.primary_token(),
130        }
131    }
132
133    /// Create a [`Lease`] with a given time-to-live (TTL).
134    /// This [`Lease`] will be tied to the [`Runtime`], specifically a child [`CancellationToken`].
135    pub async fn create_lease(&self, ttl: i64) -> Result<Lease> {
136        let token = self.runtime.child_token();
137        let lease_client = self.client.lease_client();
138        self.runtime
139            .secondary()
140            .spawn(create_lease(lease_client, ttl, token))
141            .await?
142    }
143
144    pub async fn kv_create(
145        &self,
146        key: String,
147        value: Vec<u8>,
148        lease_id: Option<i64>,
149    ) -> Result<()> {
150        let put_options = lease_id.map(|id| PutOptions::new().with_lease(id));
151
152        // Build the transaction
153        let txn = Txn::new()
154            .when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Ensure the lock does not exist
155            .and_then(vec![
156                TxnOp::put(key.as_str(), value, put_options), // Create the object
157            ]);
158
159        // Execute the transaction
160        let _ = self.client.kv_client().txn(txn).await?;
161
162        Ok(())
163    }
164
165    pub async fn kv_get_prefix(&self, prefix: impl AsRef<str>) -> Result<Vec<KeyValue>> {
166        let mut get_response = self
167            .client
168            .kv_client()
169            .get(prefix.as_ref(), Some(GetOptions::new().with_prefix()))
170            .await?;
171
172        Ok(get_response.take_kvs())
173    }
174
175    pub async fn kv_get_and_watch_prefix(&self, prefix: impl AsRef<str>) -> Result<PrefixWatcher> {
176        let mut kv_client = self.client.kv_client();
177        let mut watch_client = self.client.watch_client();
178
179        let mut get_response = kv_client
180            .get(prefix.as_ref(), Some(GetOptions::new().with_prefix()))
181            .await?;
182
183        let start_revision = get_response
184            .header()
185            .ok_or(error!("missing header; unable to get revision"))?
186            .revision();
187
188        let (watcher, mut watch_stream) = watch_client
189            .watch(
190                prefix.as_ref(),
191                Some(
192                    WatchOptions::new()
193                        .with_prefix()
194                        .with_start_revision(start_revision),
195                ),
196            )
197            .await?;
198
199        let kvs = get_response.take_kvs();
200
201        let (tx, rx) = mpsc::channel(32);
202
203        self.runtime.secondary().spawn(async move {
204            for kv in kvs {
205                if tx.send(WatchEvent::Put(kv)).await.is_err() {
206                    // receiver is closed
207                    break;
208                }
209            }
210
211            while let Some(Ok(response)) = watch_stream.next().await {
212                for event in response.events() {
213                    match event.event_type() {
214                        etcd_client::EventType::Put => {
215                            if let Some(kv) = event.kv() {
216                                if tx.send(WatchEvent::Put(kv.clone())).await.is_err() {
217                                    // receiver is closed
218                                    break;
219                                }
220                            }
221                        }
222                        etcd_client::EventType::Delete => {
223                            if let Some(kv) = event.kv() {
224                                if tx.send(WatchEvent::Delete(kv.clone())).await.is_err() {
225                                    // receiver is closed
226                                    break;
227                                }
228                            }
229                        }
230                    }
231                }
232            }
233        });
234
235        Ok(PrefixWatcher {
236            prefix: prefix.as_ref().to_string(),
237            watcher,
238            rx,
239        })
240    }
241}
242
243#[derive(Dissolve)]
244pub struct PrefixWatcher {
245    prefix: String,
246    watcher: Watcher,
247    rx: mpsc::Receiver<WatchEvent>,
248}
249
250pub enum WatchEvent {
251    Put(KeyValue),
252    Delete(KeyValue),
253}
254
255/// ETCD client configuration options
256#[derive(Debug, Clone, Builder, Validate)]
257pub struct ClientOptions {
258    #[validate(length(min = 1))]
259    etcd_url: Vec<String>,
260
261    #[builder(default)]
262    etcd_connect_options: Option<ConnectOptions>,
263}
264
265impl Default for ClientOptions {
266    fn default() -> Self {
267        ClientOptions {
268            etcd_url: default_servers(),
269            etcd_connect_options: None,
270        }
271    }
272}
273
274fn default_servers() -> Vec<String> {
275    match std::env::var("ETCD_ENDPOINTS") {
276        Ok(possible_list_of_urls) => possible_list_of_urls
277            .split(',')
278            .map(|s| s.to_string())
279            .collect(),
280        Err(_) => vec!["http://localhost:2379".to_string()],
281    }
282}