Skip to main content

rustbox_sdk/
lib.rs

1//! Official Rust client for [Rustbox](https://rustbox.orkait.com).
2//!
3//! ```no_run
4//! use rustbox_sdk::{Rustbox, SubmitRequest};
5//! # async fn run() -> Result<(), rustbox_sdk::RustboxError> {
6//! let client = Rustbox::new(&std::env::var("RUSTBOX_API_KEY").unwrap())?;
7//! let res = client.run(&SubmitRequest {
8//!     language: "python".into(),
9//!     code: "print('hi')".into(),
10//!     ..Default::default()
11//! }).await?;
12//! println!("{}", res["verdict"]);
13//! # Ok(()) }
14//! ```
15
16use std::time::Duration;
17
18use serde::Serialize;
19use thiserror::Error;
20use tokio::time::sleep;
21
22/// SDK version. Sent in `User-Agent`.
23pub const VERSION: &str = "0.1.0";
24
25/// Production endpoint.
26pub const DEFAULT_BASE_URL: &str = "https://rustbox-api.orkait.com";
27
28const DEFAULT_TIMEOUT: Duration = Duration::from_secs(65);
29const DEFAULT_MAX_RETRIES: u32 = 2;
30
31/// Execution profile.
32///
33/// - `Profile::Judge` (default): short evaluation runs.
34/// - `Profile::Agent`: longer jobs with egress proxy + per-key byte
35///   budgets. Requires a non-trial API key.
36#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq)]
37#[serde(rename_all = "lowercase")]
38pub enum Profile {
39    Judge,
40    Agent,
41}
42
43#[derive(Serialize, Debug, Clone, Default)]
44pub struct SubmitRequest {
45    pub language: String,
46    pub code: String,
47    pub stdin: String,
48    /// Optional profile override. None falls back to server-side default ("judge").
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub profile: Option<Profile>,
51}
52
53/// Errors returned by the Rustbox client.
54#[derive(Debug, Error)]
55pub enum RustboxError {
56    #[error("api_key required")]
57    MissingApiKey,
58    #[error("invalid base_url")]
59    InvalidBaseUrl,
60    #[error("invalid or missing API key (HTTP {0})")]
61    Auth(u16),
62    #[error("rate limit exceeded (HTTP 429)")]
63    RateLimit,
64    #[error("server error (HTTP {0})")]
65    Server(u16),
66    #[error("API error (HTTP {status}): {body}")]
67    Api { status: u16, body: String },
68    #[error("request timed out")]
69    Timeout,
70    #[error(transparent)]
71    Transport(#[from] reqwest::Error),
72    #[error("response decode failed: {0}")]
73    Decode(String),
74}
75
76/// Optional submit-only knobs that don't belong in the request body.
77#[derive(Debug, Clone, Default)]
78pub struct SubmitOptions {
79    /// `Idempotency-Key` header value. Safe to retry POST /api/submit when set.
80    pub idempotency_key: Option<String>,
81}
82
83pub struct Rustbox {
84    api_key: String,
85    base_url: String,
86    client: reqwest::Client,
87    max_retries: u32,
88}
89
90impl Rustbox {
91    /// Construct a Rustbox client. `api_key` is required (must be non-empty).
92    /// Base URL defaults to `DEFAULT_BASE_URL`; override with `with_base_url`.
93    pub fn new(api_key: &str) -> Result<Self, RustboxError> {
94        if api_key.is_empty() {
95            return Err(RustboxError::MissingApiKey);
96        }
97        let client = reqwest::Client::builder()
98            .timeout(DEFAULT_TIMEOUT)
99            .build()
100            .map_err(RustboxError::Transport)?;
101        Ok(Self {
102            api_key: api_key.to_string(),
103            base_url: DEFAULT_BASE_URL.to_string(),
104            client,
105            max_retries: DEFAULT_MAX_RETRIES,
106        })
107    }
108
109    /// Override the API base URL. Use for staging.
110    /// Trailing slashes are trimmed.
111    pub fn with_base_url(mut self, base_url: &str) -> Result<Self, RustboxError> {
112        if base_url.is_empty() {
113            return Err(RustboxError::InvalidBaseUrl);
114        }
115        self.base_url = base_url.trim_end_matches('/').to_string();
116        Ok(self)
117    }
118
119    /// Override the per-request timeout. Set `Duration::ZERO` to disable.
120    pub fn with_timeout(mut self, timeout: Duration) -> Result<Self, RustboxError> {
121        let mut builder = reqwest::Client::builder();
122        if !timeout.is_zero() {
123            builder = builder.timeout(timeout);
124        }
125        self.client = builder.build().map_err(RustboxError::Transport)?;
126        Ok(self)
127    }
128
129    /// Override the retry budget on transient (5xx, network) failures.
130    pub fn with_max_retries(mut self, n: u32) -> Self {
131        self.max_retries = n;
132        self
133    }
134
135    pub fn base_url(&self) -> &str {
136        &self.base_url
137    }
138
139    fn backoff_delay(&self, attempt: u32) -> Duration {
140        Duration::from_millis((100u64 * (1u64 << attempt.min(8))).min(5_000))
141    }
142
143    async fn send_with_retry(
144        &self,
145        build: impl Fn() -> reqwest::RequestBuilder,
146    ) -> Result<reqwest::Response, RustboxError> {
147        let mut last_err: Option<RustboxError> = None;
148        for attempt in 0..=self.max_retries {
149            let req = build()
150                .header("X-API-Key", &self.api_key)
151                .header("User-Agent", format!("rustbox-sdk-rust/{VERSION}"));
152            match req.send().await {
153                Ok(resp) => {
154                    if resp.status().as_u16() >= 500 && attempt < self.max_retries {
155                        sleep(self.backoff_delay(attempt)).await;
156                        continue;
157                    }
158                    return Ok(resp);
159                }
160                Err(e) => {
161                    let is_timeout = e.is_timeout();
162                    last_err = Some(if is_timeout {
163                        RustboxError::Timeout
164                    } else {
165                        RustboxError::Transport(e)
166                    });
167                    if attempt >= self.max_retries {
168                        return Err(last_err.unwrap());
169                    }
170                    sleep(self.backoff_delay(attempt)).await;
171                }
172            }
173        }
174        Err(last_err.unwrap_or(RustboxError::Decode("retry exhausted".into())))
175    }
176
177    async fn handle(&self, resp: reqwest::Response) -> Result<serde_json::Value, RustboxError> {
178        let status = resp.status();
179        let code = status.as_u16();
180        if status.is_success() || code == 408 {
181            return resp
182                .json()
183                .await
184                .map_err(|e| RustboxError::Decode(e.to_string()));
185        }
186        match code {
187            401 | 403 => Err(RustboxError::Auth(code)),
188            429 => Err(RustboxError::RateLimit),
189            500..=599 => Err(RustboxError::Server(code)),
190            _ => {
191                let body = resp.text().await.unwrap_or_default();
192                Err(RustboxError::Api { status: code, body })
193            }
194        }
195    }
196
197    pub async fn submit(
198        &self,
199        req: &SubmitRequest,
200        wait: bool,
201        opts: SubmitOptions,
202    ) -> Result<serde_json::Value, RustboxError> {
203        let url = format!("{}/api/submit?wait={}", self.base_url, wait);
204        let body = serde_json::to_vec(req).map_err(|e| RustboxError::Decode(e.to_string()))?;
205
206        let resp = self
207            .send_with_retry(|| {
208                let mut rb = self
209                    .client
210                    .post(&url)
211                    .header("Content-Type", "application/json")
212                    .body(body.clone());
213                if let Some(ref key) = opts.idempotency_key {
214                    rb = rb.header("Idempotency-Key", key);
215                }
216                rb
217            })
218            .await?;
219        self.handle(resp).await
220    }
221
222    pub async fn get_result(&self, id: &str) -> Result<serde_json::Value, RustboxError> {
223        let url = format!("{}/api/result/{}", self.base_url, id);
224        let resp = self.send_with_retry(|| self.client.get(&url)).await?;
225        self.handle(resp).await
226    }
227
228    pub async fn get_languages(&self) -> Result<Vec<String>, RustboxError> {
229        let url = format!("{}/api/languages", self.base_url);
230        let resp = self.send_with_retry(|| self.client.get(&url)).await?;
231        let val = self.handle(resp).await?;
232        serde_json::from_value(val).map_err(|e| RustboxError::Decode(e.to_string()))
233    }
234
235    pub async fn get_health(&self) -> Result<serde_json::Value, RustboxError> {
236        let url = format!("{}/api/health", self.base_url);
237        let resp = self.send_with_retry(|| self.client.get(&url)).await?;
238        self.handle(resp).await
239    }
240
241    pub async fn get_ready(&self) -> Result<serde_json::Value, RustboxError> {
242        let url = format!("{}/api/health/ready", self.base_url);
243        let resp = self.send_with_retry(|| self.client.get(&url)).await?;
244        self.handle(resp).await
245    }
246
247    /// Submit + wait (sync) + auto-poll fallback. Auto-generates an
248    /// Idempotency-Key so the underlying POST is safe to retry.
249    pub async fn run(&self, req: &SubmitRequest) -> Result<serde_json::Value, RustboxError> {
250        let opts = SubmitOptions {
251            idempotency_key: Some(idempotency_id()),
252        };
253        let mut res = self.submit(req, true, opts).await?;
254        if res.get("verdict").is_some() {
255            return Ok(res);
256        }
257
258        let id = match res.get("id").and_then(|v| v.as_str()) {
259            Some(i) => i.to_string(),
260            None => return Ok(res),
261        };
262
263        for i in 0..45 {
264            let delay_ms = (40.0 * (1.5_f64).powi(i)).min(600.0) as u64;
265            sleep(Duration::from_millis(delay_ms)).await;
266
267            res = self.get_result(&id).await?;
268            if res.get("verdict").is_some() {
269                return Ok(res);
270            }
271        }
272        Ok(res)
273    }
274}
275
276fn idempotency_id() -> String {
277    uuid::Uuid::new_v4().to_string()
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use wiremock::matchers::{method, path};
284    use wiremock::{Mock, MockServer, ResponseTemplate};
285
286    fn req() -> SubmitRequest {
287        SubmitRequest {
288            language: "python".into(),
289            code: "print(1)".into(),
290            stdin: "".into(),
291            profile: None,
292        }
293    }
294
295    #[tokio::test]
296    async fn new_should_default_base_url_to_production() {
297        let client = Rustbox::new("k").unwrap();
298        assert_eq!(client.base_url(), DEFAULT_BASE_URL);
299    }
300
301    #[tokio::test]
302    async fn new_should_return_err_when_api_key_empty() {
303        let r = Rustbox::new("");
304        assert!(matches!(r, Err(RustboxError::MissingApiKey)));
305    }
306
307    #[tokio::test]
308    async fn with_base_url_should_override_default_and_trim_slash() {
309        let client = Rustbox::new("k")
310            .unwrap()
311            .with_base_url("https://custom.example.com/")
312            .unwrap();
313        assert_eq!(client.base_url(), "https://custom.example.com");
314    }
315
316    #[tokio::test]
317    async fn with_base_url_should_return_err_when_empty() {
318        let r = Rustbox::new("k").unwrap().with_base_url("");
319        assert!(matches!(r, Err(RustboxError::InvalidBaseUrl)));
320    }
321
322    #[tokio::test]
323    async fn run_should_return_verdict_on_first_response_when_complete() {
324        let mock_server = MockServer::start().await;
325        Mock::given(method("POST"))
326            .and(path("/api/submit"))
327            .respond_with(
328                ResponseTemplate::new(200)
329                    .set_body_json(serde_json::json!({"id": "1", "verdict": "AC"})),
330            )
331            .mount(&mock_server)
332            .await;
333
334        let client = Rustbox::new("test")
335            .unwrap()
336            .with_base_url(&mock_server.uri())
337            .unwrap();
338        let res = client.run(&req()).await.unwrap();
339        assert_eq!(res.get("verdict").unwrap().as_str().unwrap(), "AC");
340    }
341
342    #[tokio::test]
343    async fn run_should_poll_until_verdict_when_initial_returns_408() {
344        let mock_server = MockServer::start().await;
345        Mock::given(method("POST"))
346            .and(path("/api/submit"))
347            .respond_with(ResponseTemplate::new(408).set_body_json(serde_json::json!({"id": "1"})))
348            .mount(&mock_server)
349            .await;
350
351        Mock::given(method("GET"))
352            .and(path("/api/result/1"))
353            .respond_with(
354                ResponseTemplate::new(200)
355                    .set_body_json(serde_json::json!({"id": "1", "verdict": "TLE"})),
356            )
357            .mount(&mock_server)
358            .await;
359
360        let client = Rustbox::new("test")
361            .unwrap()
362            .with_base_url(&mock_server.uri())
363            .unwrap();
364        let res = client.run(&req()).await.unwrap();
365        assert_eq!(res.get("verdict").unwrap().as_str().unwrap(), "TLE");
366    }
367
368    #[tokio::test]
369    async fn submit_should_return_auth_err_on_401() {
370        let mock_server = MockServer::start().await;
371        Mock::given(method("POST"))
372            .and(path("/api/submit"))
373            .respond_with(ResponseTemplate::new(401))
374            .mount(&mock_server)
375            .await;
376
377        let client = Rustbox::new("test")
378            .unwrap()
379            .with_base_url(&mock_server.uri())
380            .unwrap();
381        let err = client
382            .submit(&req(), false, SubmitOptions::default())
383            .await
384            .unwrap_err();
385        assert!(matches!(err, RustboxError::Auth(401)));
386    }
387
388    #[tokio::test]
389    async fn submit_should_return_rate_limit_on_429() {
390        let mock_server = MockServer::start().await;
391        Mock::given(method("POST"))
392            .and(path("/api/submit"))
393            .respond_with(ResponseTemplate::new(429))
394            .mount(&mock_server)
395            .await;
396
397        let client = Rustbox::new("test")
398            .unwrap()
399            .with_base_url(&mock_server.uri())
400            .unwrap();
401        let err = client
402            .submit(&req(), false, SubmitOptions::default())
403            .await
404            .unwrap_err();
405        assert!(matches!(err, RustboxError::RateLimit));
406    }
407
408    #[tokio::test]
409    async fn submit_should_return_server_err_on_503_after_retries() {
410        let mock_server = MockServer::start().await;
411        Mock::given(method("POST"))
412            .and(path("/api/submit"))
413            .respond_with(ResponseTemplate::new(503))
414            .mount(&mock_server)
415            .await;
416
417        let client = Rustbox::new("test")
418            .unwrap()
419            .with_base_url(&mock_server.uri())
420            .unwrap()
421            .with_max_retries(1);
422        let err = client
423            .submit(&req(), false, SubmitOptions::default())
424            .await
425            .unwrap_err();
426        assert!(matches!(err, RustboxError::Server(503)));
427    }
428
429    #[tokio::test]
430    async fn submit_should_send_user_agent_header() {
431        let mock_server = MockServer::start().await;
432        Mock::given(method("POST"))
433            .and(path("/api/submit"))
434            .and(wiremock::matchers::header_regex(
435                "user-agent",
436                r"^rustbox-sdk-rust/",
437            ))
438            .respond_with(
439                ResponseTemplate::new(200)
440                    .set_body_json(serde_json::json!({"id": "1", "verdict": "AC"})),
441            )
442            .mount(&mock_server)
443            .await;
444
445        let client = Rustbox::new("test")
446            .unwrap()
447            .with_base_url(&mock_server.uri())
448            .unwrap();
449        let res = client
450            .submit(&req(), false, SubmitOptions::default())
451            .await
452            .unwrap();
453        assert_eq!(res.get("verdict").unwrap().as_str().unwrap(), "AC");
454    }
455}