1use std::time::Duration;
11
12use anyhow::{Context, Result, anyhow, bail};
13use base64::Engine as _;
14use base64::engine::general_purpose::STANDARD as B64;
15use futures_util::{SinkExt, StreamExt};
16use rand::Rng;
17use serde_json::{Value, json};
18use thiserror::Error;
19use tokio::sync::mpsc;
20use tokio_tungstenite::tungstenite::Message;
21use tokio_tungstenite::tungstenite::protocol::CloseFrame;
22use tracing::{debug, error, info, warn};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29enum CloseKind {
30 DrainDeploy,
33 DrainOperator,
36 GracefulOther,
40 Unexpected,
43}
44
45fn classify_close(frame: Option<&CloseFrame>) -> CloseKind {
49 let Some(f) = frame else {
50 return CloseKind::GracefulOther;
51 };
52 let code: u16 = f.code.into();
53 if code != 1001 {
54 return CloseKind::Unexpected;
55 }
56 match f.reason.as_ref() {
57 "drain:deploy" => CloseKind::DrainDeploy,
58 "drain:operator" => CloseKind::DrainOperator,
59 _ => CloseKind::GracefulOther,
60 }
61}
62
63use crate::backend::{Job, WireFormat};
64use crate::config::Config;
65use crate::discovery;
66use crate::heartbeat;
67use crate::identity::Identity;
68use crate::job_executor::JobExecutor;
69
70const AGENT_VERSION: &str = env!("CARGO_PKG_VERSION");
71
72#[derive(Debug, Error)]
76enum ConnectError {
77 #[error("pre-auth: {0:#}")]
80 PreAuth(anyhow::Error),
81 #[error("post-auth: {0:#}")]
85 PostAuth(anyhow::Error),
86}
87
88pub async fn run(cfg: Config, mut identity: Identity) -> Result<()> {
90 let mut backoff_ms: u64 = 1000;
91 let mut consecutive_failures: u32 = 0;
92
93 loop {
94 match connect_once(&cfg, &mut identity).await {
95 Ok(kind) => {
96 match kind {
98 CloseKind::DrainDeploy => {
99 info!("coordinator drained for deploy; reconnecting")
100 }
101 CloseKind::DrainOperator => {
102 info!("coordinator requested disconnect; reconnecting")
103 }
104 CloseKind::GracefulOther => {
105 info!("coordinator going away; reconnecting")
106 }
107 CloseKind::Unexpected => {
108 warn!("coordinator connection closed unexpectedly; reconnecting")
109 }
110 }
111 consecutive_failures = 0;
112 backoff_ms = 1000;
113 }
114 Err(ConnectError::PostAuth(err)) => {
115 warn!(?err, "coordinator session ended; reconnecting");
119 consecutive_failures = 0;
120 backoff_ms = 1000;
121 }
122 Err(ConnectError::PreAuth(err)) => {
123 consecutive_failures += 1;
124 error!(?err, attempts = consecutive_failures, "coordinator connection failed");
125 if consecutive_failures == 10 {
126 error!("coordinator unreachable after 10 attempts; will keep retrying");
127 }
128 }
129 }
130
131 let jitter: f64 = rand::thread_rng().gen_range(0.8..1.2);
132 let sleep_ms = ((backoff_ms as f64) * jitter) as u64;
133 tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
134 backoff_ms = (backoff_ms.saturating_mul(2)).min(60_000);
135 }
136}
137
138async fn connect_once(cfg: &Config, identity: &mut Identity) -> Result<CloseKind, ConnectError> {
139 info!(url = %cfg.coordinator.url, "dialing coordinator");
140 let (ws, _resp) = tokio_tungstenite::connect_async(&cfg.coordinator.url)
141 .await
142 .with_context(|| format!("connecting to {}", cfg.coordinator.url))
143 .map_err(ConnectError::PreAuth)?;
144 let (mut sink, mut stream) = ws.split();
145
146 let challenge = recv_json(&mut stream).await.map_err(ConnectError::PreAuth)?;
148 if challenge.get("type").and_then(Value::as_str) != Some("auth_challenge") {
149 return Err(ConnectError::PreAuth(anyhow!(
150 "expected auth_challenge, got {challenge}"
151 )));
152 }
153 let nonce_b64 = challenge
154 .get("nonce")
155 .and_then(Value::as_str)
156 .ok_or_else(|| ConnectError::PreAuth(anyhow!("auth_challenge missing nonce")))?;
157 let nonce = B64
158 .decode(nonce_b64.as_bytes())
159 .context("decoding challenge nonce")
160 .map_err(ConnectError::PreAuth)?;
161
162 let mut auth_response = json!({
164 "type": "auth_response",
165 "pubkey": identity.public_key_b64(),
166 "signature": identity.sign_b64(&nonce),
167 "agent_version": AGENT_VERSION,
168 });
169 if identity.provider_id.is_none() {
170 if let Some(code) = cfg.coordinator.enrollment_code.as_deref() {
171 auth_response["enrollment_code"] = json!(code);
172 }
173 }
174 sink.send(Message::Text(auth_response.to_string().into()))
175 .await
176 .map_err(|e| ConnectError::PreAuth(e.into()))?;
177
178 let ack = recv_json(&mut stream).await.map_err(ConnectError::PreAuth)?;
180 match ack.get("type").and_then(Value::as_str) {
181 Some("auth_ok") => {}
182 Some("auth_failed") => {
183 let reason = ack.get("reason").and_then(Value::as_str).unwrap_or("unknown");
184 return Err(ConnectError::PreAuth(anyhow!(
185 "coordinator rejected auth: {reason}"
186 )));
187 }
188 other => {
189 return Err(ConnectError::PreAuth(anyhow!(
190 "expected auth_ok, got type={other:?}"
191 )));
192 }
193 }
194 if let Some(pid) = ack.get("provider_id").and_then(Value::as_str) {
195 if identity.provider_id.as_deref() != Some(pid) {
196 info!(provider_id = pid, "persisting provider_id from coordinator");
197 identity
198 .set_provider_id(pid.to_string())
199 .map_err(ConnectError::PreAuth)?;
200 }
201 }
202 info!("authenticated with coordinator");
203
204 let discovery_result = discovery::run(cfg).await;
208 info!(
209 models = discovery_result.capability_models.len(),
210 backends = discovery_result.backends.len(),
211 "discovery complete"
212 );
213 let capabilities = discovery_result.to_capabilities(cfg);
214 sink.send(Message::Text(capabilities.to_string().into()))
218 .await
219 .map_err(|e| ConnectError::PostAuth(e.into()))?;
220 debug!("sent capabilities");
221
222 let (out_tx, mut out_rx) = mpsc::channel::<Message>(64);
225 let hb_handle = tokio::spawn(heartbeat::spawn_loop(out_tx.clone()));
226
227 let executor = JobExecutor::new(
229 discovery_result.backends,
230 cfg.limits.max_concurrent,
231 out_tx.clone(),
232 );
233
234 let result: Result<CloseKind> = async {
236 loop {
237 tokio::select! {
238 outbound = out_rx.recv() => {
239 match outbound {
240 Some(msg) => sink.send(msg).await?,
241 None => return Ok(CloseKind::Unexpected),
242 }
243 }
244 inbound = stream.next() => {
245 match inbound {
246 Some(Ok(Message::Text(txt))) => {
247 debug!(%txt, "ws inbound text");
248 handle_inbound_text(&executor, &txt).await;
249 }
250 Some(Ok(Message::Ping(p))) => sink.send(Message::Pong(p)).await?,
251 Some(Ok(Message::Close(frame))) => return Ok(classify_close(frame.as_ref())),
252 Some(Ok(_)) => {}
253 Some(Err(e)) => return Err(anyhow!("ws read error: {e}")),
254 None => return Ok(CloseKind::Unexpected),
255 }
256 }
257 }
258 }
259 }
260 .await;
261
262 hb_handle.abort();
263 result.map_err(ConnectError::PostAuth)
264}
265
266async fn handle_inbound_text(executor: &JobExecutor, txt: &str) {
270 let v: Value = match serde_json::from_str(txt) {
271 Ok(v) => v,
272 Err(e) => {
273 warn!(error = %e, "ws inbound: invalid json");
274 return;
275 }
276 };
277 match v.get("type").and_then(Value::as_str) {
278 Some("job") => match parse_job(&v) {
279 Ok(job) => executor.dispatch(job).await,
280 Err(e) => warn!(error = %e, "ws inbound: malformed job"),
281 },
282 Some("job_cancel") => {
283 if let Some(id) = v.get("job_id").and_then(Value::as_str) {
284 match id.parse::<uuid::Uuid>() {
285 Ok(job_id) => executor.cancel(job_id).await,
286 Err(e) => warn!(error = %e, "ws inbound: bad job_id in job_cancel"),
287 }
288 }
289 }
290 Some(other) => debug!(kind = other, "ws inbound: unhandled message type"),
291 None => warn!("ws inbound: missing 'type'"),
292 }
293}
294
295fn parse_job(v: &Value) -> Result<Job> {
296 let job_id = v
297 .get("job_id")
298 .and_then(Value::as_str)
299 .ok_or_else(|| anyhow!("job missing job_id"))?
300 .parse::<uuid::Uuid>()
301 .context("job_id parse")?;
302 let model_id = v
303 .get("model_id")
304 .and_then(Value::as_str)
305 .ok_or_else(|| anyhow!("job missing model_id"))?
306 .to_string();
307 let request = v
308 .get("request")
309 .cloned()
310 .ok_or_else(|| anyhow!("job missing request"))?;
311 let format = match v.get("format").and_then(Value::as_str).unwrap_or("openai") {
312 "anthropic" => WireFormat::Anthropic,
313 _ => WireFormat::Openai,
314 };
315 let deadline_ms = v
316 .get("deadline_ms")
317 .and_then(Value::as_u64)
318 .unwrap_or(60_000) as u32;
319 Ok(Job { job_id, model_id, request, format, deadline_ms })
320}
321
322async fn recv_json<S>(stream: &mut S) -> Result<Value>
323where
324 S: StreamExt<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
325 + Unpin,
326{
327 loop {
328 let msg = stream
329 .next()
330 .await
331 .ok_or_else(|| anyhow!("ws closed before message received"))?
332 .context("ws read")?;
333 match msg {
334 Message::Text(txt) => {
335 return serde_json::from_str(&txt).context("parsing ws JSON");
336 }
337 Message::Binary(_) => bail!("unexpected binary frame during handshake"),
338 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
339 Message::Close(_) => bail!("ws closed during handshake"),
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::{CloseKind, classify_close};
347 use tokio_tungstenite::tungstenite::protocol::CloseFrame;
348 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
349
350 fn frame(code: u16, reason: &'static str) -> CloseFrame {
351 CloseFrame {
352 code: CloseCode::from(code),
353 reason: reason.into(),
354 }
355 }
356
357 #[test]
358 fn classify_close_drain_deploy_is_recognised() {
359 assert_eq!(
360 classify_close(Some(&frame(1001, "drain:deploy"))),
361 CloseKind::DrainDeploy
362 );
363 }
364
365 #[test]
366 fn classify_close_drain_operator_is_recognised() {
367 assert_eq!(
368 classify_close(Some(&frame(1001, "drain:operator"))),
369 CloseKind::DrainOperator
370 );
371 }
372
373 #[test]
374 fn classify_close_1001_with_unknown_reason_is_graceful_other() {
375 assert_eq!(
379 classify_close(Some(&frame(1001, "coordinator shutting down"))),
380 CloseKind::GracefulOther
381 );
382 assert_eq!(
383 classify_close(Some(&frame(1001, ""))),
384 CloseKind::GracefulOther
385 );
386 }
387
388 #[test]
389 fn classify_close_missing_frame_is_graceful_other() {
390 assert_eq!(classify_close(None), CloseKind::GracefulOther);
393 }
394
395 #[test]
396 fn classify_close_non_1001_code_is_unexpected() {
397 assert_eq!(
400 classify_close(Some(&frame(1002, "protocol error"))),
401 CloseKind::Unexpected
402 );
403 assert_eq!(
404 classify_close(Some(&frame(1006, "abnormal"))),
405 CloseKind::Unexpected
406 );
407 }
408}
409