Skip to main content

zlayer_agent/
worker_client.rs

1//! Worker-side gRPC client for `ZLayer`'s worker tier.
2//!
3//! Lifecycle:
4//! 1. On startup: load persisted mTLS identity (cert + key) if present;
5//!    otherwise read the bootstrap token, generate a fresh EC P-256 keypair,
6//!    build a PKCS#10 CSR, call `Register`, persist the signed cert + key
7//!    + ca chain under `<data_dir>/worker/identity/`.
8//! 2. Background loops (each spawned as its own `JoinHandle`):
9//!    - `WatchAssignments`: server-streaming; receive `AssignmentEvent`s and
10//!      forward them via `assignment_tx` to the agent's executor.
11//!    - `ReportStatus`: bidi-streaming; tick at `(next_ttl_secs - jitter)`
12//!      sending a `StatusReport` snapshot; receive `StatusAck` and update
13//!      `next_ttl_secs`.
14//!    - `WatchCommands`: server-streaming; receive `CommandEvent`s and
15//!      forward via `command_tx`.
16//! 3. On disconnect: exponential backoff capped at 60s; on reconnect send a
17//!    full snapshot (`StatusReport.full_snapshot = true`).
18//!
19//! Implements `zlayer_scheduler::cluster::WorkerClient` so a `WorkerTierCluster`
20//! in worker mode can route Cluster trait calls through this.
21
22use std::net::SocketAddr;
23use std::path::{Path, PathBuf};
24use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
25use std::sync::Arc;
26use std::time::{Duration, SystemTime, UNIX_EPOCH};
27
28use async_trait::async_trait;
29use tokio::sync::{mpsc, RwLock};
30use tokio_stream::wrappers::ReceiverStream;
31use tokio_stream::StreamExt;
32use tonic::transport::{Certificate, ClientTlsConfig, Endpoint, Identity};
33use tonic::Request;
34use zlayer_cluster_rpc::proto;
35
36/// Errors produced by the worker client.
37#[derive(Debug, thiserror::Error)]
38pub enum WorkerClientError {
39    /// No configured control-plane servers.
40    #[error("no control-plane servers configured")]
41    NoServers,
42    /// We have neither a bootstrap token nor a persisted identity, so we
43    /// cannot authenticate against the control plane.
44    #[error("no bootstrap token and no persisted identity available")]
45    NoCredentials,
46    /// Failed to build a tonic `Endpoint` from a configured server URL.
47    #[error("invalid endpoint {endpoint:?}: {source}")]
48    InvalidEndpoint {
49        endpoint: String,
50        #[source]
51        source: tonic::transport::Error,
52    },
53    /// TLS configuration failed (e.g. malformed PEM).
54    #[error("tls config error: {0}")]
55    Tls(tonic::transport::Error),
56    /// gRPC transport error (connection refused, peer closed, etc.).
57    #[error("transport error: {0}")]
58    Transport(tonic::transport::Error),
59    /// A gRPC call returned a `tonic::Status` error.
60    #[error("grpc status: {0}")]
61    Status(tonic::Status),
62    /// Crypto / CSR generation failed.
63    #[error("rcgen error: {0}")]
64    Rcgen(rcgen::Error),
65    /// Persisted identity I/O failed.
66    #[error("identity io error: {0}")]
67    Io(std::io::Error),
68}
69
70impl From<tonic::Status> for WorkerClientError {
71    fn from(s: tonic::Status) -> Self {
72        Self::Status(s)
73    }
74}
75
76impl From<std::io::Error> for WorkerClientError {
77    fn from(e: std::io::Error) -> Self {
78        Self::Io(e)
79    }
80}
81
82impl From<rcgen::Error> for WorkerClientError {
83    fn from(e: rcgen::Error) -> Self {
84        Self::Rcgen(e)
85    }
86}
87
88/// Worker mTLS identity persisted to disk.
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct WorkerIdentity {
91    /// PEM-encoded worker certificate (issued by the control plane).
92    pub cert_pem: String,
93    /// PEM-encoded private key (EC P-256). Persisted with mode 0600.
94    pub key_pem: String,
95    /// PEM-encoded CA chain back to the cluster root.
96    pub ca_chain_pem: String,
97}
98
99/// Status snapshot provider for the worker. The agent's `ServiceManager`
100/// implements this so the worker client can report current container state
101/// + resource usage on each heartbeat tick.
102#[async_trait]
103pub trait WorkerStatusProvider: Send + Sync + std::fmt::Debug {
104    /// Snapshot the worker's current container states for `ReportStatus`.
105    async fn snapshot_containers(&self) -> Vec<zlayer_types::cluster::WorkerContainerStatus>;
106    /// Snapshot the worker's current resource usage.
107    async fn snapshot_resources(&self) -> zlayer_types::cluster::WorkerResourceUsage;
108}
109
110/// Worker-side gRPC client. Construct via [`WorkerClientImpl::new`], then
111/// spawn the background loops with [`WorkerClientImpl::start`].
112pub struct WorkerClientImpl {
113    inner: Arc<WorkerClientState>,
114}
115
116impl std::fmt::Debug for WorkerClientImpl {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        f.debug_struct("WorkerClientImpl")
119            .field("node_id", &self.inner.node_id.load(Ordering::SeqCst))
120            .finish_non_exhaustive()
121    }
122}
123
124struct WorkerClientState {
125    /// Endpoints of all known control-plane nodes; tried in order.
126    servers: RwLock<Vec<String>>,
127    /// Bootstrap token (used until we have a signed cert).
128    token: RwLock<Option<String>>,
129    /// Persisted identity (loaded once we have one).
130    identity: RwLock<Option<WorkerIdentity>>,
131    /// Most recently assigned node id.
132    node_id: AtomicU64,
133    /// Most recent leader address (set after Register / on reconnect).
134    leader_addr: RwLock<Option<SocketAddr>>,
135    /// Snapshot of cluster peers, populated by `WatchAssignments` / `Register`.
136    peers: RwLock<Vec<zlayer_scheduler::cluster::NodeRecord>>,
137    /// Worker profile (os, arch, labels, etc.).
138    profile: zlayer_types::cluster::WorkerProfile,
139    /// Identity persistence directory.
140    identity_dir: PathBuf,
141    /// Outbound assignment channel — the agent's executor pulls from here.
142    assignment_tx: mpsc::UnboundedSender<proto::AssignmentEvent>,
143    /// Outbound command channel.
144    command_tx: mpsc::UnboundedSender<proto::CommandEvent>,
145    /// Current adaptive TTL (seconds). Updated on every `StatusAck`.
146    current_ttl_secs: AtomicU32,
147    /// Highest revision we've received on `WatchAssignments`. Used to
148    /// resume on reconnect.
149    last_seen_revision: AtomicU64,
150    /// Set when we should send a full snapshot on the next `ReportStatus`
151    /// tick (after `Register` or any reconnect).
152    full_snapshot_pending: AtomicBool,
153    /// Snapshot of containers + resources used in `StatusReport`.
154    status_provider: Arc<dyn WorkerStatusProvider>,
155}
156
157impl WorkerClientImpl {
158    /// Create a fresh worker client. Returns the client plus two receivers:
159    /// `assignment_rx` and `command_rx`, which the agent's executor drains
160    /// to apply control-plane decisions locally.
161    ///
162    /// Call [`start`](Self::start) to spawn the background loops.
163    #[allow(clippy::needless_pass_by_value)]
164    pub fn new(
165        servers: Vec<String>,
166        token: Option<String>,
167        profile: zlayer_types::cluster::WorkerProfile,
168        identity_dir: PathBuf,
169        status_provider: Arc<dyn WorkerStatusProvider>,
170    ) -> (
171        Self,
172        mpsc::UnboundedReceiver<proto::AssignmentEvent>,
173        mpsc::UnboundedReceiver<proto::CommandEvent>,
174    ) {
175        let (assignment_tx, assignment_rx) = mpsc::unbounded_channel();
176        let (command_tx, command_rx) = mpsc::unbounded_channel();
177
178        // Try to load a previously-persisted identity. Best-effort: any error
179        // means we'll have to re-register using the bootstrap token.
180        let identity = load_identity(&identity_dir).ok().flatten();
181
182        let inner = Arc::new(WorkerClientState {
183            servers: RwLock::new(servers),
184            token: RwLock::new(token),
185            identity: RwLock::new(identity),
186            node_id: AtomicU64::new(0),
187            leader_addr: RwLock::new(None),
188            peers: RwLock::new(Vec::new()),
189            profile,
190            identity_dir,
191            assignment_tx,
192            command_tx,
193            current_ttl_secs: AtomicU32::new(30),
194            last_seen_revision: AtomicU64::new(0),
195            full_snapshot_pending: AtomicBool::new(true),
196            status_provider,
197        });
198
199        (Self { inner }, assignment_rx, command_rx)
200    }
201
202    /// Start the background reconnect loop. Spawns one supervisor task that
203    /// in turn fans out `WatchAssignments`, `ReportStatus`, and
204    /// `WatchCommands` per connected session.
205    #[must_use]
206    pub fn start(&self) -> tokio::task::JoinSet<()> {
207        let mut set = tokio::task::JoinSet::new();
208        let state = Arc::clone(&self.inner);
209        set.spawn(run_loop(state));
210        set
211    }
212}
213
214#[async_trait]
215impl zlayer_scheduler::cluster::WorkerClient for WorkerClientImpl {
216    async fn current_leader_addr(&self) -> Option<SocketAddr> {
217        *self.inner.leader_addr.read().await
218    }
219
220    async fn known_peers(&self) -> Vec<zlayer_scheduler::cluster::NodeRecord> {
221        self.inner.peers.read().await.clone()
222    }
223
224    fn assigned_node_id(&self) -> u64 {
225        self.inner.node_id.load(Ordering::SeqCst)
226    }
227}
228
229// ----------------------------------------------------------------------------
230// Background loops
231// ----------------------------------------------------------------------------
232
233/// Top-level supervisor: keep trying to (re)establish a session forever.
234async fn run_loop(state: Arc<WorkerClientState>) {
235    let mut backoff = Duration::from_secs(1);
236    let mut server_idx: usize = 0;
237    loop {
238        match connect_and_run(&state, &mut server_idx).await {
239            Ok(()) => {
240                tracing::info!("worker session ended cleanly; reconnecting");
241                backoff = Duration::from_secs(1);
242            }
243            Err(WorkerClientError::NoServers) => {
244                // Fatal-ish: we can't do anything until somebody hands us a
245                // server list. Sleep long-ish and check again.
246                tracing::warn!("no control-plane servers configured; sleeping 30s");
247                tokio::time::sleep(Duration::from_secs(30)).await;
248            }
249            Err(WorkerClientError::NoCredentials) => {
250                tracing::error!(
251                    "no bootstrap token and no persisted identity; cannot register; sleeping 30s"
252                );
253                tokio::time::sleep(Duration::from_secs(30)).await;
254            }
255            Err(e) => {
256                tracing::warn!(error = %e, "worker session ended; reconnecting after backoff");
257                tokio::time::sleep(backoff).await;
258                backoff = (backoff * 2).min(Duration::from_secs(60));
259            }
260        }
261        // On reconnect, the next ReportStatus must be a full snapshot.
262        state.full_snapshot_pending.store(true, Ordering::SeqCst);
263    }
264}
265
266/// One session: connect, optionally Register, then fan out the three streams.
267/// Returns when any of them fails or the channel breaks.
268async fn connect_and_run(
269    state: &Arc<WorkerClientState>,
270    server_idx: &mut usize,
271) -> Result<(), WorkerClientError> {
272    let endpoint_url = {
273        let servers = state.servers.read().await;
274        if servers.is_empty() {
275            return Err(WorkerClientError::NoServers);
276        }
277        let idx = *server_idx % servers.len();
278        *server_idx = server_idx.wrapping_add(1);
279        servers[idx].clone()
280    };
281
282    // 1. Build the channel (mTLS if we have an identity, otherwise plain).
283    let channel = build_channel(state, &endpoint_url).await?;
284
285    // 2. Update leader_addr best-effort from the endpoint URL.
286    if let Some(addr) = parse_addr_from_url(&endpoint_url) {
287        *state.leader_addr.write().await = Some(addr);
288    }
289
290    let mut client =
291        proto::cluster_control_plane_client::ClusterControlPlaneClient::new(channel.clone());
292
293    // 3. If we don't have a signed identity, register first.
294    if state.identity.read().await.is_none() {
295        register(state, &mut client).await?;
296        // Rebuild the channel using the new identity so subsequent streams
297        // run over mTLS.
298        let channel = build_channel(state, &endpoint_url).await?;
299        client = proto::cluster_control_plane_client::ClusterControlPlaneClient::new(channel);
300    }
301
302    let node_id = state.node_id.load(Ordering::SeqCst);
303    if node_id == 0 {
304        return Err(WorkerClientError::Status(
305            tonic::Status::failed_precondition("register did not assign node_id"),
306        ));
307    }
308
309    // 4. Fan out the three streams. First failure cancels the session.
310    let assignments_state = Arc::clone(state);
311    let mut assignments_client = client.clone();
312    let assignments_task = tokio::spawn(async move {
313        run_watch_assignments(&assignments_state, &mut assignments_client, node_id).await
314    });
315
316    let status_state = Arc::clone(state);
317    let mut status_client = client.clone();
318    let status_task =
319        tokio::spawn(
320            async move { run_report_status(&status_state, &mut status_client, node_id).await },
321        );
322
323    let commands_state = Arc::clone(state);
324    let mut commands_client = client;
325    let commands_task = tokio::spawn(async move {
326        run_watch_commands(&commands_state, &mut commands_client, node_id).await
327    });
328
329    // Wait for any of them to terminate.
330    let result = tokio::select! {
331        r = assignments_task => unwrap_join(r),
332        r = status_task => unwrap_join(r),
333        r = commands_task => unwrap_join(r),
334    };
335    result
336}
337
338fn unwrap_join(
339    r: Result<Result<(), WorkerClientError>, tokio::task::JoinError>,
340) -> Result<(), WorkerClientError> {
341    match r {
342        Ok(inner) => inner,
343        Err(e) => Err(WorkerClientError::Status(tonic::Status::internal(format!(
344            "task join error: {e}"
345        )))),
346    }
347}
348
349// ----------------------------------------------------------------------------
350// Stream loops
351// ----------------------------------------------------------------------------
352
353async fn run_watch_assignments(
354    state: &Arc<WorkerClientState>,
355    client: &mut proto::cluster_control_plane_client::ClusterControlPlaneClient<
356        tonic::transport::Channel,
357    >,
358    node_id: u64,
359) -> Result<(), WorkerClientError> {
360    let req = proto::WatchAssignmentsRequest {
361        node_id,
362        last_seen_revision: state.last_seen_revision.load(Ordering::SeqCst),
363    };
364    let resp = client.watch_assignments(Request::new(req)).await?;
365    let mut stream = resp.into_inner();
366    while let Some(event) = stream.next().await {
367        match event {
368            Ok(ev) => {
369                if ev.revision > state.last_seen_revision.load(Ordering::SeqCst) {
370                    state
371                        .last_seen_revision
372                        .store(ev.revision, Ordering::SeqCst);
373                }
374                if state.assignment_tx.send(ev).is_err() {
375                    tracing::warn!("assignment receiver dropped; exiting watch loop");
376                    return Ok(());
377                }
378            }
379            Err(status) => {
380                return Err(WorkerClientError::Status(status));
381            }
382        }
383    }
384    Ok(())
385}
386
387async fn run_watch_commands(
388    state: &Arc<WorkerClientState>,
389    client: &mut proto::cluster_control_plane_client::ClusterControlPlaneClient<
390        tonic::transport::Channel,
391    >,
392    node_id: u64,
393) -> Result<(), WorkerClientError> {
394    let req = proto::WatchCommandsRequest { node_id };
395    let resp = client.watch_commands(Request::new(req)).await?;
396    let mut stream = resp.into_inner();
397    while let Some(event) = stream.next().await {
398        match event {
399            Ok(ev) => {
400                if state.command_tx.send(ev).is_err() {
401                    tracing::warn!("command receiver dropped; exiting watch loop");
402                    return Ok(());
403                }
404            }
405            Err(status) => {
406                return Err(WorkerClientError::Status(status));
407            }
408        }
409    }
410    Ok(())
411}
412
413async fn run_report_status(
414    state: &Arc<WorkerClientState>,
415    client: &mut proto::cluster_control_plane_client::ClusterControlPlaneClient<
416        tonic::transport::Channel,
417    >,
418    node_id: u64,
419) -> Result<(), WorkerClientError> {
420    // Bidirectional: spawn a producer task that pushes StatusReports on
421    // an adaptive tick into an mpsc; the receiver-stream wrapper feeds them
422    // into tonic. The bidi response stream is drained inline.
423    let (tx, rx) = mpsc::channel::<proto::StatusReport>(8);
424    let outbound = ReceiverStream::new(rx);
425
426    // Spawn the producer.
427    let prod_state = Arc::clone(state);
428    let producer = tokio::spawn(async move {
429        produce_status_reports(prod_state, tx, node_id).await;
430    });
431
432    let resp = client.report_status(Request::new(outbound)).await?;
433    let mut acks = resp.into_inner();
434    while let Some(ack) = acks.next().await {
435        match ack {
436            Ok(a) => {
437                if a.next_ttl_secs > 0 {
438                    state
439                        .current_ttl_secs
440                        .store(a.next_ttl_secs, Ordering::SeqCst);
441                }
442            }
443            Err(status) => {
444                producer.abort();
445                return Err(WorkerClientError::Status(status));
446            }
447        }
448    }
449    producer.abort();
450    Ok(())
451}
452
453async fn produce_status_reports(
454    state: Arc<WorkerClientState>,
455    tx: mpsc::Sender<proto::StatusReport>,
456    node_id: u64,
457) {
458    loop {
459        // Drift slightly: tick at (ttl - jitter) where jitter is up to 25% of ttl
460        // (min 1s, max 5s) so we stay comfortably inside the grace window.
461        let ttl = state.current_ttl_secs.load(Ordering::SeqCst).max(1);
462        let jitter = (ttl / 4).clamp(1, 5);
463        let interval = u64::from(ttl.saturating_sub(jitter)).max(1);
464        tokio::time::sleep(Duration::from_secs(interval)).await;
465
466        // Build a snapshot.
467        let containers = state.status_provider.snapshot_containers().await;
468        let resources = state.status_provider.snapshot_resources().await;
469        let full = state.full_snapshot_pending.swap(false, Ordering::SeqCst);
470
471        let report = proto::StatusReport {
472            node_id,
473            ts: Some(now_proto_timestamp()),
474            containers: containers.into_iter().map(Into::into).collect(),
475            resources: Some(resources.into()),
476            full_snapshot: full,
477        };
478
479        if tx.send(report).await.is_err() {
480            // Tonic closed the stream; producer task should exit.
481            return;
482        }
483    }
484}
485
486fn now_proto_timestamp() -> prost_types::Timestamp {
487    match SystemTime::now().duration_since(UNIX_EPOCH) {
488        Ok(d) => prost_types::Timestamp {
489            seconds: i64::try_from(d.as_secs()).unwrap_or(i64::MAX),
490            nanos: i32::try_from(d.subsec_nanos()).unwrap_or(0),
491        },
492        Err(_) => prost_types::Timestamp {
493            seconds: 0,
494            nanos: 0,
495        },
496    }
497}
498
499// ----------------------------------------------------------------------------
500// Register
501// ----------------------------------------------------------------------------
502
503async fn register(
504    state: &Arc<WorkerClientState>,
505    client: &mut proto::cluster_control_plane_client::ClusterControlPlaneClient<
506        tonic::transport::Channel,
507    >,
508) -> Result<(), WorkerClientError> {
509    // We need a bootstrap token for the very first registration.
510    let token = state
511        .token
512        .read()
513        .await
514        .clone()
515        .ok_or(WorkerClientError::NoCredentials)?;
516
517    // Generate a fresh EC P-256 keypair and matching CSR. The key never
518    // leaves this process unencrypted on the wire — the CSR carries only
519    // the public half.
520    let key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)?;
521    let key_pem = key_pair.serialize_pem();
522
523    let params = rcgen::CertificateParams::default();
524    let csr = params.serialize_request(&key_pair)?;
525    let csr_der: Vec<u8> = csr.der().as_ref().to_vec();
526
527    let req = proto::RegisterRequest {
528        bootstrap_token: token,
529        desired_node_id: 0,
530        profile: Some(state.profile.clone().into()),
531        csr_der,
532    };
533
534    let resp = client.register(Request::new(req)).await?.into_inner();
535
536    if resp.node_id == 0 {
537        return Err(WorkerClientError::Status(
538            tonic::Status::failed_precondition("control plane returned node_id=0"),
539        ));
540    }
541    state.node_id.store(resp.node_id, Ordering::SeqCst);
542    if resp.heartbeat_ttl_secs > 0 {
543        state
544            .current_ttl_secs
545            .store(resp.heartbeat_ttl_secs, Ordering::SeqCst);
546    }
547
548    // Persist identity. The control plane returns DER-encoded certs; turn
549    // them into PEM blocks for tonic's `Identity::from_pem`.
550    let cert_pem = der_to_pem(&resp.signed_cert_der, "CERTIFICATE");
551    let ca_chain_pem = resp
552        .ca_chain_der
553        .iter()
554        .map(|d| der_to_pem(d, "CERTIFICATE"))
555        .collect::<String>();
556    let identity = WorkerIdentity {
557        cert_pem,
558        key_pem,
559        ca_chain_pem,
560    };
561
562    persist_identity(&state.identity_dir, &identity)?;
563    *state.identity.write().await = Some(identity);
564
565    tracing::info!(
566        node_id = resp.node_id,
567        ttl_secs = resp.heartbeat_ttl_secs,
568        "worker registered with control plane"
569    );
570
571    Ok(())
572}
573
574// ----------------------------------------------------------------------------
575// Channel construction
576// ----------------------------------------------------------------------------
577
578async fn build_channel(
579    state: &Arc<WorkerClientState>,
580    endpoint_url: &str,
581) -> Result<tonic::transport::Channel, WorkerClientError> {
582    let endpoint = Endpoint::from_shared(endpoint_url.to_string()).map_err(|e| {
583        WorkerClientError::InvalidEndpoint {
584            endpoint: endpoint_url.to_string(),
585            source: e,
586        }
587    })?;
588
589    let endpoint = if let Some(identity) = state.identity.read().await.clone() {
590        let tls = ClientTlsConfig::new()
591            .ca_certificate(Certificate::from_pem(identity.ca_chain_pem.as_bytes()))
592            .identity(Identity::from_pem(
593                identity.cert_pem.as_bytes(),
594                identity.key_pem.as_bytes(),
595            ));
596        endpoint.tls_config(tls).map_err(WorkerClientError::Tls)?
597    } else {
598        endpoint
599    };
600
601    endpoint
602        .connect()
603        .await
604        .map_err(WorkerClientError::Transport)
605}
606
607fn parse_addr_from_url(url: &str) -> Option<SocketAddr> {
608    // Strip "http://" / "https://" prefix and parse the host:port portion.
609    let trimmed = url
610        .trim_start_matches("https://")
611        .trim_start_matches("http://")
612        .trim_end_matches('/');
613    trimmed.parse().ok()
614}
615
616// ----------------------------------------------------------------------------
617// Identity persistence
618// ----------------------------------------------------------------------------
619
620fn identity_paths(dir: &Path) -> (PathBuf, PathBuf, PathBuf) {
621    (
622        dir.join("cert.pem"),
623        dir.join("key.pem"),
624        dir.join("ca.pem"),
625    )
626}
627
628fn load_identity(dir: &Path) -> Result<Option<WorkerIdentity>, WorkerClientError> {
629    let (cert_path, key_path, ca_path) = identity_paths(dir);
630    if !cert_path.exists() || !key_path.exists() || !ca_path.exists() {
631        return Ok(None);
632    }
633    let cert_pem = std::fs::read_to_string(&cert_path)?;
634    let key_pem = std::fs::read_to_string(&key_path)?;
635    let ca_chain_pem = std::fs::read_to_string(&ca_path)?;
636    Ok(Some(WorkerIdentity {
637        cert_pem,
638        key_pem,
639        ca_chain_pem,
640    }))
641}
642
643fn persist_identity(dir: &Path, identity: &WorkerIdentity) -> Result<(), WorkerClientError> {
644    std::fs::create_dir_all(dir)?;
645    let (cert_path, key_path, ca_path) = identity_paths(dir);
646    write_mode_0600(&cert_path, identity.cert_pem.as_bytes())?;
647    write_mode_0600(&key_path, identity.key_pem.as_bytes())?;
648    write_mode_0600(&ca_path, identity.ca_chain_pem.as_bytes())?;
649    Ok(())
650}
651
652fn write_mode_0600(path: &Path, bytes: &[u8]) -> Result<(), WorkerClientError> {
653    std::fs::write(path, bytes)?;
654    #[cfg(unix)]
655    {
656        use std::os::unix::fs::PermissionsExt;
657        let mut perms = std::fs::metadata(path)?.permissions();
658        perms.set_mode(0o600);
659        std::fs::set_permissions(path, perms)?;
660    }
661    Ok(())
662}
663
664fn der_to_pem(der: &[u8], label: &str) -> String {
665    use std::fmt::Write;
666    let b64 = base64_encode(der);
667    let mut out = String::with_capacity(b64.len() + 64);
668    let _ = writeln!(out, "-----BEGIN {label}-----");
669    for chunk in b64.as_bytes().chunks(64) {
670        out.push_str(std::str::from_utf8(chunk).expect("base64 is ascii"));
671        out.push('\n');
672    }
673    let _ = writeln!(out, "-----END {label}-----");
674    out
675}
676
677/// Minimal RFC 4648 base64 encoder. Avoids pulling in another crate just
678/// for this single use site.
679fn base64_encode(input: &[u8]) -> String {
680    const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
681    let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
682    let mut i = 0;
683    while i + 3 <= input.len() {
684        let b0 = input[i];
685        let b1 = input[i + 1];
686        let b2 = input[i + 2];
687        out.push(TABLE[(b0 >> 2) as usize] as char);
688        out.push(TABLE[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
689        out.push(TABLE[(((b1 & 0x0f) << 2) | (b2 >> 6)) as usize] as char);
690        out.push(TABLE[(b2 & 0x3f) as usize] as char);
691        i += 3;
692    }
693    match input.len() - i {
694        0 => {}
695        1 => {
696            let b0 = input[i];
697            out.push(TABLE[(b0 >> 2) as usize] as char);
698            out.push(TABLE[((b0 & 0x03) << 4) as usize] as char);
699            out.push('=');
700            out.push('=');
701        }
702        2 => {
703            let b0 = input[i];
704            let b1 = input[i + 1];
705            out.push(TABLE[(b0 >> 2) as usize] as char);
706            out.push(TABLE[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
707            out.push(TABLE[((b1 & 0x0f) << 2) as usize] as char);
708            out.push('=');
709        }
710        _ => unreachable!(),
711    }
712    out
713}
714
715// ----------------------------------------------------------------------------
716// Tests
717// ----------------------------------------------------------------------------
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722    use std::collections::HashMap;
723    use zlayer_scheduler::cluster::WorkerClient as _;
724
725    #[derive(Debug)]
726    struct DummyStatusProvider;
727
728    #[async_trait]
729    impl WorkerStatusProvider for DummyStatusProvider {
730        async fn snapshot_containers(&self) -> Vec<zlayer_types::cluster::WorkerContainerStatus> {
731            Vec::new()
732        }
733        async fn snapshot_resources(&self) -> zlayer_types::cluster::WorkerResourceUsage {
734            zlayer_types::cluster::WorkerResourceUsage {
735                cpu_used: 0.0,
736                memory_used_bytes: 0,
737                gpu_used: 0,
738            }
739        }
740    }
741
742    fn dummy_profile() -> zlayer_types::cluster::WorkerProfile {
743        zlayer_types::cluster::WorkerProfile {
744            api_addr: "127.0.0.1:3669".parse().unwrap(),
745            os: "linux".to_string(),
746            arch: "x86_64".to_string(),
747            labels: HashMap::new(),
748            cpu_total: 4,
749            memory_total_bytes: 8_000_000_000,
750        }
751    }
752
753    #[tokio::test]
754    async fn worker_client_starts_empty_with_no_servers() {
755        let dir = tempfile::tempdir().unwrap();
756        let (client, _assignments, _commands) = WorkerClientImpl::new(
757            Vec::new(),
758            None,
759            dummy_profile(),
760            dir.path().to_path_buf(),
761            Arc::new(DummyStatusProvider),
762        );
763        assert_eq!(client.assigned_node_id(), 0);
764        assert!(client.known_peers().await.is_empty());
765        assert!(client.current_leader_addr().await.is_none());
766    }
767
768    #[test]
769    fn worker_identity_persists_to_disk() {
770        let dir = tempfile::tempdir().unwrap();
771        let identity = WorkerIdentity {
772            cert_pem: "-----BEGIN CERTIFICATE-----\nAAAA\n-----END CERTIFICATE-----\n".to_string(),
773            key_pem: "-----BEGIN PRIVATE KEY-----\nBBBB\n-----END PRIVATE KEY-----\n".to_string(),
774            ca_chain_pem: "-----BEGIN CERTIFICATE-----\nCCCC\n-----END CERTIFICATE-----\n"
775                .to_string(),
776        };
777        persist_identity(dir.path(), &identity).expect("persist");
778        let loaded = load_identity(dir.path()).expect("load").expect("present");
779        assert_eq!(loaded, identity);
780
781        // Verify mode 0600 on Unix.
782        #[cfg(unix)]
783        {
784            use std::os::unix::fs::PermissionsExt;
785            let (cert, key, ca) = identity_paths(dir.path());
786            for p in [cert, key, ca] {
787                let meta = std::fs::metadata(&p).unwrap();
788                assert_eq!(meta.permissions().mode() & 0o777, 0o600, "{p:?}");
789            }
790        }
791    }
792
793    #[test]
794    fn base64_roundtrip_basic() {
795        // Spot-check standard test vectors from RFC 4648.
796        assert_eq!(base64_encode(b""), "");
797        assert_eq!(base64_encode(b"f"), "Zg==");
798        assert_eq!(base64_encode(b"fo"), "Zm8=");
799        assert_eq!(base64_encode(b"foo"), "Zm9v");
800        assert_eq!(base64_encode(b"foob"), "Zm9vYg==");
801        assert_eq!(base64_encode(b"fooba"), "Zm9vYmE=");
802        assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
803    }
804
805    #[test]
806    fn der_to_pem_wraps_with_label() {
807        let pem = der_to_pem(&[0x30, 0x82, 0x01, 0x00], "CERTIFICATE");
808        assert!(pem.starts_with("-----BEGIN CERTIFICATE-----\n"));
809        assert!(pem.trim_end().ends_with("-----END CERTIFICATE-----"));
810    }
811
812    #[test]
813    fn parse_addr_from_url_handles_http_prefix() {
814        assert_eq!(
815            parse_addr_from_url("http://127.0.0.1:3669"),
816            Some("127.0.0.1:3669".parse().unwrap())
817        );
818        assert_eq!(
819            parse_addr_from_url("https://10.0.0.1:443/"),
820            Some("10.0.0.1:443".parse().unwrap())
821        );
822        assert_eq!(parse_addr_from_url("not-a-url"), None);
823    }
824}