1use crate::TraceStore;
2use anyhow::Result;
3use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
4use std::path::PathBuf;
5use trace_weft_core::{EventRecord, SpanRecord};
6
7pub struct SqliteRecorder {
8 pool: SqlitePool,
9}
10
11impl SqliteRecorder {
12 pub async fn new(db_path: PathBuf) -> Result<Self> {
13 if let Some(parent) = db_path.parent() {
14 tokio::fs::create_dir_all(parent).await?;
15 }
16
17 let db_url = format!("sqlite://{}?mode=rwc", db_path.to_string_lossy());
18
19 let pool = SqlitePoolOptions::new().connect(&db_url).await?;
20
21 Self::from_pool(pool).await
22 }
23
24 pub async fn from_pool(pool: SqlitePool) -> Result<Self> {
25 sqlx::migrate!("./migrations").run(&pool).await?;
27
28 Ok(Self { pool })
29 }
30}
31
32#[async_trait::async_trait]
33impl TraceStore for SqliteRecorder {
34 async fn record_span(&self, span: SpanRecord) -> Result<()> {
35 let trace_id = span.trace_id.0.to_string();
36 let span_id = span.span_id.0.to_string();
37 let parent_span_id = span.parent_span_id.map(|id| id.0.to_string());
38 let run_id = span.run_id.0.to_string();
39 let session_id = span.session_id.map(|id| id.0.to_string());
40 let span_kind = serde_json::to_string(&span.span_kind)?
41 .trim_matches('"')
42 .to_string();
43 let status = serde_json::to_string(&span.status)?
44 .trim_matches('"')
45 .to_string();
46
47 let attributes = serde_json::to_string(&span.attributes)?;
48 let otel_attributes = serde_json::to_string(&span.otel_attributes)?;
49 let openinference_attributes = serde_json::to_string(&span.openinference_attributes)?;
50 let memory_state = span
51 .memory_state
52 .map(|s| serde_json::to_string(&s).unwrap());
53
54 let input_ref = span.input_ref.map(|r| serde_json::to_string(&r).unwrap());
55 let output_ref = span.output_ref.map(|r| serde_json::to_string(&r).unwrap());
56 let retrieved_document_refs = serde_json::to_string(&span.retrieved_document_refs)?;
57 let token_usage = span.token_usage.map(|u| serde_json::to_string(&u).unwrap());
58 let cost_estimate = span
59 .cost_estimate
60 .map(|c| serde_json::to_string(&c).unwrap());
61 let redaction_policy = serde_json::to_string(&span.redaction_policy)?
62 .trim_matches('"')
63 .to_string();
64
65 sqlx::query(
66 r#"
67 INSERT INTO spans (
68 trace_id, span_id, parent_span_id, run_id, session_id, user_id_hash,
69 span_kind, name, start_time, end_time, status, status_message, error_type, error_message_redacted,
70 attributes, otel_attributes, openinference_attributes, memory_state,
71 input_ref, output_ref, prompt_template_id, prompt_version,
72 model_provider, model_name, tool_name, tool_schema_hash, retrieval_query_hash,
73 retrieved_document_refs, token_usage, cost_estimate, latency_ms, retry_count, cache_hit,
74 redaction_policy, schema_version, project_id
75 ) VALUES (
76 ?, ?, ?, ?, ?, ?,
77 ?, ?, ?, ?, ?, ?, ?, ?,
78 ?, ?, ?, ?,
79 ?, ?, ?, ?,
80 ?, ?, ?, ?, ?,
81 ?, ?, ?, ?, ?, ?,
82 ?, ?, ?
83 )
84 -- A span may be recorded twice with the same span_id (e.g. a HITL
85 -- breakpoint: first PendingApproval, then Ok once resolved). Upsert
86 -- so the resolved state replaces the pending row instead of failing
87 -- the primary key. For ordinary single-write spans the conflict
88 -- arm never fires.
89 ON CONFLICT(span_id) DO UPDATE SET
90 trace_id=excluded.trace_id, parent_span_id=excluded.parent_span_id,
91 run_id=excluded.run_id, session_id=excluded.session_id,
92 user_id_hash=excluded.user_id_hash, span_kind=excluded.span_kind,
93 name=excluded.name, start_time=excluded.start_time, end_time=excluded.end_time,
94 status=excluded.status, status_message=excluded.status_message,
95 error_type=excluded.error_type, error_message_redacted=excluded.error_message_redacted,
96 attributes=excluded.attributes, otel_attributes=excluded.otel_attributes,
97 openinference_attributes=excluded.openinference_attributes,
98 memory_state=excluded.memory_state, input_ref=excluded.input_ref,
99 output_ref=excluded.output_ref, prompt_template_id=excluded.prompt_template_id,
100 prompt_version=excluded.prompt_version, model_provider=excluded.model_provider,
101 model_name=excluded.model_name, tool_name=excluded.tool_name,
102 tool_schema_hash=excluded.tool_schema_hash,
103 retrieval_query_hash=excluded.retrieval_query_hash,
104 retrieved_document_refs=excluded.retrieved_document_refs,
105 token_usage=excluded.token_usage, cost_estimate=excluded.cost_estimate,
106 latency_ms=excluded.latency_ms, retry_count=excluded.retry_count,
107 cache_hit=excluded.cache_hit, redaction_policy=excluded.redaction_policy,
108 schema_version=excluded.schema_version, project_id=excluded.project_id
109 "#,
110 )
111 .bind(trace_id).bind(span_id).bind(parent_span_id).bind(run_id).bind(session_id).bind(span.user_id_hash)
112 .bind(span_kind).bind(span.name).bind(span.start_time as i64).bind(span.end_time.map(|t| t as i64)).bind(status).bind(span.status_message).bind(span.error_type).bind(span.error_message_redacted)
113 .bind(attributes).bind(otel_attributes).bind(openinference_attributes).bind(memory_state)
114 .bind(input_ref).bind(output_ref).bind(span.prompt_template_id).bind(span.prompt_version)
115 .bind(span.model_provider).bind(span.model_name).bind(span.tool_name).bind(span.tool_schema_hash).bind(span.retrieval_query_hash)
116 .bind(retrieved_document_refs).bind(token_usage).bind(cost_estimate).bind(span.latency_ms.map(|t| t as i64)).bind(span.retry_count).bind(span.cache_hit)
117 .bind(redaction_policy).bind(span.schema_version).bind(span.project_id)
118 .execute(&self.pool)
119 .await?;
120
121 Ok(())
122 }
123
124 async fn record_event(&self, event: EventRecord) -> Result<()> {
125 let event_kind = serde_json::to_string(&event.event_kind)?
126 .trim_matches('"')
127 .to_string();
128 let attributes = serde_json::to_string(&event.attributes)?;
129
130 sqlx::query(
131 r#"
132 INSERT INTO events (
133 event_id, trace_id, run_id, parent_span_id, seq,
134 event_kind, name, timestamp, attributes, schema_version
135 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
136 "#,
137 )
138 .bind(event.event_id.0.to_string())
139 .bind(event.trace_id.0.to_string())
140 .bind(event.run_id.0.to_string())
141 .bind(event.parent_span_id.map(|id| id.0.to_string()))
142 .bind(event.seq as i64)
143 .bind(event_kind)
144 .bind(event.name)
145 .bind(event.timestamp as i64)
146 .bind(attributes)
147 .bind(event.schema_version)
148 .execute(&self.pool)
149 .await?;
150
151 Ok(())
152 }
153}