Skip to main content

trace_weft_server/
lib.rs

1pub mod auth;
2pub mod storage;
3
4use auth::{Auth, AuthConfig};
5use axum::{
6    Json, Router,
7    extract::{Path, State},
8    http::{HeaderMap, HeaderValue, Method, StatusCode, header},
9    routing::{get, post},
10};
11use sqlx::{PgPool, Row, SqlitePool, postgres::PgPoolOptions, sqlite::SqlitePoolOptions};
12use std::net::SocketAddr;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tower_http::cors::{AllowOrigin, CorsLayer};
16use trace_weft_core::SpanRecord;
17use trace_weft_recorder::TraceStore;
18
19#[derive(Clone)]
20pub enum DbPool {
21    Sqlite(SqlitePool),
22    Postgres(PgPool),
23}
24
25#[derive(Clone)]
26pub struct AppState {
27    pub pool: DbPool,
28    pub blob_store: Arc<dyn trace_weft_core::BlobStore>,
29    pub trace_store: Arc<dyn TraceStore>,
30    pub clickhouse: Option<Arc<storage::analytics::ClickHouseAnalytics>>,
31    pub auth: Arc<AuthConfig>,
32}
33
34/// Start the server with the **production-secure** auth default
35/// ([`AuthConfig::from_env`]): unauthenticated requests are rejected unless
36/// `TRACE_WEFT_API_KEYS`/`TRACE_WEFT_DEV_MODE` are configured. Runs until the
37/// process ends.
38pub async fn start_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
39    start_server_with_shutdown(
40        db_url,
41        port,
42        blob_dir,
43        AuthConfig::from_env(),
44        std::future::pending::<()>(),
45    )
46    .await
47}
48
49/// Start a **local-first** dev server: the auth bypass defaults on when no keys
50/// are configured (see [`AuthConfig::from_env_local_first`]), so the local UI
51/// works without keys. Used by `trace-weft dev`.
52pub async fn start_dev_server(db_url: &str, port: u16, blob_dir: PathBuf) -> anyhow::Result<()> {
53    start_server_with_shutdown(
54        db_url,
55        port,
56        blob_dir,
57        AuthConfig::from_env_local_first(),
58        std::future::pending::<()>(),
59    )
60    .await
61}
62
63/// Start the server with an explicit [`AuthConfig`], stopping gracefully when
64/// `shutdown` resolves. Used by the desktop app to start/stop the embedded
65/// server on demand and to drain it cleanly on app exit.
66pub async fn start_server_with_shutdown(
67    db_url: &str,
68    port: u16,
69    blob_dir: PathBuf,
70    auth: AuthConfig,
71    shutdown: impl std::future::Future<Output = ()> + Send + 'static,
72) -> anyhow::Result<()> {
73    let pool = if db_url.starts_with("postgres://") || db_url.starts_with("postgresql://") {
74        let pg_pool = PgPoolOptions::new().connect(db_url).await?;
75        DbPool::Postgres(pg_pool)
76    } else {
77        // Assume sqlite file path or sqlite:// url
78        let url = if db_url.starts_with("sqlite://") {
79            db_url.to_string()
80        } else {
81            if let Some(parent) = std::path::Path::new(db_url).parent() {
82                tokio::fs::create_dir_all(parent).await?;
83            }
84            format!("sqlite://{}?mode=rwc", db_url)
85        };
86        let sq_pool = SqlitePoolOptions::new().connect(&url).await?;
87        DbPool::Sqlite(sq_pool)
88    };
89
90    let blob_store = Arc::new(storage::blob::LocalBlobStore::new(blob_dir));
91
92    let trace_store: Arc<dyn TraceStore> = match &pool {
93        DbPool::Postgres(pg_pool) => Arc::new(storage::postgres::PostgresRecorder {
94            pool: pg_pool.clone(),
95        }),
96        DbPool::Sqlite(sq_pool) => {
97            Arc::new(trace_weft_recorder::sqlite::SqliteRecorder::from_pool(sq_pool.clone()).await?)
98        }
99    };
100
101    // Enterprise Analytics (Stubbed connection if env var is present)
102    let clickhouse = if let Ok(ch_url) = std::env::var("TRACE_WEFT_CH_URL") {
103        tracing::info!("Initializing ClickHouse analytics connected to {}", ch_url);
104        Some(Arc::new(storage::analytics::ClickHouseAnalytics::new(
105            &ch_url, "default", "", "default",
106        )))
107    } else {
108        None
109    };
110
111    let state = AppState {
112        pool,
113        blob_store,
114        trace_store,
115        clickhouse,
116        auth: Arc::new(auth),
117    };
118
119    let app = build_router(state);
120
121    let addr = SocketAddr::from(([127, 0, 0, 1], port));
122    tracing::info!("Server listening on http://{}", addr);
123
124    let listener = tokio::net::TcpListener::bind(addr).await?;
125    axum::serve(listener, app)
126        .with_graceful_shutdown(shutdown)
127        .await?;
128
129    Ok(())
130}
131
132/// Build the TraceWeft API router over the given application state.
133pub fn build_router(state: AppState) -> Router {
134    Router::new()
135        .route("/api/traces", get(list_traces))
136        .route("/api/traces/{trace_id}", get(get_trace))
137        .route("/api/evals", get(list_evals))
138        .route("/api/v1/batch", post(batch_ingest))
139        .route("/api/hitl/pending", get(get_pending_approvals))
140        .route("/api/hitl/resolve", post(resolve_approval))
141        .layer(local_cors())
142        .with_state(state)
143}
144
145/// CORS for a local-first server: only the local dev UI and the desktop webview
146/// may read API responses. A permissive policy would let any website the user
147/// visits script `127.0.0.1:<port>` and exfiltrate locally-stored prompts and
148/// tool outputs (and, for JSON `POST`s, drive HITL/ingest via CSRF). Restricting
149/// the allowed origins makes the browser block both.
150fn local_cors() -> CorsLayer {
151    CorsLayer::new()
152        .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
153        .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE])
154        .allow_origin(AllowOrigin::predicate(|origin: &HeaderValue, _req| {
155            origin.to_str().map(is_allowed_origin).unwrap_or(false)
156        }))
157}
158
159/// Allow the Tauri webview origins and loopback (any port, for the Vite dev
160/// server and direct browser access); reject everything else.
161fn is_allowed_origin(origin: &str) -> bool {
162    if origin == "tauri://localhost" || origin == "http://tauri.localhost" {
163        return true;
164    }
165    ["http://localhost", "http://127.0.0.1"]
166        .iter()
167        .any(|host| origin == *host || origin.starts_with(&format!("{host}:")))
168}
169
170/// Resolve the request's API key to a tenant, or `401` when none is valid and
171/// the dev bypass is off.
172fn authorize(state: &AppState, headers: &HeaderMap) -> Result<Auth, StatusCode> {
173    state
174        .auth
175        .authenticate(headers)
176        .ok_or(StatusCode::UNAUTHORIZED)
177}
178
179async fn batch_ingest(
180    headers: HeaderMap,
181    State(state): State<AppState>,
182    Json(mut spans): Json<Vec<SpanRecord>>,
183) -> Result<StatusCode, StatusCode> {
184    let auth = authorize(&state, &headers)?;
185    // The server is authoritative on tenancy: stamp the authenticated project
186    // onto every span so a client cannot assert someone else's project_id.
187    let project_id = auth.project().map(|p| p.to_string());
188    for span in &mut spans {
189        span.project_id = project_id.clone();
190    }
191
192    tracing::info!(
193        "Received batch of {} spans for project {:?}",
194        spans.len(),
195        project_id
196    );
197
198    // 1. Ingest metadata into Postgres
199    for span in &spans {
200        // In a real app, this should be a bulk insert
201        if let Err(e) = state.trace_store.record_span(span.clone()).await {
202            tracing::error!("Failed to record span: {}", e);
203            return Err(StatusCode::INTERNAL_SERVER_ERROR);
204        }
205    }
206
207    // 2. Stream to ClickHouse for analytics
208    if let Some(ch) = &state.clickhouse
209        && let Err(e) = ch.ingest_batch(&spans).await
210    {
211        tracing::warn!("Failed to stream to ClickHouse: {}", e);
212    }
213
214    Ok(StatusCode::ACCEPTED)
215}
216
217/// Log a database error and surface it as a 500. Used by every query handler so
218/// failures are recorded rather than silently flattened to an empty body.
219fn db_error<E: std::fmt::Display>(e: E) -> StatusCode {
220    tracing::error!("database query failed: {e}");
221    StatusCode::INTERNAL_SERVER_ERROR
222}
223
224/// Decode a JSON column we wrote ourselves. A parse failure means the row is
225/// corrupt, so we surface a 500 instead of masking it with an empty object —
226/// silently substituting `{}` would hide data loss from the caller.
227fn parse_json_column(raw: &str) -> Result<serde_json::Value, StatusCode> {
228    serde_json::from_str(raw).map_err(|e| {
229        tracing::error!("corrupt JSON in spans column: {e}");
230        StatusCode::INTERNAL_SERVER_ERROR
231    })
232}
233
234/// Decode a nullable JSON column, preserving SQL `NULL` as JSON `null`.
235fn parse_opt_json_column(raw: Option<String>) -> Result<serde_json::Value, StatusCode> {
236    match raw {
237        Some(s) => parse_json_column(&s),
238        None => Ok(serde_json::Value::Null),
239    }
240}
241
242// The SQLite and Postgres `spans` tables share an identical column layout, so a
243// single row shape maps to JSON for either backend. These macros expand the
244// same extraction against `SqliteRow` or `PgRow` (the `?` inside propagates to
245// the calling handler), keeping the two dialects from drifting apart.
246
247/// One row of the trace-summary aggregate (see `list_traces`).
248macro_rules! trace_summary_json {
249    ($row:expr) => {{
250        let row = $row;
251        let trace_id: String = row.get("trace_id");
252        let run_id: String = row.get("run_id");
253        let start_time: i64 = row.get("start_time");
254        let end_time: Option<i64> = row.get("end_time");
255        let span_count: i64 = row.get("span_count");
256        let has_error: i64 = row.get("has_error");
257        serde_json::json!({
258            "trace_id": trace_id,
259            "run_id": run_id,
260            "start_time": start_time,
261            "end_time": end_time,
262            "span_count": span_count,
263            // A trace is errored if any of its spans errored, otherwise ok.
264            "status": if has_error != 0 { "error" } else { "ok" },
265        })
266    }};
267}
268
269/// One evaluator span row (see `list_evals`).
270macro_rules! eval_row_json {
271    ($row:expr) => {{
272        let row = $row;
273        let trace_id: String = row.get("trace_id");
274        let span_id: String = row.get("span_id");
275        let name: String = row.get("name");
276        let start_time: i64 = row.get("start_time");
277        let status: String = row.get("status");
278        let attributes: String = row.get("attributes");
279        serde_json::json!({
280            "trace_id": trace_id,
281            "span_id": span_id,
282            "name": name,
283            "start_time": start_time,
284            "status": status,
285            "attributes": parse_json_column(&attributes)?,
286        })
287    }};
288}
289
290/// One full span row (see `get_trace`).
291macro_rules! span_detail_json {
292    ($row:expr) => {{
293        let row = $row;
294        let trace_id: String = row.get("trace_id");
295        let span_id: String = row.get("span_id");
296        let parent_span_id: Option<String> = row.get("parent_span_id");
297        let span_kind: String = row.get("span_kind");
298        let name: String = row.get("name");
299        let start_time: i64 = row.get("start_time");
300        let end_time: Option<i64> = row.get("end_time");
301        let status: String = row.get("status");
302        let attributes: String = row.get("attributes");
303        let latency_ms: Option<i64> = row.get("latency_ms");
304        let input_ref: Option<String> = row.get("input_ref");
305        let output_ref: Option<String> = row.get("output_ref");
306        serde_json::json!({
307            "trace_id": trace_id,
308            "span_id": span_id,
309            "parent_span_id": parent_span_id,
310            "span_kind": span_kind,
311            "name": name,
312            "start_time": start_time,
313            "end_time": end_time,
314            "status": status,
315            "attributes": parse_json_column(&attributes)?,
316            "latency_ms": latency_ms,
317            "input_ref": parse_opt_json_column(input_ref)?,
318            "output_ref": parse_opt_json_column(output_ref)?,
319        })
320    }};
321}
322
323// Project scoping: each query filters on `project_id` against the bound
324// `project` value. A real tenant binds its project id; the dev bypass binds
325// SQL `NULL`, and the `OR <param> IS NULL` arm then matches every row so
326// local-first runs see all traces. Postgres reuses one `$1`; SQLite repeats the
327// positional `?`, so the project value is bound twice there.
328//
329// The aggregate is portable: every span of a trace shares a run_id (so
330// MIN(run_id) is deterministic), the error rollup is CAST to BIGINT so both
331// engines decode it as i64, and only grouped/aggregated columns are selected so
332// Postgres (which rejects bare columns under GROUP BY) is happy.
333const LIST_TRACES_SQL_SQLITE: &str = r#"
334    SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
335           MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
336           CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error
337    FROM spans
338    WHERE (project_id = ? OR ? IS NULL)
339    GROUP BY trace_id
340    ORDER BY start_time DESC
341    LIMIT 50
342"#;
343
344const LIST_TRACES_SQL_PG: &str = r#"
345    SELECT trace_id, MIN(run_id) AS run_id, MIN(start_time) AS start_time,
346           MAX(end_time) AS end_time, COUNT(span_id) AS span_count,
347           CAST(MAX(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS BIGINT) AS has_error
348    FROM spans
349    WHERE (project_id = $1 OR $1 IS NULL)
350    GROUP BY trace_id
351    ORDER BY start_time DESC
352    LIMIT 50
353"#;
354
355const LIST_EVALS_SQL_SQLITE: &str = r#"
356    SELECT trace_id, span_id, name, start_time, status, attributes
357    FROM spans
358    WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
359      AND (project_id = ? OR ? IS NULL)
360    ORDER BY start_time DESC
361    LIMIT 50
362"#;
363
364const LIST_EVALS_SQL_PG: &str = r#"
365    SELECT trace_id, span_id, name, start_time, status, attributes
366    FROM spans
367    WHERE (span_kind = 'evaluator' OR span_kind = 'Evaluator')
368      AND (project_id = $1 OR $1 IS NULL)
369    ORDER BY start_time DESC
370    LIMIT 50
371"#;
372
373const GET_TRACE_SQL_SQLITE: &str = "SELECT * FROM spans WHERE trace_id = ? AND (project_id = ? OR ? IS NULL) ORDER BY start_time ASC";
374
375const GET_TRACE_SQL_PG: &str = "SELECT * FROM spans WHERE trace_id = $1 AND (project_id = $2 OR $2 IS NULL) ORDER BY start_time ASC";
376
377async fn list_traces(
378    headers: HeaderMap,
379    State(state): State<AppState>,
380) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
381    let project = authorize(&state, &headers)?.project().map(str::to_string);
382    let mut traces = Vec::new();
383    match &state.pool {
384        DbPool::Sqlite(pool) => {
385            let rows = sqlx::query(LIST_TRACES_SQL_SQLITE)
386                .bind(project.clone())
387                .bind(project)
388                .fetch_all(pool)
389                .await
390                .map_err(db_error)?;
391            for row in &rows {
392                traces.push(trace_summary_json!(row));
393            }
394        }
395        DbPool::Postgres(pool) => {
396            let rows = sqlx::query(LIST_TRACES_SQL_PG)
397                .bind(project)
398                .fetch_all(pool)
399                .await
400                .map_err(db_error)?;
401            for row in &rows {
402                traces.push(trace_summary_json!(row));
403            }
404        }
405    }
406    Ok(Json(traces))
407}
408
409async fn list_evals(
410    headers: HeaderMap,
411    State(state): State<AppState>,
412) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
413    let project = authorize(&state, &headers)?.project().map(str::to_string);
414    let mut evals = Vec::new();
415    match &state.pool {
416        DbPool::Sqlite(pool) => {
417            let rows = sqlx::query(LIST_EVALS_SQL_SQLITE)
418                .bind(project.clone())
419                .bind(project)
420                .fetch_all(pool)
421                .await
422                .map_err(db_error)?;
423            for row in &rows {
424                evals.push(eval_row_json!(row));
425            }
426        }
427        DbPool::Postgres(pool) => {
428            let rows = sqlx::query(LIST_EVALS_SQL_PG)
429                .bind(project)
430                .fetch_all(pool)
431                .await
432                .map_err(db_error)?;
433            for row in &rows {
434                evals.push(eval_row_json!(row));
435            }
436        }
437    }
438    Ok(Json(evals))
439}
440
441async fn get_trace(
442    Path(trace_id): Path<String>,
443    headers: HeaderMap,
444    State(state): State<AppState>,
445) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
446    let project = authorize(&state, &headers)?.project().map(str::to_string);
447    let mut spans = Vec::new();
448    match &state.pool {
449        DbPool::Sqlite(pool) => {
450            let rows = sqlx::query(GET_TRACE_SQL_SQLITE)
451                .bind(trace_id)
452                .bind(project.clone())
453                .bind(project)
454                .fetch_all(pool)
455                .await
456                .map_err(db_error)?;
457            for row in &rows {
458                spans.push(span_detail_json!(row));
459            }
460        }
461        DbPool::Postgres(pool) => {
462            let rows = sqlx::query(GET_TRACE_SQL_PG)
463                .bind(trace_id)
464                .bind(project)
465                .fetch_all(pool)
466                .await
467                .map_err(db_error)?;
468            for row in &rows {
469                spans.push(span_detail_json!(row));
470            }
471        }
472    }
473    Ok(Json(spans))
474}
475
476use serde::Deserialize;
477use trace_weft::hitl::HitlResponse;
478
479async fn get_pending_approvals(
480    headers: HeaderMap,
481    State(state): State<AppState>,
482) -> Result<Json<Vec<String>>, StatusCode> {
483    authorize(&state, &headers)?;
484    Ok(Json(trace_weft::hitl::get_pending_approvals()))
485}
486
487#[derive(Deserialize)]
488struct ResolveRequest {
489    span_id: String,
490    action: String,
491    value: Option<serde_json::Value>,
492    reason: Option<String>,
493}
494
495async fn resolve_approval(
496    headers: HeaderMap,
497    State(state): State<AppState>,
498    Json(req): Json<ResolveRequest>,
499) -> Result<StatusCode, StatusCode> {
500    authorize(&state, &headers)?;
501    let response = if req.action == "approve" {
502        HitlResponse::Approved(req.value.unwrap_or(serde_json::json!({})))
503    } else {
504        HitlResponse::Rejected(req.reason.unwrap_or_else(|| "Rejected by user".to_string()))
505    };
506
507    if trace_weft::hitl::resolve_approval(&req.span_id, response).is_ok() {
508        Ok(StatusCode::OK)
509    } else {
510        Err(StatusCode::NOT_FOUND)
511    }
512}
513
514#[cfg(test)]
515mod cors_tests {
516    use super::is_allowed_origin;
517
518    #[test]
519    fn allows_local_ui_and_tauri_origins() {
520        for origin in [
521            "http://localhost:5173",
522            "http://127.0.0.1:5173",
523            "http://localhost:3000",
524            "http://localhost",
525            "http://127.0.0.1",
526            "tauri://localhost",
527            "http://tauri.localhost",
528        ] {
529            assert!(is_allowed_origin(origin), "{origin} should be allowed");
530        }
531    }
532
533    #[test]
534    fn rejects_external_and_lookalike_origins() {
535        for origin in [
536            "https://evil.example.com",
537            "http://localhost.evil.com",
538            "http://127.0.0.1.evil.com",
539            "https://localhost:5173",
540            "http://evil.com",
541            "null",
542        ] {
543            assert!(!is_allowed_origin(origin), "{origin} should be rejected");
544        }
545    }
546}