1use anyhow::Result;
2use sqlx::{PgPool, Postgres, postgres::PgArguments, postgres::PgPoolOptions, query::Query};
3use trace_weft_core::{EventRecord, SpanRecord};
4use trace_weft_recorder::TraceStore;
5
6pub struct PostgresRecorder {
7 pub pool: PgPool,
8}
9
10impl PostgresRecorder {
11 pub async fn new(db_url: &str) -> Result<Self> {
12 let pool = PgPoolOptions::new()
13 .max_connections(5)
14 .connect(db_url)
15 .await?;
16
17 Self::from_pool(pool).await
18 }
19
20 pub async fn from_pool(pool: PgPool) -> Result<Self> {
25 sqlx::raw_sql(
29 r#"
30 CREATE TABLE IF NOT EXISTS spans (
31 trace_id TEXT NOT NULL,
32 span_id TEXT NOT NULL PRIMARY KEY,
33 parent_span_id TEXT,
34 run_id TEXT NOT NULL,
35 session_id TEXT,
36 user_id_hash TEXT,
37 span_kind TEXT NOT NULL,
38 name TEXT NOT NULL,
39 start_time BIGINT NOT NULL,
40 end_time BIGINT,
41 status TEXT NOT NULL,
42 status_message TEXT,
43 error_type TEXT,
44 error_message_redacted TEXT,
45 attributes TEXT NOT NULL,
46 otel_attributes TEXT NOT NULL,
47 openinference_attributes TEXT NOT NULL,
48 memory_state TEXT,
49 input_ref TEXT,
50 output_ref TEXT,
51 prompt_template_id TEXT,
52 prompt_version TEXT,
53 model_provider TEXT,
54 model_name TEXT,
55 tool_name TEXT,
56 tool_schema_hash TEXT,
57 retrieval_query_hash TEXT,
58 retrieved_document_refs TEXT NOT NULL,
59 token_usage TEXT,
60 cost_estimate TEXT,
61 latency_ms BIGINT,
62 retry_count BIGINT,
63 cache_hit BOOLEAN,
64 redaction_policy TEXT NOT NULL,
65 schema_version TEXT NOT NULL,
66 project_id TEXT
67 );
68 CREATE INDEX IF NOT EXISTS idx_spans_trace_id ON spans(trace_id);
69 CREATE INDEX IF NOT EXISTS idx_spans_run_id ON spans(run_id);
70 CREATE INDEX IF NOT EXISTS idx_spans_project_id ON spans(project_id);
71
72 CREATE TABLE IF NOT EXISTS events (
73 event_id TEXT NOT NULL PRIMARY KEY,
74 trace_id TEXT NOT NULL,
75 run_id TEXT NOT NULL,
76 parent_span_id TEXT,
77 seq BIGINT NOT NULL,
78 event_kind TEXT NOT NULL,
79 name TEXT NOT NULL,
80 timestamp BIGINT NOT NULL,
81 attributes TEXT NOT NULL,
82 schema_version TEXT NOT NULL
83 );
84 CREATE INDEX IF NOT EXISTS idx_events_trace_id ON events(trace_id);
85 CREATE INDEX IF NOT EXISTS idx_events_parent_span_id ON events(parent_span_id);
86 "#,
87 )
88 .execute(&pool)
89 .await?;
90
91 Ok(Self { pool })
92 }
93}
94
95#[async_trait::async_trait]
96impl TraceStore for PostgresRecorder {
97 async fn record_span(&self, span: SpanRecord) -> Result<()> {
98 let trace_id = span.trace_id.0.to_string();
99 let span_id = span.span_id.0.to_string();
100 let parent_span_id = span.parent_span_id.map(|id| id.0.to_string());
101 let run_id = span.run_id.0.to_string();
102 let session_id = span.session_id.map(|id| id.0.to_string());
103 let span_kind = serde_json::to_string(&span.span_kind)?
104 .trim_matches('"')
105 .to_string();
106 let status = serde_json::to_string(&span.status)?
107 .trim_matches('"')
108 .to_string();
109
110 let attributes = serde_json::to_string(&span.attributes)?;
111 let otel_attributes = serde_json::to_string(&span.otel_attributes)?;
112 let openinference_attributes = serde_json::to_string(&span.openinference_attributes)?;
113 let memory_state = span
114 .memory_state
115 .map(|s| serde_json::to_string(&s).unwrap());
116
117 let input_ref = span.input_ref.map(|r| serde_json::to_string(&r).unwrap());
118 let output_ref = span.output_ref.map(|r| serde_json::to_string(&r).unwrap());
119 let retrieved_document_refs = serde_json::to_string(&span.retrieved_document_refs)?;
120 let token_usage = span.token_usage.map(|u| serde_json::to_string(&u).unwrap());
121 let cost_estimate = span
122 .cost_estimate
123 .map(|c| serde_json::to_string(&c).unwrap());
124 let redaction_policy = serde_json::to_string(&span.redaction_policy)?
125 .trim_matches('"')
126 .to_string();
127
128 let q = sqlx::query(
129 r#"
130 INSERT INTO spans (
131 trace_id, span_id, parent_span_id, run_id, session_id, user_id_hash,
132 span_kind, name, start_time, end_time, status, status_message, error_type, error_message_redacted,
133 attributes, otel_attributes, openinference_attributes, memory_state,
134 input_ref, output_ref, prompt_template_id, prompt_version,
135 model_provider, model_name, tool_name, tool_schema_hash, retrieval_query_hash,
136 retrieved_document_refs, token_usage, cost_estimate, latency_ms, retry_count, cache_hit,
137 redaction_policy, schema_version, project_id
138 ) VALUES (
139 $1, $2, $3, $4, $5, $6,
140 $7, $8, $9, $10, $11, $12, $13, $14,
141 $15, $16, $17, $18,
142 $19, $20, $21, $22,
143 $23, $24, $25, $26, $27,
144 $28, $29, $30, $31, $32, $33,
145 $34, $35, $36
146 )
147 -- A span may be recorded twice with the same span_id (e.g. a HITL
148 -- breakpoint: first PendingApproval, then Ok once resolved). Upsert
149 -- so the resolved state replaces the pending row; `DO NOTHING` would
150 -- silently discard it. For ordinary single-write spans the conflict
151 -- arm never fires.
152 ON CONFLICT (span_id) DO UPDATE SET
153 trace_id=EXCLUDED.trace_id, parent_span_id=EXCLUDED.parent_span_id,
154 run_id=EXCLUDED.run_id, session_id=EXCLUDED.session_id,
155 user_id_hash=EXCLUDED.user_id_hash, span_kind=EXCLUDED.span_kind,
156 name=EXCLUDED.name, start_time=EXCLUDED.start_time, end_time=EXCLUDED.end_time,
157 status=EXCLUDED.status, status_message=EXCLUDED.status_message,
158 error_type=EXCLUDED.error_type, error_message_redacted=EXCLUDED.error_message_redacted,
159 attributes=EXCLUDED.attributes, otel_attributes=EXCLUDED.otel_attributes,
160 openinference_attributes=EXCLUDED.openinference_attributes,
161 memory_state=EXCLUDED.memory_state, input_ref=EXCLUDED.input_ref,
162 output_ref=EXCLUDED.output_ref, prompt_template_id=EXCLUDED.prompt_template_id,
163 prompt_version=EXCLUDED.prompt_version, model_provider=EXCLUDED.model_provider,
164 model_name=EXCLUDED.model_name, tool_name=EXCLUDED.tool_name,
165 tool_schema_hash=EXCLUDED.tool_schema_hash,
166 retrieval_query_hash=EXCLUDED.retrieval_query_hash,
167 retrieved_document_refs=EXCLUDED.retrieved_document_refs,
168 token_usage=EXCLUDED.token_usage, cost_estimate=EXCLUDED.cost_estimate,
169 latency_ms=EXCLUDED.latency_ms, retry_count=EXCLUDED.retry_count,
170 cache_hit=EXCLUDED.cache_hit, redaction_policy=EXCLUDED.redaction_policy,
171 schema_version=EXCLUDED.schema_version, project_id=EXCLUDED.project_id
172 "#,
173 );
174
175 let q: Query<'_, Postgres, PgArguments> = q;
176 q.bind(trace_id)
177 .bind(span_id)
178 .bind(parent_span_id)
179 .bind(run_id)
180 .bind(session_id)
181 .bind(span.user_id_hash)
182 .bind(span_kind)
183 .bind(span.name)
184 .bind(span.start_time as i64)
185 .bind(span.end_time.map(|t| t as i64))
186 .bind(status)
187 .bind(span.status_message)
188 .bind(span.error_type)
189 .bind(span.error_message_redacted)
190 .bind(attributes)
191 .bind(otel_attributes)
192 .bind(openinference_attributes)
193 .bind(memory_state)
194 .bind(input_ref)
195 .bind(output_ref)
196 .bind(span.prompt_template_id)
197 .bind(span.prompt_version)
198 .bind(span.model_provider)
199 .bind(span.model_name)
200 .bind(span.tool_name)
201 .bind(span.tool_schema_hash)
202 .bind(span.retrieval_query_hash)
203 .bind(retrieved_document_refs)
204 .bind(token_usage)
205 .bind(cost_estimate)
206 .bind(span.latency_ms.map(|t| t as i64))
207 .bind(span.retry_count.map(|c| c as i64))
208 .bind(span.cache_hit)
209 .bind(redaction_policy)
210 .bind(span.schema_version)
211 .bind(span.project_id)
212 .execute(&self.pool)
213 .await?;
214
215 Ok(())
216 }
217
218 async fn record_event(&self, event: EventRecord) -> Result<()> {
219 let event_kind = serde_json::to_string(&event.event_kind)?
220 .trim_matches('"')
221 .to_string();
222 let attributes = serde_json::to_string(&event.attributes)?;
223
224 let q = sqlx::query(
225 r#"
226 INSERT INTO events (
227 event_id, trace_id, run_id, parent_span_id, seq,
228 event_kind, name, timestamp, attributes, schema_version
229 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
230 ON CONFLICT (event_id) DO NOTHING
231 "#,
232 );
233
234 let q: Query<'_, Postgres, PgArguments> = q;
235 q.bind(event.event_id.0.to_string())
236 .bind(event.trace_id.0.to_string())
237 .bind(event.run_id.0.to_string())
238 .bind(event.parent_span_id.map(|id| id.0.to_string()))
239 .bind(event.seq as i64)
240 .bind(event_kind)
241 .bind(event.name)
242 .bind(event.timestamp as i64)
243 .bind(attributes)
244 .bind(event.schema_version)
245 .execute(&self.pool)
246 .await?;
247
248 Ok(())
249 }
250}