trace_weft_server/storage/
postgres.rs1use anyhow::Result;
2use sqlx::{PgPool, Postgres, postgres::PgArguments, postgres::PgPoolOptions, query::Query};
3use trace_weft_core::SpanRecord;
4use trace_weft_recorder::TraceStore;
5
6pub struct PostgresRecorder {
7 pub pool: PgPool,
8}
9
10impl PostgresRecorder {
11 pub async fn new(db_url: &str) -> Result<Self> {
12 let pool = PgPoolOptions::new()
13 .max_connections(5)
14 .connect(db_url)
15 .await?;
16
17 let q = sqlx::query(
19 r#"
20 CREATE TABLE IF NOT EXISTS spans (
21 trace_id TEXT NOT NULL,
22 span_id TEXT NOT NULL PRIMARY KEY,
23 parent_span_id TEXT,
24 run_id TEXT NOT NULL,
25 session_id TEXT,
26 user_id_hash TEXT,
27 span_kind TEXT NOT NULL,
28 name TEXT NOT NULL,
29 start_time BIGINT NOT NULL,
30 end_time BIGINT,
31 status TEXT NOT NULL,
32 status_message TEXT,
33 error_type TEXT,
34 error_message_redacted TEXT,
35 attributes TEXT NOT NULL,
36 otel_attributes TEXT NOT NULL,
37 openinference_attributes TEXT NOT NULL,
38 memory_state TEXT,
39 input_ref TEXT,
40 output_ref TEXT,
41 prompt_template_id TEXT,
42 prompt_version TEXT,
43 model_provider TEXT,
44 model_name TEXT,
45 tool_name TEXT,
46 tool_schema_hash TEXT,
47 retrieval_query_hash TEXT,
48 retrieved_document_refs TEXT NOT NULL,
49 token_usage TEXT,
50 cost_estimate TEXT,
51 latency_ms BIGINT,
52 retry_count INTEGER,
53 cache_hit BOOLEAN,
54 redaction_policy TEXT NOT NULL,
55 schema_version TEXT NOT NULL,
56 project_id TEXT
57 );
58 CREATE INDEX IF NOT EXISTS idx_spans_trace_id ON spans(trace_id);
59 CREATE INDEX IF NOT EXISTS idx_spans_run_id ON spans(run_id);
60 CREATE INDEX IF NOT EXISTS idx_spans_project_id ON spans(project_id);
61 "#,
62 );
63 let q: Query<'_, Postgres, PgArguments> = q;
64 q.execute(&pool).await?;
65
66 Ok(Self { pool })
67 }
68}
69
70#[async_trait::async_trait]
71impl TraceStore for PostgresRecorder {
72 async fn record_span(&self, span: SpanRecord) -> Result<()> {
73 let trace_id = span.trace_id.0.to_string();
74 let span_id = span.span_id.0.to_string();
75 let parent_span_id = span.parent_span_id.map(|id| id.0.to_string());
76 let run_id = span.run_id.0.to_string();
77 let session_id = span.session_id.map(|id| id.0.to_string());
78 let span_kind = serde_json::to_string(&span.span_kind)?
79 .trim_matches('"')
80 .to_string();
81 let status = serde_json::to_string(&span.status)?
82 .trim_matches('"')
83 .to_string();
84
85 let attributes = serde_json::to_string(&span.attributes)?;
86 let otel_attributes = serde_json::to_string(&span.otel_attributes)?;
87 let openinference_attributes = serde_json::to_string(&span.openinference_attributes)?;
88 let memory_state = span
89 .memory_state
90 .map(|s| serde_json::to_string(&s).unwrap());
91
92 let input_ref = span.input_ref.map(|r| serde_json::to_string(&r).unwrap());
93 let output_ref = span.output_ref.map(|r| serde_json::to_string(&r).unwrap());
94 let retrieved_document_refs = serde_json::to_string(&span.retrieved_document_refs)?;
95 let token_usage = span.token_usage.map(|u| serde_json::to_string(&u).unwrap());
96 let cost_estimate = span
97 .cost_estimate
98 .map(|c| serde_json::to_string(&c).unwrap());
99 let redaction_policy = serde_json::to_string(&span.redaction_policy)?
100 .trim_matches('"')
101 .to_string();
102
103 let q = sqlx::query(
104 r#"
105 INSERT INTO spans (
106 trace_id, span_id, parent_span_id, run_id, session_id, user_id_hash,
107 span_kind, name, start_time, end_time, status, status_message, error_type, error_message_redacted,
108 attributes, otel_attributes, openinference_attributes, memory_state,
109 input_ref, output_ref, prompt_template_id, prompt_version,
110 model_provider, model_name, tool_name, tool_schema_hash, retrieval_query_hash,
111 retrieved_document_refs, token_usage, cost_estimate, latency_ms, retry_count, cache_hit,
112 redaction_policy, schema_version, project_id
113 ) VALUES (
114 $1, $2, $3, $4, $5, $6,
115 $7, $8, $9, $10, $11, $12, $13, $14,
116 $15, $16, $17, $18,
117 $19, $20, $21, $22,
118 $23, $24, $25, $26, $27,
119 $28, $29, $30, $31, $32, $33,
120 $34, $35, $36
121 )
122 ON CONFLICT (span_id) DO NOTHING
123 "#,
124 );
125
126 let q: Query<'_, Postgres, PgArguments> = q;
127 q.bind(trace_id)
128 .bind(span_id)
129 .bind(parent_span_id)
130 .bind(run_id)
131 .bind(session_id)
132 .bind(span.user_id_hash)
133 .bind(span_kind)
134 .bind(span.name)
135 .bind(span.start_time as i64)
136 .bind(span.end_time.map(|t| t as i64))
137 .bind(status)
138 .bind(span.status_message)
139 .bind(span.error_type)
140 .bind(span.error_message_redacted)
141 .bind(attributes)
142 .bind(otel_attributes)
143 .bind(openinference_attributes)
144 .bind(memory_state)
145 .bind(input_ref)
146 .bind(output_ref)
147 .bind(span.prompt_template_id)
148 .bind(span.prompt_version)
149 .bind(span.model_provider)
150 .bind(span.model_name)
151 .bind(span.tool_name)
152 .bind(span.tool_schema_hash)
153 .bind(span.retrieval_query_hash)
154 .bind(retrieved_document_refs)
155 .bind(token_usage)
156 .bind(cost_estimate)
157 .bind(span.latency_ms.map(|t| t as i64))
158 .bind(span.retry_count.map(|c| c as i32))
159 .bind(span.cache_hit)
160 .bind(redaction_policy)
161 .bind(span.schema_version)
162 .bind(span.project_id)
163 .execute(&self.pool)
164 .await?;
165
166 Ok(())
167 }
168}