trackio_rs/
client.rs

1use once_cell::sync::OnceCell;
2use parking_lot::Mutex;
3use reqwest::blocking::Client as Http;
4use reqwest::StatusCode;
5use serde::Serialize;
6use std::env;
7use std::time::Duration;
8
9/// A lightweight Trackio REST client for posting metrics to local or remote Trackio dashboards.
10#[derive(Debug)]
11pub struct Client {
12    base_url: String,
13    project: String,
14    run: String,
15    write_token: Option<String>,
16
17    http: Http,
18    cached_bulk_path: OnceCell<String>,
19
20    // batching
21    buf: Mutex<Vec<LogItem>>,
22    max_batch: usize,
23    #[allow(dead_code)]
24    flush_interval: Duration,
25}
26
27#[derive(Debug, Clone, Serialize)]
28struct BulkPayload<'a> {
29    project: &'a str,
30    run: &'a str,
31    #[serde(rename = "metrics_list")]
32    metrics_list: Vec<serde_json::Value>,
33    steps: Vec<i64>,
34    timestamps: Vec<String>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    config: Option<serde_json::Value>,
37}
38
39#[derive(Debug, Clone)]
40pub struct LogItem {
41    pub metrics: serde_json::Value,
42    pub step: Option<i64>,
43    pub timestamp: Option<String>,
44}
45
46impl Client {
47    /// Create a new Trackio client using environment variables for configuration.
48    ///
49    /// Recognized env vars:
50    /// - `TRACKIO_SERVER_URL` (default: http://127.0.0.1:7860)
51    /// - `TRACKIO_PROJECT`
52    /// - `TRACKIO_RUN`
53    /// - `TRACKIO_WRITE_TOKEN`
54    /// - `TRACKIO_TIMEOUT_MS`
55    /// - `TRACKIO_MAX_BATCH`
56    /// - `TRACKIO_FLUSH_INTERVAL_MS`
57    pub fn new() -> Self {
58        let base = env::var("TRACKIO_SERVER_URL").unwrap_or_else(|_| "http://127.0.0.1:7860".into());
59        let project = env::var("TRACKIO_PROJECT").unwrap_or_default();
60        let run = env::var("TRACKIO_RUN").unwrap_or_default();
61        let write_token = env::var("TRACKIO_WRITE_TOKEN").ok();
62
63        let timeout_ms = env::var("TRACKIO_TIMEOUT_MS")
64            .ok()
65            .and_then(|s| s.parse::<u64>().ok())
66            .unwrap_or(5000);
67
68        let max_batch = env::var("TRACKIO_MAX_BATCH")
69            .ok()
70            .and_then(|s| s.parse::<usize>().ok())
71            .unwrap_or(128);
72
73        let flush_interval = env::var("TRACKIO_FLUSH_INTERVAL_MS")
74            .ok()
75            .and_then(|s| s.parse::<u64>().ok())
76            .map(Duration::from_millis)
77            .unwrap_or(Duration::from_millis(200));
78
79        Self {
80            base_url: base,
81            project,
82            run,
83            write_token,
84            http: Http::builder()
85                .timeout(Duration::from_millis(timeout_ms))
86                .build()
87                .expect("failed to build HTTP client"),
88            cached_bulk_path: OnceCell::new(),
89            buf: Mutex::new(Vec::with_capacity(max_batch)),
90            max_batch,
91            flush_interval,
92        }
93    }
94
95    pub fn with_project(mut self, p: &str) -> Self {
96        self.project = p.into();
97        self
98    }
99
100    pub fn with_run(mut self, r: &str) -> Self {
101        self.run = r.into();
102        self
103    }
104
105    pub fn with_base_url(mut self, u: &str) -> Self {
106        self.base_url = u.into();
107        self
108    }
109
110    /// Logs a single metric dictionary into the in-memory buffer.
111    /// Auto-flushes when `max_batch` is reached.
112    pub fn log(&self, metrics: serde_json::Value, step: Option<i64>, ts: Option<String>) {
113        let mut buf = self.buf.lock();
114        buf.push(LogItem {
115            metrics,
116            step,
117            timestamp: ts,
118        });
119        if buf.len() >= self.max_batch {
120            drop(buf);
121            let _ = self.flush(); // best-effort flush
122        }
123    }
124
125    /// Flush all buffered metrics to the Trackio server.
126    pub fn flush(&self) -> Result<(), TrackioError> {
127        let items = {
128            let mut buf = self.buf.lock();
129            if buf.is_empty() {
130                return Ok(());
131            }
132            let out = buf.clone();
133            buf.clear();
134            out
135        };
136
137        let mut metrics_list = Vec::with_capacity(items.len());
138        let mut steps = Vec::with_capacity(items.len());
139        let mut timestamps = Vec::with_capacity(items.len());
140
141        for it in items {
142            metrics_list.push(it.metrics);
143            steps.push(it.step.unwrap_or(-1));
144            timestamps.push(it.timestamp.unwrap_or_else(|| "".into()));
145        }
146
147        let payload = BulkPayload {
148            project: &self.project,
149            run: &self.run,
150            metrics_list,
151            steps,
152            timestamps,
153            config: None,
154        };
155
156        // Discover a working bulk endpoint once.
157        let path = self.cached_bulk_path.get_or_try_init(|| {
158            if self.try_post("/api/bulk_log", &payload).is_ok() {
159                return Ok("/api/bulk_log".to_string());
160            }
161            if self.try_post("/gradio_api/bulk_log", &payload).is_ok() {
162                return Ok("/gradio_api/bulk_log".to_string());
163            }
164            Err(TrackioError::NoBulkEndpoint)
165        })?;
166
167        self.try_post(path, &payload)
168    }
169
170    /// Internal helper to send JSON POST and map non-2xx responses.
171    fn try_post<P: AsRef<str>, T: Serialize>(
172        &self,
173        path: P,
174        payload: &T,
175    ) -> Result<(), TrackioError> {
176        let url = format!("{}{}", self.base_url, path.as_ref());
177        let mut req = self.http.post(url).json(payload);
178        if let Some(tok) = &self.write_token {
179            req = req.header("X-Trackio-Write-Token", tok);
180        }
181        let resp = req.send().map_err(TrackioError::Http)?;
182        if !resp.status().is_success() {
183            let status = resp.status();
184            let body = resp.text().unwrap_or_default();
185            if status == StatusCode::NOT_FOUND {
186                return Err(TrackioError::NotFound(body));
187            }
188            return Err(TrackioError::Status(status.as_u16(), body));
189        }
190        Ok(())
191    }
192
193    /// Flush remaining metrics and stop background tasks (if any).
194    pub fn close(&self) -> Result<(), TrackioError> {
195        self.flush()
196    }
197}
198
199#[derive(thiserror::Error, Debug)]
200pub enum TrackioError {
201    #[error("no Trackio bulk endpoint found")]
202    NoBulkEndpoint,
203    #[error("HTTP error: {0}")]
204    Http(#[from] reqwest::Error),
205    #[error("404 Not Found: {0}")]
206    NotFound(String),
207    #[error("HTTP {0}: {1}")]
208    Status(u16, String),
209}