Skip to main content

synth_ai_core/localapi/
client.rs

1//! Task app HTTP client.
2//!
3//! Client for communicating with task apps via the standard API contract.
4
5use super::types::{HealthResponse, InfoResponse, RolloutRequest, RolloutResponse, TaskInfo};
6use crate::errors::CoreError;
7use crate::http::HttpClient;
8use serde_json::Value;
9
10/// Default timeout in seconds for task app requests.
11const DEFAULT_TIMEOUT_SECS: u64 = 300;
12
13/// Client for communicating with task apps.
14pub struct TaskAppClient {
15    client: HttpClient,
16    base_url: String,
17}
18
19impl TaskAppClient {
20    /// Create a new task app client.
21    pub fn new(base_url: &str, api_key: Option<&str>) -> Self {
22        let key = api_key.unwrap_or("no-auth");
23        let client = HttpClient::new(base_url, key, DEFAULT_TIMEOUT_SECS)
24            .expect("Failed to create HTTP client");
25        Self {
26            client,
27            base_url: base_url.trim_end_matches('/').to_string(),
28        }
29    }
30
31    /// Create a client with custom timeout.
32    pub fn with_timeout(base_url: &str, api_key: Option<&str>, timeout_secs: u64) -> Self {
33        let key = api_key.unwrap_or("no-auth");
34        let client = HttpClient::new(base_url, key, timeout_secs)
35            .expect("Failed to create HTTP client");
36        Self {
37            client,
38            base_url: base_url.trim_end_matches('/').to_string(),
39        }
40    }
41
42    /// Get the base URL.
43    pub fn base_url(&self) -> &str {
44        &self.base_url
45    }
46
47    /// Check task app health.
48    pub async fn health(&self) -> Result<HealthResponse, CoreError> {
49        let response: Value = self.client.get("/health", None).await?;
50        serde_json::from_value(response)
51            .map_err(|e| CoreError::Internal(format!("Failed to parse health response: {}", e)))
52    }
53
54    /// Check if the task app is healthy.
55    pub async fn is_healthy(&self) -> bool {
56        self.health().await.map(|r| r.healthy).unwrap_or(false)
57    }
58
59    /// Get service info.
60    pub async fn info(&self) -> Result<InfoResponse, CoreError> {
61        let response: Value = self.client.get("/info", None).await?;
62        serde_json::from_value(response)
63            .map_err(|e| CoreError::Internal(format!("Failed to parse info response: {}", e)))
64    }
65
66    /// Get task info for specific seeds.
67    pub async fn task_info(&self, seeds: Option<&[i64]>) -> Result<Vec<TaskInfo>, CoreError> {
68        let path = match seeds {
69            Some(s) if !s.is_empty() => {
70                let query: String = s
71                    .iter()
72                    .map(|seed| format!("seed={}", seed))
73                    .collect::<Vec<_>>()
74                    .join("&");
75                format!("/task_info?{}", query)
76            }
77            _ => "/task_info".to_string(),
78        };
79
80        let response: Value = self.client.get(&path, None).await?;
81
82        // Response can be a single TaskInfo or array
83        if response.is_array() {
84            serde_json::from_value(response)
85                .map_err(|e| CoreError::Internal(format!("Failed to parse task_info array: {}", e)))
86        } else if response.get("taskset").is_some() {
87            // Taskset descriptor response (no seeds provided)
88            Ok(vec![])
89        } else {
90            let info: TaskInfo = serde_json::from_value(response)
91                .map_err(|e| CoreError::Internal(format!("Failed to parse task_info: {}", e)))?;
92            Ok(vec![info])
93        }
94    }
95
96    /// Get taskset description (no seeds).
97    pub async fn taskset_info(&self) -> Result<Value, CoreError> {
98        let response: Value = self.client.get("/task_info", None).await?;
99        Ok(response)
100    }
101
102    /// Execute a rollout.
103    pub async fn rollout(&self, request: &RolloutRequest) -> Result<RolloutResponse, CoreError> {
104        let body = serde_json::to_value(request)
105            .map_err(|e| CoreError::Internal(format!("Failed to serialize rollout request: {}", e)))?;
106
107        let response: Value = self.client.post_json("/rollout", &body).await?;
108
109        serde_json::from_value(response)
110            .map_err(|e| CoreError::Internal(format!("Failed to parse rollout response: {}", e)))
111    }
112
113    /// Signal that the job is done.
114    pub async fn done(&self) -> Result<Value, CoreError> {
115        let response: Value = self.client.post_json("/done", &serde_json::json!({})).await?;
116        Ok(response)
117    }
118
119    /// Raw GET request to any endpoint.
120    pub async fn get(&self, path: &str) -> Result<Value, CoreError> {
121        let response: Value = self.client.get(path, None).await?;
122        Ok(response)
123    }
124
125    /// Raw POST request to any endpoint.
126    pub async fn post(&self, path: &str, body: &Value) -> Result<Value, CoreError> {
127        let response: Value = self.client.post_json(path, body).await?;
128        Ok(response)
129    }
130
131    /// Wait for the task app to become healthy.
132    pub async fn wait_for_healthy(
133        &self,
134        timeout_seconds: f64,
135        poll_interval_seconds: f64,
136    ) -> Result<(), CoreError> {
137        let start = std::time::Instant::now();
138        let timeout = std::time::Duration::from_secs_f64(timeout_seconds);
139        let interval = std::time::Duration::from_secs_f64(poll_interval_seconds);
140
141        loop {
142            if start.elapsed() >= timeout {
143                return Err(CoreError::Timeout(format!(
144                    "Task app at {} did not become healthy within {} seconds",
145                    self.base_url, timeout_seconds
146                )));
147            }
148
149            match self.health().await {
150                Ok(health) if health.healthy => return Ok(()),
151                Ok(_) | Err(_) => {
152                    tokio::time::sleep(interval).await;
153                }
154            }
155        }
156    }
157}
158
159/// Environment client for RL-style interactions.
160pub struct EnvClient<'a> {
161    client: &'a TaskAppClient,
162}
163
164impl<'a> EnvClient<'a> {
165    /// Create a new environment client.
166    pub fn new(client: &'a TaskAppClient) -> Self {
167        Self { client }
168    }
169
170    /// Initialize an environment.
171    pub async fn initialize(&self, env_name: &str, payload: &Value) -> Result<Value, CoreError> {
172        let path = format!("/env/{}/initialize", env_name);
173        self.client.post(&path, payload).await
174    }
175
176    /// Take a step in the environment.
177    pub async fn step(&self, env_name: &str, payload: &Value) -> Result<Value, CoreError> {
178        let path = format!("/env/{}/step", env_name);
179        self.client.post(&path, payload).await
180    }
181
182    /// Terminate the environment.
183    pub async fn terminate(&self, env_name: &str, payload: &Value) -> Result<Value, CoreError> {
184        let path = format!("/env/{}/terminate", env_name);
185        self.client.post(&path, payload).await
186    }
187
188    /// Reset the environment.
189    pub async fn reset(&self, env_name: &str, payload: &Value) -> Result<Value, CoreError> {
190        let path = format!("/env/{}/reset", env_name);
191        self.client.post(&path, payload).await
192    }
193}
194
195impl TaskAppClient {
196    /// Get an environment client for RL-style interactions.
197    pub fn env(&self) -> EnvClient<'_> {
198        EnvClient::new(self)
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_client_creation() {
208        let client = TaskAppClient::new("https://task-app.example.com", Some("sk-test"));
209        assert_eq!(client.base_url(), "https://task-app.example.com");
210    }
211
212    #[test]
213    fn test_client_url_normalization() {
214        let client = TaskAppClient::new("https://task-app.example.com/", Some("sk-test"));
215        assert_eq!(client.base_url(), "https://task-app.example.com");
216    }
217
218    #[test]
219    fn test_env_client() {
220        let client = TaskAppClient::new("https://task-app.example.com", None);
221        let _env = client.env();
222        // Just verify it compiles
223    }
224}