1pub mod auth;
2pub mod storage;
3
4use auth::{Auth, AuthConfig};
5use axum::{
6 Json, Router,
7 extract::{Path, State},
8 http::{HeaderMap, HeaderValue, Method, StatusCode, header},
9 routing::{get, post},
10};
11use sqlx::{PgPool, Row, SqlitePool, postgres::PgPoolOptions, sqlite::SqlitePoolOptions};
12use std::net::SocketAddr;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tower_http::cors::{AllowOrigin, CorsLayer};
16use trace_weft_core::SpanRecord;
17use trace_weft_recorder::TraceStore;
18
19#[derive(Clone)]
20pub enum DbPool {
21 Sqlite(SqlitePool),
22 Postgres(PgPool),
23}
24
25#[derive(Clone)]
26pub struct AppState {
27 pub pool: DbPool,
28 pub blob_store: Arc<dyn trace_weft_core::BlobStore>,
29 pub trace_store: Arc<dyn TraceStore>,
30 pub clickhouse: Option<Arc<storage::analytics::ClickHouseAnalytics>>,
31 pub auth: Arc<AuthConfig>,
32}
33
34pub async fn start_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
39 start_server_with_shutdown(
40 db_url,
41 port,
42 blob_dir,
43 AuthConfig::from_env(),
44 std::future::pending::<()>(),
45 )
46 .await
47}
48
49pub async fn start_dev_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
53 start_server_with_shutdown(
54 db_url,
55 port,
56 blob_dir,
57 AuthConfig::from_env_local_first(),
58 std::future::pending::<()>(),
59 )
60 .await
61}
62
63pub async fn start_server_with_shutdown(
67 db_url: &str,
68 port: u16,
69 blob_dir: PathBuf,
70 auth: AuthConfig,
71 shutdown: impl std::future::Future<Output = ()> + Send + 'static,
72) -> anyhow::Result<()> {
73 let pool = if db_url.starts_with("postgres://") || db_url.starts_with("postgresql://") {
74 let pg_pool = PgPoolOptions::new().connect(db_url).await?;
75 DbPool::Postgres(pg_pool)
76 } else {
77 let url = if db_url.starts_with("sqlite://") {
79 db_url.to_string()
80 } else {
81 if let Some(parent) = std::path::Path::new(db_url).parent() {
82 tokio::fs::create_dir_all(parent).await?;
83 }
84 format!("sqlite://{}?mode=rwc", db_url)
85 };
86 let sq_pool = SqlitePoolOptions::new().connect(&url).await?;
87 DbPool::Sqlite(sq_pool)
88 };
89
90 let blob_store = Arc::new(storage::blob::LocalBlobStore::new(blob_dir));
91
92 let trace_store: Arc<dyn TraceStore> = match &pool {
93 DbPool::Postgres(pg_pool) => Arc::new(storage::postgres::PostgresRecorder {
94 pool: pg_pool.clone(),
95 }),
96 DbPool::Sqlite(sq_pool) => {
97 Arc::new(trace_weft_recorder::sqlite::SqliteRecorder::from_pool(sq_pool.clone()).await?)
98 }
99 };
100
101 let clickhouse = if let Ok(ch_url) = std::env::var("TRACE_WEFT_CH_URL") {
103 tracing::info!("Initializing ClickHouse analytics connected to {}", ch_url);
104 Some(Arc::new(storage::analytics::ClickHouseAnalytics::new(
105 &ch_url, "default", "", "default",
106 )))
107 } else {
108 None
109 };
110
111 let state = AppState {
112 pool,
113 blob_store,
114 trace_store,
115 clickhouse,
116 auth: Arc::new(auth),
117 };
118
119 let app = build_router(state);
120
121 let addr = SocketAddr::from(([127, 0, 0, 1], port));
122 tracing::info!("Server listening on http://{}", addr);
123
124 let listener = tokio::net::TcpListener::bind(addr).await?;
125 axum::serve(listener, app)
126 .with_graceful_shutdown(shutdown)
127 .await?;
128
129 Ok(())
130}
131
132pub fn build_router(state: AppState) -> Router {
134 Router::new()
135 .route("/api/traces", get(list_traces))
136 .route("/api/traces/{trace_id}", get(get_trace))
137 .route("/api/evals", get(list_evals))
138 .route("/api/v1/batch", post(batch_ingest))
139 .route("/api/hitl/pending", get(get_pending_approvals))
140 .route("/api/hitl/resolve", post(resolve_approval))
141 .layer(local_cors())
142 .with_state(state)
143}
144
145fn local_cors() -> CorsLayer {
151 CorsLayer::new()
152 .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
153 .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE])
154 .allow_origin(AllowOrigin::predicate(|origin: &HeaderValue, _req| {
155 origin.to_str().map(is_allowed_origin).unwrap_or(false)
156 }))
157}
158
159fn is_allowed_origin(origin: &str) -> bool {
162 if origin == "tauri://localhost" || origin == "http://tauri.localhost" {
163 return true;
164 }
165 ["http://localhost", "http://127.0.0.1"]
166 .iter()
167 .any(|host| origin == *host || origin.starts_with(&format!("{host}:")))
168}
169
170fn authorize(state: &AppState, headers: &HeaderMap) -> Result<Auth, StatusCode> {
173 state
174 .auth
175 .authenticate(headers)
176 .ok_or(StatusCode::UNAUTHORIZED)
177}
178
179async fn batch_ingest(
180 headers: HeaderMap,
181 State(state): State<AppState>,
182 Json(mut spans): Json<Vec<SpanRecord>>,
183) -> Result<StatusCode, StatusCode> {
184 let auth = authorize(&state, &headers)?;
185 let project_id = auth.project().map(|p| p.to_string());
188 for span in &mut spans {
189 span.project_id = project_id.clone();
190 }
191
192 tracing::info!(
193 "Received batch of {} spans for project {:?}",
194 spans.len(),
195 project_id
196 );
197
198 for span in &spans {
200 if let Err(e) = state.trace_store.record_span(span.clone()).await {
202 tracing::error!("Failed to record span: {}", e);
203 return Err(StatusCode::INTERNAL_SERVER_ERROR);
204 }
205 }
206
207 if let Some(ch) = &state.clickhouse
209 && let Err(e) = ch.ingest_batch(&spans).await
210 {
211 tracing::warn!("Failed to stream to ClickHouse: {}", e);
212 }
213
214 Ok(StatusCode::ACCEPTED)
215}
216
217fn db_error<E: std::fmt::Display>(e: E) -> StatusCode {
220 tracing::error!("database query failed: {e}");
221 StatusCode::INTERNAL_SERVER_ERROR
222}
223
224fn parse_json_column(raw: &str) -> Result<serde_json::Value, StatusCode> {
228 serde_json::from_str(raw).map_err(|e| {
229 tracing::error!("corrupt JSON in spans column: {e}");
230 StatusCode::INTERNAL_SERVER_ERROR
231 })
232}
233
234fn parse_opt_json_column(raw: Option<String>) -> Result<serde_json::Value, StatusCode> {
236 match raw {
237 Some(s) => parse_json_column(&s),
238 None => Ok(serde_json::Value::Null),
239 }
240}
241
242macro_rules! trace_summary_json {
249 ($row:expr) => {{
250 let row = $row;
251 let trace_id: String = row.get("trace_id");
252 let run_id: String = row.get("run_id");
253 let start_time: i64 = row.get("start_time");
254 let end_time: Option<i64> = row.get("end_time");
255 let span_count: i64 = row.get("span_count");
256 let has_error: i64 = row.get("has_error");
257 serde_json::json!({
258 "trace_id": trace_id,
259 "run_id": run_id,
260 "start_time": start_time,
261 "end_time": end_time,
262 "span_count": span_count,
263 "status": if has_error != 0 { "error" } else { "ok" },
265 })
266 }};
267}
268
269macro_rules! eval_row_json {
271 ($row:expr) => {{
272 let row = $row;
273 let trace_id: String = row.get("trace_id");
274 let span_id: String = row.get("span_id");
275 let name: String = row.get("name");
276 let start_time: i64 = row.get("start_time");
277 let status: String = row.get("status");
278 let attributes: String = row.get("attributes");
279 serde_json::json!({
280 "trace_id": trace_id,
281 "span_id": span_id,
282 "name": name,
283 "start_time": start_time,
284 "status": status,
285 "attributes": parse_json_column(&attributes)?,
286 })
287 }};
288}
289
290macro_rules! span_detail_json {
292 ($row:expr) => {{
293 let row = $row;
294 let trace_id: String = row.get("trace_id");
295 let span_id: String = row.get("span_id");
296 let parent_span_id: Option<String> = row.get("parent_span_id");
297 let span_kind: String = row.get("span_kind");
298 let name: String = row.get("name");
299 let start_time: i64 = row.get("start_time");
300 let end_time: Option<i64> = row.get("end_time");
301 let status: String = row.get("status");
302 let attributes: String = row.get("attributes");
303 let latency_ms: Option<i64> = row.get("latency_ms");
304 let input_ref: Option<String> = row.get("input_ref");
305 let output_ref: Option<String> = row.get("output_ref");
306 serde_json::json!({
307 "trace_id": trace_id,
308 "span_id": span_id,
309 "parent_span_id": parent_span_id,
310 "span_kind": span_kind,
311 "name": name,
312 "start_time": start_time,
313 "end_time": end_time,
314 "status": status,
315 "attributes": parse_json_column(&attributes)?,
316 "latency_ms": latency_ms,
317 "input_ref": parse_opt_json_column(input_ref)?,
318 "output_ref": parse_opt_json_column(output_ref)?,
319 })
320 }};
321}
322
323const LIST_TRACES_SQL_SQLITE: &str = r#"
334 SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
335 MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
336 CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error
337 FROM spans
338 WHERE (project_id = ? OR ? IS NULL)
339 GROUP BY trace_id
340 ORDER BY start_time DESC
341 LIMIT 50
342"#;
343
344const LIST_TRACES_SQL_PG: &str = r#"
345 SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
346 MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
347 CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error
348 FROM spans
349 WHERE (project_id = $1 OR $1 IS NULL)
350 GROUP BY trace_id
351 ORDER BY start_time DESC
352 LIMIT 50
353"#;
354
355const LIST_EVALS_SQL_SQLITE: &str = r#"
356 SELECT trace_id, span_id, name, start_time, status, attributes
357 FROM spans
358 WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
359 AND (project_id = ? OR ? IS NULL)
360 ORDER BY start_time DESC
361 LIMIT 50
362"#;
363
364const LIST_EVALS_SQL_PG: &str = r#"
365 SELECT trace_id, span_id, name, start_time, status, attributes
366 FROM spans
367 WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
368 AND (project_id = $1 OR $1 IS NULL)
369 ORDER BY start_time DESC
370 LIMIT 50
371"#;
372
373const GET_TRACE_SQL_SQLITE: &str = "SELECT * FROM spans WHERE trace_id = ? AND (project_id = ? OR ? IS NULL) ORDER BY start_time ASC";
374
375const GET_TRACE_SQL_PG: &str = "SELECT * FROM spans WHERE trace_id = $1 AND (project_id = $2 OR $2 IS NULL) ORDER BY start_time ASC";
376
377async fn list_traces(
378 headers: HeaderMap,
379 State(state): State<AppState>,
380) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
381 let project = authorize(&state, &headers)?.project().map(str::to_string);
382 let mut traces = Vec::new();
383 match &state.pool {
384 DbPool::Sqlite(pool) => {
385 let rows = sqlx::query(LIST_TRACES_SQL_SQLITE)
386 .bind(project.clone())
387 .bind(project)
388 .fetch_all(pool)
389 .await
390 .map_err(db_error)?;
391 for row in &rows {
392 traces.push(trace_summary_json!(row));
393 }
394 }
395 DbPool::Postgres(pool) => {
396 let rows = sqlx::query(LIST_TRACES_SQL_PG)
397 .bind(project)
398 .fetch_all(pool)
399 .await
400 .map_err(db_error)?;
401 for row in &rows {
402 traces.push(trace_summary_json!(row));
403 }
404 }
405 }
406 Ok(Json(traces))
407}
408
409async fn list_evals(
410 headers: HeaderMap,
411 State(state): State<AppState>,
412) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
413 let project = authorize(&state, &headers)?.project().map(str::to_string);
414 let mut evals = Vec::new();
415 match &state.pool {
416 DbPool::Sqlite(pool) => {
417 let rows = sqlx::query(LIST_EVALS_SQL_SQLITE)
418 .bind(project.clone())
419 .bind(project)
420 .fetch_all(pool)
421 .await
422 .map_err(db_error)?;
423 for row in &rows {
424 evals.push(eval_row_json!(row));
425 }
426 }
427 DbPool::Postgres(pool) => {
428 let rows = sqlx::query(LIST_EVALS_SQL_PG)
429 .bind(project)
430 .fetch_all(pool)
431 .await
432 .map_err(db_error)?;
433 for row in &rows {
434 evals.push(eval_row_json!(row));
435 }
436 }
437 }
438 Ok(Json(evals))
439}
440
441async fn get_trace(
442 Path(trace_id): Path<String>,
443 headers: HeaderMap,
444 State(state): State<AppState>,
445) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
446 let project = authorize(&state, &headers)?.project().map(str::to_string);
447 let mut spans = Vec::new();
448 match &state.pool {
449 DbPool::Sqlite(pool) => {
450 let rows = sqlx::query(GET_TRACE_SQL_SQLITE)
451 .bind(trace_id)
452 .bind(project.clone())
453 .bind(project)
454 .fetch_all(pool)
455 .await
456 .map_err(db_error)?;
457 for row in &rows {
458 spans.push(span_detail_json!(row));
459 }
460 }
461 DbPool::Postgres(pool) => {
462 let rows = sqlx::query(GET_TRACE_SQL_PG)
463 .bind(trace_id)
464 .bind(project)
465 .fetch_all(pool)
466 .await
467 .map_err(db_error)?;
468 for row in &rows {
469 spans.push(span_detail_json!(row));
470 }
471 }
472 }
473 Ok(Json(spans))
474}
475
476use serde::Deserialize;
477use trace_weft::hitl::HitlResponse;
478
479async fn get_pending_approvals(
480 headers: HeaderMap,
481 State(state): State<AppState>,
482) -> Result<Json<Vec<String>>, StatusCode> {
483 authorize(&state, &headers)?;
484 Ok(Json(trace_weft::hitl::get_pending_approvals()))
485}
486
487#[derive(Deserialize)]
488struct ResolveRequest {
489 span_id: String,
490 action: String,
491 value: Option<serde_json::Value>,
492 reason: Option<String>,
493}
494
495async fn resolve_approval(
496 headers: HeaderMap,
497 State(state): State<AppState>,
498 Json(req): Json<ResolveRequest>,
499) -> Result<StatusCode, StatusCode> {
500 authorize(&state, &headers)?;
501 let response = if req.action == "approve" {
502 HitlResponse::Approved(req.value.unwrap_or(serde_json::json!({})))
503 } else {
504 HitlResponse::Rejected(req.reason.unwrap_or_else(|| "Rejected by user".to_string()))
505 };
506
507 if trace_weft::hitl::resolve_approval(&req.span_id, response).is_ok() {
508 Ok(StatusCode::OK)
509 } else {
510 Err(StatusCode::NOT_FOUND)
511 }
512}
513
514#[cfg(test)]
515mod cors_tests {
516 use super::is_allowed_origin;
517
518 #[test]
519 fn allows_local_ui_and_tauri_origins() {
520 for origin in [
521 "http://localhost:5173",
522 "http://127.0.0.1:5173",
523 "http://localhost:3000",
524 "http://localhost",
525 "http://127.0.0.1",
526 "tauri://localhost",
527 "http://tauri.localhost",
528 ] {
529 assert!(is_allowed_origin(origin), "{origin} should be allowed");
530 }
531 }
532
533 #[test]
534 fn rejects_external_and_lookalike_origins() {
535 for origin in [
536 "https://evil.example.com",
537 "http://localhost.evil.com",
538 "http://127.0.0.1.evil.com",
539 "https://localhost:5173",
540 "http://evil.com",
541 "null",
542 ] {
543 assert!(!is_allowed_origin(origin), "{origin} should be rejected");
544 }
545 }
546}