Skip to main content

trace_weft_server/storage/
postgres.rs

1use anyhow::Result;
2use sqlx::{PgPool, Postgres, postgres::PgArguments, postgres::PgPoolOptions, query::Query};
3use trace_weft_core::SpanRecord;
4use trace_weft_recorder::TraceStore;
5
6pub struct PostgresRecorder {
7    pub pool: PgPool,
8}
9
10impl PostgresRecorder {
11    pub async fn new(db_url: &str) -> Result<Self> {
12        let pool = PgPoolOptions::new()
13            .max_connections(5)
14            .connect(db_url)
15            .await?;
16
17        // Initialize schema (simplified for demo)
18        let q = sqlx::query(
19            r#"
20            CREATE TABLE IF NOT EXISTS spans (
21                trace_id TEXT NOT NULL,
22                span_id TEXT NOT NULL PRIMARY KEY,
23                parent_span_id TEXT,
24                run_id TEXT NOT NULL,
25                session_id TEXT,
26                user_id_hash TEXT,
27                span_kind TEXT NOT NULL,
28                name TEXT NOT NULL,
29                start_time BIGINT NOT NULL,
30                end_time BIGINT,
31                status TEXT NOT NULL,
32                status_message TEXT,
33                error_type TEXT,
34                error_message_redacted TEXT,
35                attributes TEXT NOT NULL,
36                otel_attributes TEXT NOT NULL,
37                openinference_attributes TEXT NOT NULL,
38                memory_state TEXT,
39                input_ref TEXT,
40                output_ref TEXT,
41                prompt_template_id TEXT,
42                prompt_version TEXT,
43                model_provider TEXT,
44                model_name TEXT,
45                tool_name TEXT,
46                tool_schema_hash TEXT,
47                retrieval_query_hash TEXT,
48                retrieved_document_refs TEXT NOT NULL,
49                token_usage TEXT,
50                cost_estimate TEXT,
51                latency_ms BIGINT,
52                retry_count INTEGER,
53                cache_hit BOOLEAN,
54                redaction_policy TEXT NOT NULL,
55                schema_version TEXT NOT NULL,
56                project_id TEXT
57            );
58            CREATE INDEX IF NOT EXISTS idx_spans_trace_id ON spans(trace_id);
59            CREATE INDEX IF NOT EXISTS idx_spans_run_id ON spans(run_id);
60            CREATE INDEX IF NOT EXISTS idx_spans_project_id ON spans(project_id);
61            "#,
62        );
63        let q: Query<'_, Postgres, PgArguments> = q;
64        q.execute(&pool).await?;
65
66        Ok(Self { pool })
67    }
68}
69
70#[async_trait::async_trait]
71impl TraceStore for PostgresRecorder {
72    async fn record_span(&self, span: SpanRecord) -> Result<()> {
73        let trace_id = span.trace_id.0.to_string();
74        let span_id = span.span_id.0.to_string();
75        let parent_span_id = span.parent_span_id.map(|id| id.0.to_string());
76        let run_id = span.run_id.0.to_string();
77        let session_id = span.session_id.map(|id| id.0.to_string());
78        let span_kind = serde_json::to_string(&span.span_kind)?
79            .trim_matches('"')
80            .to_string();
81        let status = serde_json::to_string(&span.status)?
82            .trim_matches('"')
83            .to_string();
84
85        let attributes = serde_json::to_string(&span.attributes)?;
86        let otel_attributes = serde_json::to_string(&span.otel_attributes)?;
87        let openinference_attributes = serde_json::to_string(&span.openinference_attributes)?;
88        let memory_state = span
89            .memory_state
90            .map(|s| serde_json::to_string(&s).unwrap());
91
92        let input_ref = span.input_ref.map(|r| serde_json::to_string(&r).unwrap());
93        let output_ref = span.output_ref.map(|r| serde_json::to_string(&r).unwrap());
94        let retrieved_document_refs = serde_json::to_string(&span.retrieved_document_refs)?;
95        let token_usage = span.token_usage.map(|u| serde_json::to_string(&u).unwrap());
96        let cost_estimate = span
97            .cost_estimate
98            .map(|c| serde_json::to_string(&c).unwrap());
99        let redaction_policy = serde_json::to_string(&span.redaction_policy)?
100            .trim_matches('"')
101            .to_string();
102
103        let q = sqlx::query(
104            r#"
105            INSERT INTO spans (
106                trace_id, span_id, parent_span_id, run_id, session_id, user_id_hash,
107                span_kind, name, start_time, end_time, status, status_message, error_type, error_message_redacted,
108                attributes, otel_attributes, openinference_attributes, memory_state,
109                input_ref, output_ref, prompt_template_id, prompt_version,
110                model_provider, model_name, tool_name, tool_schema_hash, retrieval_query_hash,
111                retrieved_document_refs, token_usage, cost_estimate, latency_ms, retry_count, cache_hit,
112                redaction_policy, schema_version, project_id
113            ) VALUES (
114                $1, $2, $3, $4, $5, $6,
115                $7, $8, $9, $10, $11, $12, $13, $14,
116                $15, $16, $17, $18,
117                $19, $20, $21, $22,
118                $23, $24, $25, $26, $27,
119                $28, $29, $30, $31, $32, $33,
120                $34, $35, $36
121            )
122            ON CONFLICT (span_id) DO NOTHING
123            "#,
124        );
125
126        let q: Query<'_, Postgres, PgArguments> = q;
127        q.bind(trace_id)
128            .bind(span_id)
129            .bind(parent_span_id)
130            .bind(run_id)
131            .bind(session_id)
132            .bind(span.user_id_hash)
133            .bind(span_kind)
134            .bind(span.name)
135            .bind(span.start_time as i64)
136            .bind(span.end_time.map(|t| t as i64))
137            .bind(status)
138            .bind(span.status_message)
139            .bind(span.error_type)
140            .bind(span.error_message_redacted)
141            .bind(attributes)
142            .bind(otel_attributes)
143            .bind(openinference_attributes)
144            .bind(memory_state)
145            .bind(input_ref)
146            .bind(output_ref)
147            .bind(span.prompt_template_id)
148            .bind(span.prompt_version)
149            .bind(span.model_provider)
150            .bind(span.model_name)
151            .bind(span.tool_name)
152            .bind(span.tool_schema_hash)
153            .bind(span.retrieval_query_hash)
154            .bind(retrieved_document_refs)
155            .bind(token_usage)
156            .bind(cost_estimate)
157            .bind(span.latency_ms.map(|t| t as i64))
158            .bind(span.retry_count.map(|c| c as i32))
159            .bind(span.cache_hit)
160            .bind(redaction_policy)
161            .bind(span.schema_version)
162            .bind(span.project_id)
163            .execute(&self.pool)
164            .await?;
165
166        Ok(())
167    }
168}