1use std::collections::HashSet;
2use std::future::Future;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use axum::extract::State;
7use axum::http::{HeaderMap, StatusCode};
8use axum::response::{IntoResponse, Response};
9use axum::routing::{get, post};
10use axum::{Json, Router};
11use futures_util::future::BoxFuture;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Map, Value};
14
15use crate::types::{Result, SynthError};
16
17pub type RolloutHandler =
18 Arc<dyn Fn(RolloutRequest) -> BoxFuture<'static, std::result::Result<RolloutResponse, LocalApiError>>
19 + Send
20 + Sync>;
21
22#[derive(Debug, Clone)]
23pub struct LocalApiError {
24 pub status: StatusCode,
25 pub message: String,
26}
27
28impl LocalApiError {
29 pub fn bad_request(message: impl Into<String>) -> Self {
30 Self {
31 status: StatusCode::BAD_REQUEST,
32 message: message.into(),
33 }
34 }
35
36 pub fn internal(message: impl Into<String>) -> Self {
37 Self {
38 status: StatusCode::INTERNAL_SERVER_ERROR,
39 message: message.into(),
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct TaskDescriptor {
46 pub id: String,
47 pub name: String,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub description: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub version: Option<String>,
52 #[serde(flatten)]
53 pub extra: Map<String, Value>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, Default)]
57pub struct DatasetInfo {
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub id: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub name: Option<String>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub version: Option<String>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 pub splits: Option<Vec<String>>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 pub default_split: Option<String>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 pub description: Option<String>,
70 #[serde(flatten)]
71 pub extra: Map<String, Value>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, Default)]
75pub struct InferenceInfo {
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub model: Option<String>,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub inference_url: Option<String>,
80 #[serde(flatten)]
81 pub extra: Map<String, Value>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, Default)]
85pub struct LimitsInfo {
86 #[serde(skip_serializing_if = "Option::is_none")]
87 pub max_turns: Option<i64>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub max_response_tokens: Option<i64>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 pub timeout_seconds: Option<i64>,
92 #[serde(flatten)]
93 pub extra: Map<String, Value>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct TaskInfo {
98 pub task: TaskDescriptor,
99 pub dataset: DatasetInfo,
100 pub inference: InferenceInfo,
101 pub limits: LimitsInfo,
102 #[serde(skip_serializing_if = "Option::is_none")]
103 pub task_metadata: Option<Value>,
104 #[serde(flatten)]
105 pub extra: Map<String, Value>,
106}
107
108impl TaskInfo {
109 pub fn minimal(app_id: impl Into<String>, name: impl Into<String>, description: impl Into<String>) -> Self {
110 let task = TaskDescriptor {
111 id: app_id.into(),
112 name: name.into(),
113 description: Some(description.into()),
114 version: None,
115 extra: Map::new(),
116 };
117 Self {
118 task,
119 dataset: DatasetInfo::default(),
120 inference: InferenceInfo::default(),
121 limits: LimitsInfo::default(),
122 task_metadata: None,
123 extra: Map::new(),
124 }
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct RolloutRequest {
130 pub trace_correlation_id: String,
131 pub env: Value,
132 pub policy: Value,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 pub on_done: Option<String>,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub safety: Option<Value>,
137 #[serde(skip_serializing_if = "Option::is_none")]
138 pub training_session_id: Option<String>,
139 #[serde(skip_serializing_if = "Option::is_none")]
140 pub synth_base_url: Option<String>,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub context_overrides: Option<Value>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub override_bundle_id: Option<String>,
145 #[serde(flatten)]
146 pub extra: Map<String, Value>,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RolloutMetrics {
151 pub outcome_reward: f64,
152 #[serde(skip_serializing_if = "Option::is_none")]
153 pub event_rewards: Option<Vec<f64>>,
154 #[serde(skip_serializing_if = "Option::is_none")]
155 pub outcome_objectives: Option<Map<String, Value>>,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 pub event_objectives: Option<Vec<Map<String, Value>>>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 pub instance_objectives: Option<Vec<Map<String, Value>>>,
160 #[serde(default, skip_serializing_if = "Map::is_empty")]
161 pub details: Map<String, Value>,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct RolloutResponse {
166 pub trace_correlation_id: String,
167 pub reward_info: RolloutMetrics,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 pub trace: Option<Value>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub inference_url: Option<String>,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 pub artifact: Option<Value>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 pub success_status: Option<String>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 pub status_detail: Option<String>,
178 #[serde(skip_serializing_if = "Option::is_none")]
179 pub override_application_results: Option<Value>,
180 #[serde(flatten)]
181 pub extra: Map<String, Value>,
182}
183
184#[derive(Clone)]
185pub struct LocalApiConfig {
186 pub task_info: TaskInfo,
187 pub rollout: RolloutHandler,
188 pub require_api_key: bool,
189 pub api_keys: Vec<String>,
190}
191
192impl LocalApiConfig {
193 pub fn new<F, Fut>(
194 app_id: impl Into<String>,
195 name: impl Into<String>,
196 description: impl Into<String>,
197 handler: F,
198 ) -> Self
199 where
200 F: Fn(RolloutRequest) -> Fut + Send + Sync + 'static,
201 Fut: Future<Output = std::result::Result<RolloutResponse, LocalApiError>> + Send + 'static,
202 {
203 let rollout: RolloutHandler = Arc::new(move |req| Box::pin(handler(req)));
204 let mut api_keys = Vec::new();
205 if let Ok(val) = std::env::var("ENVIRONMENT_API_KEY") {
206 api_keys.push(val);
207 }
208 Self {
209 task_info: TaskInfo::minimal(app_id, name, description),
210 rollout,
211 require_api_key: true,
212 api_keys,
213 }
214 }
215}
216
217#[derive(Clone)]
218pub struct LocalApiApp {
219 router: Router,
220}
221
222pub fn create_local_api(config: LocalApiConfig) -> LocalApiApp {
223 let state = Arc::new(config);
224
225 let router = Router::new()
226 .route("/", get(root))
227 .route("/health", get(health))
228 .route("/task_info", get(task_info))
229 .route("/info", get(info))
230 .route("/rollout", post(rollout))
231 .with_state(state);
232
233 LocalApiApp { router }
234}
235
236impl LocalApiApp {
237 pub fn router(&self) -> Router {
238 self.router.clone()
239 }
240
241 pub async fn run(self, addr: SocketAddr) -> Result<()> {
242 axum::Server::bind(&addr)
243 .serve(self.router.into_make_service())
244 .await
245 .map_err(|err| SynthError::UnexpectedResponse(err.to_string()))
246 }
247}
248
249async fn root() -> Response {
250 Json(json!({"status": "ok"})).into_response()
251}
252
253async fn health(
254 State(config): State<Arc<LocalApiConfig>>,
255 headers: HeaderMap,
256) -> Response {
257 if let Err(resp) = authorize(&config, &headers) {
258 return resp;
259 }
260 Json(json!({ "healthy": true })).into_response()
261}
262
263async fn task_info(
264 State(config): State<Arc<LocalApiConfig>>,
265 headers: HeaderMap,
266) -> Response {
267 if let Err(resp) = authorize(&config, &headers) {
268 return resp;
269 }
270 Json(config.task_info.clone()).into_response()
271}
272
273async fn info(
274 State(config): State<Arc<LocalApiConfig>>,
275 headers: HeaderMap,
276) -> Response {
277 if let Err(resp) = authorize(&config, &headers) {
278 return resp;
279 }
280 let task = config.task_info.task.clone();
281 let version = task.version.clone();
282 let service = json!({
283 "task": task,
284 "version": version,
285 });
286 let payload = json!({
287 "service": service,
288 "dataset": config.task_info.dataset,
289 "rubrics": null,
290 "inference": config.task_info.inference,
291 "limits": config.task_info.limits,
292 });
293 Json(payload).into_response()
294}
295
296async fn rollout(
297 State(config): State<Arc<LocalApiConfig>>,
298 headers: HeaderMap,
299 Json(request): Json<RolloutRequest>,
300) -> impl IntoResponse {
301 if let Err(resp) = authorize(&config, &headers) {
302 return resp;
303 }
304 let handler = config.rollout.clone();
305 match handler(request).await {
306 Ok(resp) => (StatusCode::OK, Json(resp)).into_response(),
307 Err(err) => (
308 err.status,
309 Json(json!({ "error": err.message })),
310 )
311 .into_response(),
312 }
313}
314
315fn authorize(config: &LocalApiConfig, headers: &HeaderMap) -> std::result::Result<(), axum::response::Response> {
316 if !config.require_api_key {
317 return Ok(());
318 }
319 let allowed = api_key_set(config);
320 if allowed.is_empty() {
321 let resp = (
322 StatusCode::SERVICE_UNAVAILABLE,
323 Json(json!({ "error": "ENVIRONMENT_API_KEY is not configured" })),
324 )
325 .into_response();
326 return Err(resp);
327 }
328 let provided = header_keys(headers);
329 if provided.iter().any(|key| allowed.contains(key)) {
330 return Ok(());
331 }
332 let resp = (
333 StatusCode::UNAUTHORIZED,
334 Json(json!({ "error": "API key missing or invalid" })),
335 )
336 .into_response();
337 Err(resp)
338}
339
340fn api_key_set(config: &LocalApiConfig) -> HashSet<String> {
341 let mut set = HashSet::new();
342 for key in &config.api_keys {
343 if !key.is_empty() {
344 set.insert(key.clone());
345 }
346 }
347 if let Ok(aliases) = std::env::var("ENVIRONMENT_API_KEY_ALIASES") {
348 for part in aliases.split(',') {
349 let trimmed = part.trim();
350 if !trimmed.is_empty() {
351 set.insert(trimmed.to_string());
352 }
353 }
354 }
355 set
356}
357
358fn header_keys(headers: &HeaderMap) -> Vec<String> {
359 let mut keys = Vec::new();
360 for header in ["x-api-key", "x-api-keys", "authorization"] {
361 if let Some(value) = headers.get(header) {
362 if let Ok(text) = value.to_str() {
363 if header == "authorization" && text.to_lowercase().starts_with("bearer ") {
364 keys.extend(split_keys(&text[7..]));
365 } else {
366 keys.extend(split_keys(text));
367 }
368 }
369 }
370 }
371 keys
372}
373
374fn split_keys(input: &str) -> Vec<String> {
375 input
376 .split(',')
377 .map(|s| s.trim().to_string())
378 .filter(|s| !s.is_empty())
379 .collect()
380}