Skip to main content

trace_weft_server/storage/
postgres.rs

1use anyhow::Result;
2use sqlx::{PgPool, Postgres, postgres::PgArguments, postgres::PgPoolOptions, query::Query};
3use trace_weft_core::{EventRecord, SpanRecord};
4use trace_weft_recorder::TraceStore;
5
6pub struct PostgresRecorder {
7    pub pool: PgPool,
8}
9
10impl PostgresRecorder {
11    pub async fn new(db_url: &str) -> Result<Self> {
12        let pool = PgPoolOptions::new()
13            .max_connections(5)
14            .connect(db_url)
15            .await?;
16
17        Self::from_pool(pool).await
18    }
19
20    /// Wrap an existing pool, creating the schema on first use. The server
21    /// constructs the recorder from its own pool (so it shares connection
22    /// settings), so schema creation must live here rather than only in
23    /// [`new`] — otherwise a fresh Postgres has no tables.
24    pub async fn from_pool(pool: PgPool) -> Result<Self> {
25        // `raw_sql` runs the whole multi-statement block unprepared; `query`
26        // would prepare it and Postgres rejects multiple commands in one
27        // prepared statement ("cannot insert multiple commands…").
28        sqlx::raw_sql(
29            r#"
30            CREATE TABLE IF NOT EXISTS spans (
31                trace_id TEXT NOT NULL,
32                span_id TEXT NOT NULL PRIMARY KEY,
33                parent_span_id TEXT,
34                run_id TEXT NOT NULL,
35                session_id TEXT,
36                user_id_hash TEXT,
37                span_kind TEXT NOT NULL,
38                name TEXT NOT NULL,
39                start_time BIGINT NOT NULL,
40                end_time BIGINT,
41                status TEXT NOT NULL,
42                status_message TEXT,
43                error_type TEXT,
44                error_message_redacted TEXT,
45                attributes TEXT NOT NULL,
46                otel_attributes TEXT NOT NULL,
47                openinference_attributes TEXT NOT NULL,
48                memory_state TEXT,
49                input_ref TEXT,
50                output_ref TEXT,
51                prompt_template_id TEXT,
52                prompt_version TEXT,
53                model_provider TEXT,
54                model_name TEXT,
55                tool_name TEXT,
56                tool_schema_hash TEXT,
57                retrieval_query_hash TEXT,
58                retrieved_document_refs TEXT NOT NULL,
59                token_usage TEXT,
60                cost_estimate TEXT,
61                latency_ms BIGINT,
62                retry_count BIGINT,
63                cache_hit BOOLEAN,
64                redaction_policy TEXT NOT NULL,
65                schema_version TEXT NOT NULL,
66                project_id TEXT
67            );
68            CREATE INDEX IF NOT EXISTS idx_spans_trace_id ON spans(trace_id);
69            CREATE INDEX IF NOT EXISTS idx_spans_run_id ON spans(run_id);
70            CREATE INDEX IF NOT EXISTS idx_spans_project_id ON spans(project_id);
71
72            CREATE TABLE IF NOT EXISTS events (
73                event_id TEXT NOT NULL PRIMARY KEY,
74                trace_id TEXT NOT NULL,
75                run_id TEXT NOT NULL,
76                parent_span_id TEXT,
77                seq BIGINT NOT NULL,
78                event_kind TEXT NOT NULL,
79                name TEXT NOT NULL,
80                timestamp BIGINT NOT NULL,
81                attributes TEXT NOT NULL,
82                schema_version TEXT NOT NULL
83            );
84            CREATE INDEX IF NOT EXISTS idx_events_trace_id ON events(trace_id);
85            CREATE INDEX IF NOT EXISTS idx_events_parent_span_id ON events(parent_span_id);
86            "#,
87        )
88        .execute(&pool)
89        .await?;
90
91        Ok(Self { pool })
92    }
93}
94
95#[async_trait::async_trait]
96impl TraceStore for PostgresRecorder {
97    async fn record_span(&self, span: SpanRecord) -> Result<()> {
98        let trace_id = span.trace_id.0.to_string();
99        let span_id = span.span_id.0.to_string();
100        let parent_span_id = span.parent_span_id.map(|id| id.0.to_string());
101        let run_id = span.run_id.0.to_string();
102        let session_id = span.session_id.map(|id| id.0.to_string());
103        let span_kind = serde_json::to_string(&span.span_kind)?
104            .trim_matches('"')
105            .to_string();
106        let status = serde_json::to_string(&span.status)?
107            .trim_matches('"')
108            .to_string();
109
110        let attributes = serde_json::to_string(&span.attributes)?;
111        let otel_attributes = serde_json::to_string(&span.otel_attributes)?;
112        let openinference_attributes = serde_json::to_string(&span.openinference_attributes)?;
113        let memory_state = span
114            .memory_state
115            .map(|s| serde_json::to_string(&s).unwrap());
116
117        let input_ref = span.input_ref.map(|r| serde_json::to_string(&r).unwrap());
118        let output_ref = span.output_ref.map(|r| serde_json::to_string(&r).unwrap());
119        let retrieved_document_refs = serde_json::to_string(&span.retrieved_document_refs)?;
120        let token_usage = span.token_usage.map(|u| serde_json::to_string(&u).unwrap());
121        let cost_estimate = span
122            .cost_estimate
123            .map(|c| serde_json::to_string(&c).unwrap());
124        let redaction_policy = serde_json::to_string(&span.redaction_policy)?
125            .trim_matches('"')
126            .to_string();
127
128        let q = sqlx::query(
129            r#"
130            INSERT INTO spans (
131                trace_id, span_id, parent_span_id, run_id, session_id, user_id_hash,
132                span_kind, name, start_time, end_time, status, status_message, error_type, error_message_redacted,
133                attributes, otel_attributes, openinference_attributes, memory_state,
134                input_ref, output_ref, prompt_template_id, prompt_version,
135                model_provider, model_name, tool_name, tool_schema_hash, retrieval_query_hash,
136                retrieved_document_refs, token_usage, cost_estimate, latency_ms, retry_count, cache_hit,
137                redaction_policy, schema_version, project_id
138            ) VALUES (
139                $1, $2, $3, $4, $5, $6,
140                $7, $8, $9, $10, $11, $12, $13, $14,
141                $15, $16, $17, $18,
142                $19, $20, $21, $22,
143                $23, $24, $25, $26, $27,
144                $28, $29, $30, $31, $32, $33,
145                $34, $35, $36
146            )
147            -- A span may be recorded twice with the same span_id (e.g. a HITL
148            -- breakpoint: first PendingApproval, then Ok once resolved). Upsert
149            -- so the resolved state replaces the pending row; `DO NOTHING` would
150            -- silently discard it. For ordinary single-write spans the conflict
151            -- arm never fires.
152            ON CONFLICT (span_id) DO UPDATE SET
153                trace_id=EXCLUDED.trace_id, parent_span_id=EXCLUDED.parent_span_id,
154                run_id=EXCLUDED.run_id, session_id=EXCLUDED.session_id,
155                user_id_hash=EXCLUDED.user_id_hash, span_kind=EXCLUDED.span_kind,
156                name=EXCLUDED.name, start_time=EXCLUDED.start_time, end_time=EXCLUDED.end_time,
157                status=EXCLUDED.status, status_message=EXCLUDED.status_message,
158                error_type=EXCLUDED.error_type, error_message_redacted=EXCLUDED.error_message_redacted,
159                attributes=EXCLUDED.attributes, otel_attributes=EXCLUDED.otel_attributes,
160                openinference_attributes=EXCLUDED.openinference_attributes,
161                memory_state=EXCLUDED.memory_state, input_ref=EXCLUDED.input_ref,
162                output_ref=EXCLUDED.output_ref, prompt_template_id=EXCLUDED.prompt_template_id,
163                prompt_version=EXCLUDED.prompt_version, model_provider=EXCLUDED.model_provider,
164                model_name=EXCLUDED.model_name, tool_name=EXCLUDED.tool_name,
165                tool_schema_hash=EXCLUDED.tool_schema_hash,
166                retrieval_query_hash=EXCLUDED.retrieval_query_hash,
167                retrieved_document_refs=EXCLUDED.retrieved_document_refs,
168                token_usage=EXCLUDED.token_usage, cost_estimate=EXCLUDED.cost_estimate,
169                latency_ms=EXCLUDED.latency_ms, retry_count=EXCLUDED.retry_count,
170                cache_hit=EXCLUDED.cache_hit, redaction_policy=EXCLUDED.redaction_policy,
171                schema_version=EXCLUDED.schema_version, project_id=EXCLUDED.project_id
172            "#,
173        );
174
175        let q: Query<'_, Postgres, PgArguments> = q;
176        q.bind(trace_id)
177            .bind(span_id)
178            .bind(parent_span_id)
179            .bind(run_id)
180            .bind(session_id)
181            .bind(span.user_id_hash)
182            .bind(span_kind)
183            .bind(span.name)
184            .bind(span.start_time as i64)
185            .bind(span.end_time.map(|t| t as i64))
186            .bind(status)
187            .bind(span.status_message)
188            .bind(span.error_type)
189            .bind(span.error_message_redacted)
190            .bind(attributes)
191            .bind(otel_attributes)
192            .bind(openinference_attributes)
193            .bind(memory_state)
194            .bind(input_ref)
195            .bind(output_ref)
196            .bind(span.prompt_template_id)
197            .bind(span.prompt_version)
198            .bind(span.model_provider)
199            .bind(span.model_name)
200            .bind(span.tool_name)
201            .bind(span.tool_schema_hash)
202            .bind(span.retrieval_query_hash)
203            .bind(retrieved_document_refs)
204            .bind(token_usage)
205            .bind(cost_estimate)
206            .bind(span.latency_ms.map(|t| t as i64))
207            .bind(span.retry_count.map(|c| c as i64))
208            .bind(span.cache_hit)
209            .bind(redaction_policy)
210            .bind(span.schema_version)
211            .bind(span.project_id)
212            .execute(&self.pool)
213            .await?;
214
215        Ok(())
216    }
217
218    async fn record_event(&self, event: EventRecord) -> Result<()> {
219        let event_kind = serde_json::to_string(&event.event_kind)?
220            .trim_matches('"')
221            .to_string();
222        let attributes = serde_json::to_string(&event.attributes)?;
223
224        let q = sqlx::query(
225            r#"
226            INSERT INTO events (
227                event_id, trace_id, run_id, parent_span_id, seq,
228                event_kind, name, timestamp, attributes, schema_version
229            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
230            ON CONFLICT (event_id) DO NOTHING
231            "#,
232        );
233
234        let q: Query<'_, Postgres, PgArguments> = q;
235        q.bind(event.event_id.0.to_string())
236            .bind(event.trace_id.0.to_string())
237            .bind(event.run_id.0.to_string())
238            .bind(event.parent_span_id.map(|id| id.0.to_string()))
239            .bind(event.seq as i64)
240            .bind(event_kind)
241            .bind(event.name)
242            .bind(event.timestamp as i64)
243            .bind(attributes)
244            .bind(event.schema_version)
245            .execute(&self.pool)
246            .await?;
247
248        Ok(())
249    }
250}