Skip to main content

trace_weft_recorder/
sqlite.rs

1use crate::TraceStore;
2use anyhow::Result;
3use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
4use std::path::PathBuf;
5use trace_weft_core::{EventRecord, SpanRecord};
6
7pub struct SqliteRecorder {
8    pool: SqlitePool,
9}
10
11impl SqliteRecorder {
12    pub async fn new(db_path: PathBuf) -> Result<Self> {
13        if let Some(parent) = db_path.parent() {
14            tokio::fs::create_dir_all(parent).await?;
15        }
16
17        let db_url = format!("sqlite://{}?mode=rwc", db_path.to_string_lossy());
18
19        let pool = SqlitePoolOptions::new().connect(&db_url).await?;
20
21        Self::from_pool(pool).await
22    }
23
24    pub async fn from_pool(pool: SqlitePool) -> Result<Self> {
25        // Run migrations
26        sqlx::migrate!("./migrations").run(&pool).await?;
27
28        Ok(Self { pool })
29    }
30}
31
32#[async_trait::async_trait]
33impl TraceStore for SqliteRecorder {
34    async fn record_span(&self, span: SpanRecord) -> Result<()> {
35        let trace_id = span.trace_id.0.to_string();
36        let span_id = span.span_id.0.to_string();
37        let parent_span_id = span.parent_span_id.map(|id| id.0.to_string());
38        let run_id = span.run_id.0.to_string();
39        let session_id = span.session_id.map(|id| id.0.to_string());
40        let span_kind = serde_json::to_string(&span.span_kind)?
41            .trim_matches('"')
42            .to_string();
43        let status = serde_json::to_string(&span.status)?
44            .trim_matches('"')
45            .to_string();
46
47        let attributes = serde_json::to_string(&span.attributes)?;
48        let otel_attributes = serde_json::to_string(&span.otel_attributes)?;
49        let openinference_attributes = serde_json::to_string(&span.openinference_attributes)?;
50        let memory_state = span
51            .memory_state
52            .map(|s| serde_json::to_string(&s).unwrap());
53
54        let input_ref = span.input_ref.map(|r| serde_json::to_string(&r).unwrap());
55        let output_ref = span.output_ref.map(|r| serde_json::to_string(&r).unwrap());
56        let retrieved_document_refs = serde_json::to_string(&span.retrieved_document_refs)?;
57        let token_usage = span.token_usage.map(|u| serde_json::to_string(&u).unwrap());
58        let cost_estimate = span
59            .cost_estimate
60            .map(|c| serde_json::to_string(&c).unwrap());
61        let redaction_policy = serde_json::to_string(&span.redaction_policy)?
62            .trim_matches('"')
63            .to_string();
64
65        sqlx::query(
66            r#"
67            INSERT INTO spans (
68                trace_id, span_id, parent_span_id, run_id, session_id, user_id_hash,
69                span_kind, name, start_time, end_time, status, status_message, error_type, error_message_redacted,
70                attributes, otel_attributes, openinference_attributes, memory_state,
71                input_ref, output_ref, prompt_template_id, prompt_version,
72                model_provider, model_name, tool_name, tool_schema_hash, retrieval_query_hash,
73                retrieved_document_refs, token_usage, cost_estimate, latency_ms, retry_count, cache_hit,
74                redaction_policy, schema_version, project_id
75            ) VALUES (
76                ?, ?, ?, ?, ?, ?,
77                ?, ?, ?, ?, ?, ?, ?, ?,
78                ?, ?, ?, ?,
79                ?, ?, ?, ?,
80                ?, ?, ?, ?, ?,
81                ?, ?, ?, ?, ?, ?,
82                ?, ?, ?
83            )
84            -- A span may be recorded twice with the same span_id (e.g. a HITL
85            -- breakpoint: first PendingApproval, then Ok once resolved). Upsert
86            -- so the resolved state replaces the pending row instead of failing
87            -- the primary key. For ordinary single-write spans the conflict
88            -- arm never fires.
89            ON CONFLICT(span_id) DO UPDATE SET
90                trace_id=excluded.trace_id, parent_span_id=excluded.parent_span_id,
91                run_id=excluded.run_id, session_id=excluded.session_id,
92                user_id_hash=excluded.user_id_hash, span_kind=excluded.span_kind,
93                name=excluded.name, start_time=excluded.start_time, end_time=excluded.end_time,
94                status=excluded.status, status_message=excluded.status_message,
95                error_type=excluded.error_type, error_message_redacted=excluded.error_message_redacted,
96                attributes=excluded.attributes, otel_attributes=excluded.otel_attributes,
97                openinference_attributes=excluded.openinference_attributes,
98                memory_state=excluded.memory_state, input_ref=excluded.input_ref,
99                output_ref=excluded.output_ref, prompt_template_id=excluded.prompt_template_id,
100                prompt_version=excluded.prompt_version, model_provider=excluded.model_provider,
101                model_name=excluded.model_name, tool_name=excluded.tool_name,
102                tool_schema_hash=excluded.tool_schema_hash,
103                retrieval_query_hash=excluded.retrieval_query_hash,
104                retrieved_document_refs=excluded.retrieved_document_refs,
105                token_usage=excluded.token_usage, cost_estimate=excluded.cost_estimate,
106                latency_ms=excluded.latency_ms, retry_count=excluded.retry_count,
107                cache_hit=excluded.cache_hit, redaction_policy=excluded.redaction_policy,
108                schema_version=excluded.schema_version, project_id=excluded.project_id
109            "#,
110        )
111        .bind(trace_id).bind(span_id).bind(parent_span_id).bind(run_id).bind(session_id).bind(span.user_id_hash)
112        .bind(span_kind).bind(span.name).bind(span.start_time as i64).bind(span.end_time.map(|t| t as i64)).bind(status).bind(span.status_message).bind(span.error_type).bind(span.error_message_redacted)
113        .bind(attributes).bind(otel_attributes).bind(openinference_attributes).bind(memory_state)
114        .bind(input_ref).bind(output_ref).bind(span.prompt_template_id).bind(span.prompt_version)
115        .bind(span.model_provider).bind(span.model_name).bind(span.tool_name).bind(span.tool_schema_hash).bind(span.retrieval_query_hash)
116        .bind(retrieved_document_refs).bind(token_usage).bind(cost_estimate).bind(span.latency_ms.map(|t| t as i64)).bind(span.retry_count).bind(span.cache_hit)
117        .bind(redaction_policy).bind(span.schema_version).bind(span.project_id)
118        .execute(&self.pool)
119        .await?;
120
121        Ok(())
122    }
123
124    async fn record_event(&self, event: EventRecord) -> Result<()> {
125        let event_kind = serde_json::to_string(&event.event_kind)?
126            .trim_matches('"')
127            .to_string();
128        let attributes = serde_json::to_string(&event.attributes)?;
129
130        sqlx::query(
131            r#"
132            INSERT INTO events (
133                event_id, trace_id, run_id, parent_span_id, seq,
134                event_kind, name, timestamp, attributes, schema_version
135            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
136            "#,
137        )
138        .bind(event.event_id.0.to_string())
139        .bind(event.trace_id.0.to_string())
140        .bind(event.run_id.0.to_string())
141        .bind(event.parent_span_id.map(|id| id.0.to_string()))
142        .bind(event.seq as i64)
143        .bind(event_kind)
144        .bind(event.name)
145        .bind(event.timestamp as i64)
146        .bind(attributes)
147        .bind(event.schema_version)
148        .execute(&self.pool)
149        .await?;
150
151        Ok(())
152    }
153}