Skip to main content

synth_ai_core/container/
client.rs

1//! Container HTTP client.
2//!
3//! Client for communicating with containers via the standard API contract.
4
5use super::types::{HealthResponse, InfoResponse, RolloutRequest, RolloutResponse, TaskInfo};
6use crate::errors::CoreError;
7use crate::http::HttpClient;
8use serde_json::Value;
9
10/// Default timeout in seconds for container requests.
11const DEFAULT_TIMEOUT_SECS: u64 = 300;
12
13/// Client for communicating with containers.
14pub struct ContainerClient {
15    client: HttpClient,
16    base_url: String,
17}
18
19impl ContainerClient {
20    /// Create a new container client.
21    pub fn new(base_url: &str, api_key: Option<&str>) -> Self {
22        let key = api_key.unwrap_or("no-auth");
23        let client = HttpClient::new(base_url, key, DEFAULT_TIMEOUT_SECS)
24            .expect("Failed to create HTTP client");
25        Self {
26            client,
27            base_url: base_url.trim_end_matches('/').to_string(),
28        }
29    }
30
31    /// Create a client with custom timeout.
32    pub fn with_timeout(base_url: &str, api_key: Option<&str>, timeout_secs: u64) -> Self {
33        let key = api_key.unwrap_or("no-auth");
34        let client =
35            HttpClient::new(base_url, key, timeout_secs).expect("Failed to create HTTP client");
36        Self {
37            client,
38            base_url: base_url.trim_end_matches('/').to_string(),
39        }
40    }
41
42    /// Get the base URL.
43    pub fn base_url(&self) -> &str {
44        &self.base_url
45    }
46
47    /// Check container health.
48    pub async fn health(&self) -> Result<HealthResponse, CoreError> {
49        let response: Value = self.client.get("/health", None).await?;
50        serde_json::from_value(response)
51            .map_err(|e| CoreError::Internal(format!("Failed to parse health response: {}", e)))
52    }
53
54    /// Check if the container is healthy.
55    pub async fn is_healthy(&self) -> bool {
56        self.health().await.map(|r| r.healthy).unwrap_or(false)
57    }
58
59    /// Get service info.
60    pub async fn info(&self) -> Result<InfoResponse, CoreError> {
61        let response: Value = self.client.get("/info", None).await?;
62        serde_json::from_value(response)
63            .map_err(|e| CoreError::Internal(format!("Failed to parse info response: {}", e)))
64    }
65
66    /// Get task info for specific seeds.
67    pub async fn task_info(&self, seeds: Option<&[i64]>) -> Result<Vec<TaskInfo>, CoreError> {
68        let path = match seeds {
69            Some(s) if !s.is_empty() => {
70                let query: String = s
71                    .iter()
72                    .map(|seed| format!("seed={}", seed))
73                    .collect::<Vec<_>>()
74                    .join("&");
75                format!("/task_info?{}", query)
76            }
77            _ => "/task_info".to_string(),
78        };
79
80        let response: Value = self.client.get(&path, None).await?;
81
82        // Response can be a single TaskInfo or array
83        if response.is_array() {
84            serde_json::from_value(response)
85                .map_err(|e| CoreError::Internal(format!("Failed to parse task_info array: {}", e)))
86        } else if response.get("taskset").is_some() {
87            // Taskset descriptor response (no seeds provided)
88            Ok(vec![])
89        } else {
90            let info: TaskInfo = serde_json::from_value(response)
91                .map_err(|e| CoreError::Internal(format!("Failed to parse task_info: {}", e)))?;
92            Ok(vec![info])
93        }
94    }
95
96    /// Get taskset description (no seeds).
97    pub async fn taskset_info(&self) -> Result<Value, CoreError> {
98        let response: Value = self.client.get("/task_info", None).await?;
99        Ok(response)
100    }
101
102    /// Execute a rollout.
103    pub async fn rollout(&self, request: &RolloutRequest) -> Result<RolloutResponse, CoreError> {
104        let body = serde_json::to_value(request).map_err(|e| {
105            CoreError::Internal(format!("Failed to serialize rollout request: {}", e))
106        })?;
107
108        let response: Value = self.client.post_json("/rollout", &body).await?;
109
110        serde_json::from_value(response)
111            .map_err(|e| CoreError::Internal(format!("Failed to parse rollout response: {}", e)))
112    }
113
114    /// Signal that the job is done.
115    pub async fn done(&self) -> Result<Value, CoreError> {
116        let response: Value = self
117            .client
118            .post_json("/done", &serde_json::json!({}))
119            .await?;
120        Ok(response)
121    }
122
123    /// Raw GET request to any endpoint.
124    pub async fn get(&self, path: &str) -> Result<Value, CoreError> {
125        let response: Value = self.client.get(path, None).await?;
126        Ok(response)
127    }
128
129    /// Raw POST request to any endpoint.
130    pub async fn post(&self, path: &str, body: &Value) -> Result<Value, CoreError> {
131        let response: Value = self.client.post_json(path, body).await?;
132        Ok(response)
133    }
134
135    /// Wait for the container to become healthy.
136    pub async fn wait_for_healthy(
137        &self,
138        timeout_seconds: f64,
139        poll_interval_seconds: f64,
140    ) -> Result<(), CoreError> {
141        let start = std::time::Instant::now();
142        let timeout = std::time::Duration::from_secs_f64(timeout_seconds);
143        let interval = std::time::Duration::from_secs_f64(poll_interval_seconds);
144
145        loop {
146            if start.elapsed() >= timeout {
147                return Err(CoreError::Timeout(format!(
148                    "Container at {} did not become healthy within {} seconds",
149                    self.base_url, timeout_seconds
150                )));
151            }
152
153            match self.health().await {
154                Ok(health) if health.healthy => return Ok(()),
155                Ok(_) | Err(_) => {
156                    tokio::time::sleep(interval).await;
157                }
158            }
159        }
160    }
161}
162
163/// Environment client for RL-style interactions.
164pub struct EnvClient<'a> {
165    client: &'a ContainerClient,
166}
167
168impl<'a> EnvClient<'a> {
169    /// Create a new environment client.
170    pub fn new(client: &'a ContainerClient) -> Self {
171        Self { client }
172    }
173
174    /// Initialize an environment.
175    pub async fn initialize(&self, env_name: &str, payload: &Value) -> Result<Value, CoreError> {
176        let path = format!("/env/{}/initialize", env_name);
177        self.client.post(&path, payload).await
178    }
179
180    /// Take a step in the environment.
181    pub async fn step(&self, env_name: &str, payload: &Value) -> Result<Value, CoreError> {
182        let path = format!("/env/{}/step", env_name);
183        self.client.post(&path, payload).await
184    }
185
186    /// Terminate the environment.
187    pub async fn terminate(&self, env_name: &str, payload: &Value) -> Result<Value, CoreError> {
188        let path = format!("/env/{}/terminate", env_name);
189        self.client.post(&path, payload).await
190    }
191
192    /// Reset the environment.
193    pub async fn reset(&self, env_name: &str, payload: &Value) -> Result<Value, CoreError> {
194        let path = format!("/env/{}/reset", env_name);
195        self.client.post(&path, payload).await
196    }
197}
198
199impl ContainerClient {
200    /// Get an environment client for RL-style interactions.
201    pub fn env(&self) -> EnvClient<'_> {
202        EnvClient::new(self)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_client_creation() {
212        let client = ContainerClient::new("https://container.example.com", Some("sk-test"));
213        assert_eq!(client.base_url(), "https://container.example.com");
214    }
215
216    #[test]
217    fn test_client_url_normalization() {
218        let client = ContainerClient::new("https://container.example.com/", Some("sk-test"));
219        assert_eq!(client.base_url(), "https://container.example.com");
220    }
221
222    #[test]
223    fn test_env_client() {
224        let client = ContainerClient::new("https://container.example.com", None);
225        let _env = client.env();
226        // Just verify it compiles
227    }
228}