Skip to main content

provider_agent/
ws_client.rs

1//! Long-lived WSS client. Performs the signed handshake (`auth_challenge` →
2//! `auth_response` → `auth_ok`) and runs heartbeat/job loops thereafter.
3//! Reconnect with exponential backoff per `plan/V2_AGENT_SPEC.md` §9.
4//!
5//! NOTE: this is the scaffold. Job dispatch and capability discovery are
6//! implemented in subsequent tasks (#15/#16). For now we authenticate, send a
7//! placeholder `capabilities` payload, then drive heartbeats while listening
8//! for and tracing-logging incoming messages.
9
10use std::time::Duration;
11
12use anyhow::{Context, Result, anyhow, bail};
13use base64::Engine as _;
14use base64::engine::general_purpose::STANDARD as B64;
15use futures_util::{SinkExt, StreamExt};
16use rand::Rng;
17use serde_json::{Value, json};
18use thiserror::Error;
19use tokio::sync::mpsc;
20use tokio_tungstenite::tungstenite::Message;
21use tokio_tungstenite::tungstenite::protocol::CloseFrame;
22use tracing::{debug, error, info, warn};
23
24/// How a steady-state coordinator connection ended. Lets `run` choose
25/// INFO vs WARN for the reconnect log: a clean `Close(1001 GoingAway)`
26/// frame — typical during a blue-green deploy — is expected and the
27/// agent should not pretend it's an outage.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29enum CloseKind {
30    /// Coordinator sent `Close(1001, reason="drain:deploy")` — almost
31    /// always blue-green rollover.
32    DrainDeploy,
33    /// Coordinator sent `Close(1001, reason="drain:operator")` — admin /
34    /// operator-initiated disconnect. Reserved for a future kick endpoint.
35    DrainOperator,
36    /// Clean `Close(1001)` from the server but the reason was missing or
37    /// a value this agent build doesn't recognise (e.g. older server, or
38    /// newer one with a reason we'll learn about later). Still expected.
39    GracefulOther,
40    /// Stream EOF without a Close frame, outbound mpsc closed, or any
41    /// other path that wasn't a server-initiated graceful close. WARN.
42    Unexpected,
43}
44
45/// Inspect a `Close` frame and decide which `CloseKind` it represents.
46/// The reason string is the wire contract with `services/api/src/
47/// marketplace/connection.rs::ShutdownReason::as_wire`.
48fn classify_close(frame: Option<&CloseFrame>) -> CloseKind {
49    let Some(f) = frame else {
50        return CloseKind::GracefulOther;
51    };
52    let code: u16 = f.code.into();
53    if code != 1001 {
54        return CloseKind::Unexpected;
55    }
56    match f.reason.as_ref() {
57        "drain:deploy" => CloseKind::DrainDeploy,
58        "drain:operator" => CloseKind::DrainOperator,
59        _ => CloseKind::GracefulOther,
60    }
61}
62
63use crate::backend::{Job, WireFormat};
64use crate::config::Config;
65use crate::discovery;
66use crate::heartbeat;
67use crate::identity::Identity;
68use crate::job_executor::JobExecutor;
69
70const AGENT_VERSION: &str = env!("CARGO_PKG_VERSION");
71
72/// Connect-once outcome. We split pre-auth from post-auth failures so a deploy
73/// of the coordinator (which drops a steady-state session) doesn't escalate
74/// the reconnect backoff the same way a real outage does.
75#[derive(Debug, Error)]
76enum ConnectError {
77    /// Disconnected before reaching steady state (dial, TLS, handshake, auth
78    /// response). Treated as a real failure; reconnect backoff escalates.
79    #[error("pre-auth: {0:#}")]
80    PreAuth(anyhow::Error),
81    /// Disconnected after `auth_ok` and discovery succeeded — i.e. an
82    /// established session ended. Treated as a planned cycle (coordinator
83    /// restart, network blip); reconnect backoff resets.
84    #[error("post-auth: {0:#}")]
85    PostAuth(anyhow::Error),
86}
87
88/// Connect, authenticate, and run forever with reconnect.
89pub async fn run(cfg: Config, mut identity: Identity) -> Result<()> {
90    let mut backoff_ms: u64 = 1000;
91    let mut consecutive_failures: u32 = 0;
92
93    loop {
94        match connect_once(&cfg, &mut identity).await {
95            Ok(kind) => {
96                // Clean disconnect (server closed). Restart with the post-success backoff schedule.
97                match kind {
98                    CloseKind::DrainDeploy => {
99                        info!("coordinator drained for deploy; reconnecting")
100                    }
101                    CloseKind::DrainOperator => {
102                        info!("coordinator requested disconnect; reconnecting")
103                    }
104                    CloseKind::GracefulOther => {
105                        info!("coordinator going away; reconnecting")
106                    }
107                    CloseKind::Unexpected => {
108                        warn!("coordinator connection closed unexpectedly; reconnecting")
109                    }
110                }
111                consecutive_failures = 0;
112                backoff_ms = 1000;
113            }
114            Err(ConnectError::PostAuth(err)) => {
115                // Steady-state session ended — almost always a coordinator
116                // deploy or transient network issue. Reset like a clean close
117                // so we don't stretch the gap during a blue-green rotation.
118                warn!(?err, "coordinator session ended; reconnecting");
119                consecutive_failures = 0;
120                backoff_ms = 1000;
121            }
122            Err(ConnectError::PreAuth(err)) => {
123                consecutive_failures += 1;
124                error!(?err, attempts = consecutive_failures, "coordinator connection failed");
125                if consecutive_failures == 10 {
126                    error!("coordinator unreachable after 10 attempts; will keep retrying");
127                }
128            }
129        }
130
131        let jitter: f64 = rand::thread_rng().gen_range(0.8..1.2);
132        let sleep_ms = ((backoff_ms as f64) * jitter) as u64;
133        tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
134        backoff_ms = (backoff_ms.saturating_mul(2)).min(60_000);
135    }
136}
137
138async fn connect_once(cfg: &Config, identity: &mut Identity) -> Result<CloseKind, ConnectError> {
139    info!(url = %cfg.coordinator.url, "dialing coordinator");
140    let (ws, _resp) = tokio_tungstenite::connect_async(&cfg.coordinator.url)
141        .await
142        .with_context(|| format!("connecting to {}", cfg.coordinator.url))
143        .map_err(ConnectError::PreAuth)?;
144    let (mut sink, mut stream) = ws.split();
145
146    // 1. Receive auth_challenge
147    let challenge = recv_json(&mut stream).await.map_err(ConnectError::PreAuth)?;
148    if challenge.get("type").and_then(Value::as_str) != Some("auth_challenge") {
149        return Err(ConnectError::PreAuth(anyhow!(
150            "expected auth_challenge, got {challenge}"
151        )));
152    }
153    let nonce_b64 = challenge
154        .get("nonce")
155        .and_then(Value::as_str)
156        .ok_or_else(|| ConnectError::PreAuth(anyhow!("auth_challenge missing nonce")))?;
157    let nonce = B64
158        .decode(nonce_b64.as_bytes())
159        .context("decoding challenge nonce")
160        .map_err(ConnectError::PreAuth)?;
161
162    // 2. Sign nonce, send auth_response
163    let mut auth_response = json!({
164        "type": "auth_response",
165        "pubkey": identity.public_key_b64(),
166        "signature": identity.sign_b64(&nonce),
167        "agent_version": AGENT_VERSION,
168    });
169    if identity.provider_id.is_none() {
170        if let Some(code) = cfg.coordinator.enrollment_code.as_deref() {
171            auth_response["enrollment_code"] = json!(code);
172        }
173    }
174    sink.send(Message::Text(auth_response.to_string().into()))
175        .await
176        .map_err(|e| ConnectError::PreAuth(e.into()))?;
177
178    // 3. Await auth_ok
179    let ack = recv_json(&mut stream).await.map_err(ConnectError::PreAuth)?;
180    match ack.get("type").and_then(Value::as_str) {
181        Some("auth_ok") => {}
182        Some("auth_failed") => {
183            let reason = ack.get("reason").and_then(Value::as_str).unwrap_or("unknown");
184            return Err(ConnectError::PreAuth(anyhow!(
185                "coordinator rejected auth: {reason}"
186            )));
187        }
188        other => {
189            return Err(ConnectError::PreAuth(anyhow!(
190                "expected auth_ok, got type={other:?}"
191            )));
192        }
193    }
194    if let Some(pid) = ack.get("provider_id").and_then(Value::as_str) {
195        if identity.provider_id.as_deref() != Some(pid) {
196            info!(provider_id = pid, "persisting provider_id from coordinator");
197            identity
198                .set_provider_id(pid.to_string())
199                .map_err(ConnectError::PreAuth)?;
200        }
201    }
202    info!("authenticated with coordinator");
203
204    // 4. Run backend discovery and send a real `capabilities` payload.
205    //    Built fresh on every successful auth_ok so reconnect-after-coordinator-
206    //    restart re-registers the model list (Redis state may be cold).
207    let discovery_result = discovery::run(cfg).await;
208    info!(
209        models = discovery_result.capability_models.len(),
210        backends = discovery_result.backends.len(),
211        "discovery complete"
212    );
213    let capabilities = discovery_result.to_capabilities(cfg);
214    // Past this point, the coordinator has accepted us and we've started the
215    // steady-state pump. Any further error is a PostAuth — treat as a planned
216    // cycle so a deploy doesn't escalate the reconnect backoff.
217    sink.send(Message::Text(capabilities.to_string().into()))
218        .await
219        .map_err(|e| ConnectError::PostAuth(e.into()))?;
220    debug!("sent capabilities");
221
222    // 5. Spawn the heartbeat loop. We funnel both heartbeat and any future
223    //    outbound traffic through an mpsc to keep the WS sink single-owner.
224    let (out_tx, mut out_rx) = mpsc::channel::<Message>(64);
225    let hb_handle = tokio::spawn(heartbeat::spawn_loop(out_tx.clone()));
226
227    // The discovered backends become the dispatch table owned by the executor.
228    let executor = JobExecutor::new(
229        discovery_result.backends,
230        cfg.limits.max_concurrent,
231        out_tx.clone(),
232    );
233
234    // 6. Read loop / write pump.
235    let result: Result<CloseKind> = async {
236        loop {
237            tokio::select! {
238                outbound = out_rx.recv() => {
239                    match outbound {
240                        Some(msg) => sink.send(msg).await?,
241                        None => return Ok(CloseKind::Unexpected),
242                    }
243                }
244                inbound = stream.next() => {
245                    match inbound {
246                        Some(Ok(Message::Text(txt))) => {
247                            debug!(%txt, "ws inbound text");
248                            handle_inbound_text(&executor, &txt).await;
249                        }
250                        Some(Ok(Message::Ping(p))) => sink.send(Message::Pong(p)).await?,
251                        Some(Ok(Message::Close(frame))) => return Ok(classify_close(frame.as_ref())),
252                        Some(Ok(_)) => {}
253                        Some(Err(e)) => return Err(anyhow!("ws read error: {e}")),
254                        None => return Ok(CloseKind::Unexpected),
255                    }
256                }
257            }
258        }
259    }
260    .await;
261
262    hb_handle.abort();
263    result.map_err(ConnectError::PostAuth)
264}
265
266/// Parse an inbound coordinator frame and route `job` / `job_cancel` to the
267/// executor. Other types (`config_update`, etc.) are debug-logged for now;
268/// adding handlers here is non-invasive.
269async fn handle_inbound_text(executor: &JobExecutor, txt: &str) {
270    let v: Value = match serde_json::from_str(txt) {
271        Ok(v) => v,
272        Err(e) => {
273            warn!(error = %e, "ws inbound: invalid json");
274            return;
275        }
276    };
277    match v.get("type").and_then(Value::as_str) {
278        Some("job") => match parse_job(&v) {
279            Ok(job) => executor.dispatch(job).await,
280            Err(e) => warn!(error = %e, "ws inbound: malformed job"),
281        },
282        Some("job_cancel") => {
283            if let Some(id) = v.get("job_id").and_then(Value::as_str) {
284                match id.parse::<uuid::Uuid>() {
285                    Ok(job_id) => executor.cancel(job_id).await,
286                    Err(e) => warn!(error = %e, "ws inbound: bad job_id in job_cancel"),
287                }
288            }
289        }
290        Some(other) => debug!(kind = other, "ws inbound: unhandled message type"),
291        None => warn!("ws inbound: missing 'type'"),
292    }
293}
294
295fn parse_job(v: &Value) -> Result<Job> {
296    let job_id = v
297        .get("job_id")
298        .and_then(Value::as_str)
299        .ok_or_else(|| anyhow!("job missing job_id"))?
300        .parse::<uuid::Uuid>()
301        .context("job_id parse")?;
302    let model_id = v
303        .get("model_id")
304        .and_then(Value::as_str)
305        .ok_or_else(|| anyhow!("job missing model_id"))?
306        .to_string();
307    let request = v
308        .get("request")
309        .cloned()
310        .ok_or_else(|| anyhow!("job missing request"))?;
311    let format = match v.get("format").and_then(Value::as_str).unwrap_or("openai") {
312        "anthropic" => WireFormat::Anthropic,
313        _ => WireFormat::Openai,
314    };
315    let deadline_ms = v
316        .get("deadline_ms")
317        .and_then(Value::as_u64)
318        .unwrap_or(60_000) as u32;
319    Ok(Job { job_id, model_id, request, format, deadline_ms })
320}
321
322async fn recv_json<S>(stream: &mut S) -> Result<Value>
323where
324    S: StreamExt<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
325        + Unpin,
326{
327    loop {
328        let msg = stream
329            .next()
330            .await
331            .ok_or_else(|| anyhow!("ws closed before message received"))?
332            .context("ws read")?;
333        match msg {
334            Message::Text(txt) => {
335                return serde_json::from_str(&txt).context("parsing ws JSON");
336            }
337            Message::Binary(_) => bail!("unexpected binary frame during handshake"),
338            Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
339            Message::Close(_) => bail!("ws closed during handshake"),
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::{CloseKind, classify_close};
347    use tokio_tungstenite::tungstenite::protocol::CloseFrame;
348    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
349
350    fn frame(code: u16, reason: &'static str) -> CloseFrame {
351        CloseFrame {
352            code: CloseCode::from(code),
353            reason: reason.into(),
354        }
355    }
356
357    #[test]
358    fn classify_close_drain_deploy_is_recognised() {
359        assert_eq!(
360            classify_close(Some(&frame(1001, "drain:deploy"))),
361            CloseKind::DrainDeploy
362        );
363    }
364
365    #[test]
366    fn classify_close_drain_operator_is_recognised() {
367        assert_eq!(
368            classify_close(Some(&frame(1001, "drain:operator"))),
369            CloseKind::DrainOperator
370        );
371    }
372
373    #[test]
374    fn classify_close_1001_with_unknown_reason_is_graceful_other() {
375        // Older server / newer reason value we don't yet know about — both
376        // should be treated as graceful so we don't log WARN on a deploy
377        // mismatch.
378        assert_eq!(
379            classify_close(Some(&frame(1001, "coordinator shutting down"))),
380            CloseKind::GracefulOther
381        );
382        assert_eq!(
383            classify_close(Some(&frame(1001, ""))),
384            CloseKind::GracefulOther
385        );
386    }
387
388    #[test]
389    fn classify_close_missing_frame_is_graceful_other() {
390        // Some peers send a bare Close with no payload — still a clean
391        // close, just no annotation.
392        assert_eq!(classify_close(None), CloseKind::GracefulOther);
393    }
394
395    #[test]
396    fn classify_close_non_1001_code_is_unexpected() {
397        // Protocol error, abnormal close, etc. — these are NOT planned
398        // drains and should keep their WARN log.
399        assert_eq!(
400            classify_close(Some(&frame(1002, "protocol error"))),
401            CloseKind::Unexpected
402        );
403        assert_eq!(
404            classify_close(Some(&frame(1006, "abnormal"))),
405            CloseKind::Unexpected
406        );
407    }
408}
409