Skip to main content

veil_sdk/
client.rs

1use std::time::{Duration, Instant};
2
3use reqwest::{header, StatusCode};
4use serde::de::DeserializeOwned;
5use tracing::{debug, info, warn};
6
7use crate::{
8    error::{Result, VeilError},
9    types::{
10        Health, Job, JobStatus, Proof, RegisterModelRequest, RegisterModelResponse,
11        SubmitJobRequest, SubmitJobResponse, VerifyResult,
12    },
13};
14
15// ── Builder ───────────────────────────────────────────────────────────────────
16
17/// Builder for [`VeilClient`].
18///
19/// ```rust
20/// use veil_sdk::VeilClient;
21/// use std::time::Duration;
22///
23/// let client = VeilClient::builder()
24///     .base_url("http://localhost:8080")
25///     .timeout(Duration::from_secs(600))
26///     .poll_interval(Duration::from_secs(3))
27///     .build()
28///     .unwrap();
29/// ```
30#[derive(Debug)]
31pub struct VeilClientBuilder {
32    base_url: String,
33    timeout: Duration,
34    poll_interval: Duration,
35}
36
37impl Default for VeilClientBuilder {
38    fn default() -> Self {
39        Self {
40            base_url: "http://localhost:8080".to_string(),
41            timeout: Duration::from_secs(600),
42            poll_interval: Duration::from_secs(3),
43        }
44    }
45}
46
47impl VeilClientBuilder {
48    /// Base URL of the Veil gateway, e.g. `"https://api.mugen.network"`.
49    /// Trailing slashes are stripped automatically.
50    pub fn base_url(mut self, url: impl Into<String>) -> Self {
51        self.base_url = url.into().trim_end_matches('/').to_string();
52        self
53    }
54
55    /// Maximum wall-clock time to wait for a job to reach a terminal state.
56    /// Applies to `verify_inference`. Defaults to 600 seconds.
57    pub fn timeout(mut self, d: Duration) -> Self {
58        self.timeout = d;
59        self
60    }
61
62    /// How often to poll `GET /v1/jobs/{id}` while waiting.
63    /// Defaults to 3 seconds.
64    pub fn poll_interval(mut self, d: Duration) -> Self {
65        self.poll_interval = d;
66        self
67    }
68
69    /// Consume the builder and construct a [`VeilClient`].
70    ///
71    /// # Errors
72    /// Returns [`VeilError::InvalidUrl`] if the base URL cannot be parsed by
73    /// `reqwest`.
74    pub fn build(self) -> Result<VeilClient> {
75        // Validate the URL by attempting to construct a reqwest client
76        // with a test request (parse-only, no network).
77        reqwest::Url::parse(&self.base_url)
78            .map_err(|e| VeilError::InvalidUrl(format!("{}: {e}", self.base_url)))?;
79
80        let http = reqwest::Client::builder()
81            .default_headers({
82                let mut h = header::HeaderMap::new();
83                h.insert(
84                    header::CONTENT_TYPE,
85                    header::HeaderValue::from_static("application/json"),
86                );
87                h.insert(
88                    header::ACCEPT,
89                    header::HeaderValue::from_static("application/json"),
90                );
91                h
92            })
93            // reqwest's own connection timeout — separate from our poll timeout.
94            .connect_timeout(Duration::from_secs(10))
95            .build()
96            .map_err(VeilError::Http)?;
97
98        Ok(VeilClient {
99            http,
100            base_url: self.base_url,
101            timeout: self.timeout,
102            poll_interval: self.poll_interval,
103        })
104    }
105}
106
107// ── Client ────────────────────────────────────────────────────────────────────
108
109/// Async client for the Mugen Veil verifiable inference gateway.
110///
111/// Construct via [`VeilClient::builder()`].
112///
113/// `VeilClient` is cheap to clone — the underlying `reqwest::Client` uses an
114/// `Arc` internally and shares the connection pool across clones.
115#[derive(Debug, Clone)]
116pub struct VeilClient {
117    http: reqwest::Client,
118    base_url: String,
119    timeout: Duration,
120    poll_interval: Duration,
121}
122
123impl VeilClient {
124    /// Begin building a client. See [`VeilClientBuilder`].
125    pub fn builder() -> VeilClientBuilder {
126        VeilClientBuilder::default()
127    }
128
129    // ── Private helpers ───────────────────────────────────────────────────────
130
131    fn url(&self, path: &str) -> String {
132        format!("{}{path}", self.base_url)
133    }
134
135    /// Parse a response, surfacing gateway error bodies as [`VeilError::Api`].
136    async fn parse<T: DeserializeOwned>(&self, res: reqwest::Response) -> Result<T> {
137        let status = res.status();
138        if status.is_success() {
139            Ok(res.json::<T>().await?)
140        } else {
141            // Best-effort extraction of an `error` field from the JSON body.
142            let message = res
143                .json::<serde_json::Value>()
144                .await
145                .ok()
146                .and_then(|v| v.get("error").and_then(|e| e.as_str()).map(String::from))
147                .unwrap_or_else(|| status.to_string());
148
149            Err(VeilError::Api {
150                status: status.as_u16(),
151                message,
152            })
153        }
154    }
155
156    // ── Primitive methods ─────────────────────────────────────────────────────
157
158    /// `GET /healthz` — Returns gateway health information.
159    ///
160    /// Use [`Health::is_healthy()`] to check overall readiness.
161    pub async fn health_check(&self) -> Result<Health> {
162        debug!("GET /healthz");
163        let res = self.http.get(self.url("/healthz")).send().await?;
164        self.parse(res).await
165    }
166
167    /// `POST /v1/jobs` — Submit an inference job and return immediately.
168    ///
169    /// Returns the `job_id`. Use [`get_job`](Self::get_job) to poll status,
170    /// or [`verify_inference`](Self::verify_inference) to submit and wait.
171    ///
172    /// # Arguments
173    /// - `model_id`   — the model name registered with the gateway (e.g. `"tiny_mlp_v1"`)
174    /// - `input_data` — row-major input tensor, e.g. `vec![vec![0.1, 0.2, 0.3, 0.4]]`
175    pub async fn submit_job(
176        &self,
177        model_id: impl Into<String>,
178        input_data: Vec<Vec<f64>>,
179    ) -> Result<String> {
180        let body = SubmitJobRequest {
181            input_data,
182            model_id: model_id.into(),
183        };
184
185        debug!(model_id = %body.model_id, "POST /v1/jobs");
186
187        let res = self
188            .http
189            .post(self.url("/v1/jobs"))
190            .json(&body)
191            .send()
192            .await?;
193
194        let resp: SubmitJobResponse = self.parse(res).await?;
195        info!(job_id = %resp.job_id, "job submitted");
196        Ok(resp.job_id)
197    }
198
199    /// `GET /v1/jobs/{id}` — Poll the status of a job.
200    pub async fn get_job(&self, job_id: &str) -> Result<Job> {
201        debug!(%job_id, "GET /v1/jobs/{job_id}");
202        let res = self
203            .http
204            .get(self.url(&format!("/v1/jobs/{job_id}")))
205            .send()
206            .await?;
207        self.parse(res).await
208    }
209
210    /// `GET /v1/jobs/{id}/proof` — Fetch the raw proof bytes for a completed job.
211    ///
212    /// The gateway returns `HTTP 202` if the job is not yet complete.
213    /// Returns [`VeilError::Api`] with status 202 in that case — callers should
214    /// poll [`get_job`](Self::get_job) first.
215    pub async fn get_proof(&self, job_id: &str) -> Result<Proof> {
216        debug!(%job_id, "GET /v1/jobs/{job_id}/proof");
217        let res = self
218            .http
219            .get(self.url(&format!("/v1/jobs/{job_id}/proof")))
220            .send()
221            .await?;
222        self.parse(res).await
223    }
224
225    /// `POST /v1/models` — Register an ONNX model artifact with the gateway.
226    ///
227    /// Pins the artifact to IPFS and registers it on-chain.
228    /// Requires `PINATA_JWT` to be configured on the gateway.
229    pub async fn register_model(&self, req: RegisterModelRequest) -> Result<RegisterModelResponse> {
230        debug!(name = %req.name, version = %req.version, "POST /v1/models");
231        let res = self
232            .http
233            .post(self.url("/v1/models"))
234            .json(&req)
235            .send()
236            .await?;
237        self.parse(res).await
238    }
239
240    // ── High-level method ─────────────────────────────────────────────────────
241
242    /// Submit an inference job and block until it reaches a terminal state.
243    ///
244    /// Polls `GET /v1/jobs/{id}` at the configured `poll_interval` until the
245    /// job status is one of `done`, `settled`, or `failed`, or until the
246    /// configured `timeout` elapses.
247    ///
248    /// # Arguments
249    /// - `model_id`   — registered model name, e.g. `"tiny_mlp_v1"`
250    /// - `input_data` — row-major input tensor
251    ///
252    /// # Errors
253    /// - [`VeilError::Timeout`]    — polling exceeded the configured timeout
254    /// - [`VeilError::JobFailed`]  — the gateway reported the job as failed
255    /// - [`VeilError::Api`]        — gateway returned a non-2xx response
256    /// - [`VeilError::Http`]       — network-level failure
257    ///
258    /// # Example
259    /// ```rust,no_run
260    /// # use veil_sdk::VeilClient;
261    /// # #[tokio::main] async fn main() -> veil_sdk::error::Result<()> {
262    /// let client = VeilClient::builder()
263    ///     .base_url("http://localhost:8080")
264    ///     .build()?;
265    ///
266    /// let result = client
267    ///     .verify_inference("tiny_mlp_v1", vec![vec![0.1, 0.2, 0.3, 0.4]])
268    ///     .await?;
269    ///
270    /// println!("tx_hash: {:?}", result.tx_hash);
271    /// # Ok(())
272    /// # }
273    /// ```
274    pub async fn verify_inference(
275        &self,
276        model_id: impl Into<String>,
277        input_data: Vec<Vec<f64>>,
278    ) -> Result<VerifyResult> {
279        let model_id = model_id.into();
280        let started = Instant::now();
281
282        // 1. Submit
283        let job_id = self.submit_job(&model_id, input_data).await?;
284        info!(%job_id, %model_id, "job submitted — polling until terminal state");
285
286        // 2. Poll
287        let deadline = started + self.timeout;
288        let mut last_status = String::from("queued");
289
290        loop {
291            tokio::time::sleep(self.poll_interval).await;
292
293            if Instant::now() >= deadline {
294                return Err(VeilError::Timeout {
295                    job_id,
296                    elapsed_ms: started.elapsed().as_millis() as u64,
297                    last_status,
298                });
299            }
300
301            let job = match self.get_job(&job_id).await {
302                Ok(j) => j,
303                Err(e) => {
304                    // Transient network errors during polling are logged and
305                    // retried rather than propagated immediately.
306                    warn!(%job_id, "poll error (will retry): {e}");
307                    continue;
308                }
309            };
310
311            last_status = job.status.to_string();
312            debug!(%job_id, status = %last_status, "poll");
313
314            match &job.status {
315                JobStatus::Failed => {
316                    return Err(VeilError::JobFailed {
317                        job_id,
318                        reason: job.reason,
319                    });
320                }
321                s if s.is_terminal() => {
322                    let elapsed_ms = started.elapsed().as_millis() as u64;
323                    info!(%job_id, status = %last_status, elapsed_ms, "job complete");
324                    return Ok(VerifyResult {
325                        job_id,
326                        status: job.status,
327                        tx_hash: job.tx_hash,
328                        attestation_hash: job.attestation_hash,
329                        elapsed_ms,
330                    });
331                }
332                _ => {} // still in progress — keep polling
333            }
334        }
335    }
336}