systemprompt_database/admin/
introspection.rs1use std::sync::Arc;
2
3use anyhow::Result;
4use sqlx::postgres::PgPool;
5use sqlx::Row;
6
7use crate::models::{ColumnInfo, DatabaseInfo, IndexInfo, TableInfo};
8
9#[derive(Debug)]
10pub struct DatabaseAdminService {
11 pool: Arc<PgPool>,
12}
13
14impl DatabaseAdminService {
15 pub const fn new(pool: Arc<PgPool>) -> Self {
16 Self { pool }
17 }
18
19 pub async fn list_tables(&self) -> Result<Vec<TableInfo>> {
20 let rows = sqlx::query(
21 r"
22 SELECT
23 t.table_name as name,
24 COALESCE(s.n_live_tup, 0) as row_count,
25 COALESCE(pg_total_relation_size(quote_ident(t.table_name)::regclass), 0) as size_bytes
26 FROM information_schema.tables t
27 LEFT JOIN pg_stat_user_tables s ON t.table_name = s.relname
28 WHERE t.table_schema = 'public'
29 ORDER BY t.table_name
30 ",
31 )
32 .fetch_all(&*self.pool)
33 .await?;
34
35 let tables = rows
36 .iter()
37 .map(|row| {
38 let name: String = row.get("name");
39 let row_count: i64 = row.get("row_count");
40 let size_bytes: i64 = row.get("size_bytes");
41 TableInfo {
42 name,
43 row_count,
44 size_bytes,
45 columns: vec![],
46 }
47 })
48 .collect();
49
50 Ok(tables)
51 }
52
53 pub async fn describe_table(&self, table_name: &str) -> Result<(Vec<ColumnInfo>, i64)> {
54 if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
55 return Err(anyhow::anyhow!("Table '{}' not found", table_name));
56 }
57
58 let rows = sqlx::query(
59 "SELECT column_name, data_type, is_nullable, column_default FROM \
60 information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position",
61 )
62 .bind(table_name)
63 .fetch_all(&*self.pool)
64 .await?;
65
66 if rows.is_empty() {
67 return Err(anyhow::anyhow!("Table '{}' not found", table_name));
68 }
69
70 let pk_rows = sqlx::query(
71 r"
72 SELECT a.attname as column_name
73 FROM pg_index i
74 JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
75 WHERE i.indrelid = $1::regclass AND i.indisprimary
76 ",
77 )
78 .bind(table_name)
79 .fetch_all(&*self.pool)
80 .await
81 .unwrap_or_else(|_| Vec::new());
82
83 let pk_columns: Vec<String> = pk_rows
84 .iter()
85 .map(|row| row.get::<String, _>("column_name"))
86 .collect();
87
88 let columns = rows
89 .iter()
90 .map(|row| {
91 let name: String = row.get("column_name");
92 let data_type: String = row.get("data_type");
93 let nullable_str: String = row.get("is_nullable");
94 let nullable = nullable_str.to_uppercase() == "YES";
95 let default: Option<String> = row.get("column_default");
96 let primary_key = pk_columns.contains(&name);
97
98 ColumnInfo {
99 name,
100 data_type,
101 nullable,
102 primary_key,
103 default,
104 }
105 })
106 .collect();
107
108 let row_count = self.count_rows(table_name).await?;
109
110 Ok((columns, row_count))
111 }
112
113 pub async fn get_table_indexes(&self, table_name: &str) -> Result<Vec<IndexInfo>> {
114 if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
115 return Err(anyhow::anyhow!("Table '{}' not found", table_name));
116 }
117
118 let rows = sqlx::query(
119 r"
120 SELECT
121 i.relname as index_name,
122 ix.indisunique as is_unique,
123 array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns
124 FROM pg_class t
125 JOIN pg_index ix ON t.oid = ix.indrelid
126 JOIN pg_class i ON i.oid = ix.indexrelid
127 JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
128 WHERE t.relname = $1 AND t.relkind = 'r'
129 GROUP BY i.relname, ix.indisunique
130 ORDER BY i.relname
131 ",
132 )
133 .bind(table_name)
134 .fetch_all(&*self.pool)
135 .await?;
136
137 let indexes = rows
138 .iter()
139 .map(|row| {
140 let name: String = row.get("index_name");
141 let unique: bool = row.get("is_unique");
142 let columns: Vec<String> = row.get("columns");
143 IndexInfo {
144 name,
145 columns,
146 unique,
147 }
148 })
149 .collect();
150
151 Ok(indexes)
152 }
153
154 pub async fn count_rows(&self, table_name: &str) -> Result<i64> {
155 if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
156 return Err(anyhow::anyhow!("Table '{}' not found", table_name));
157 }
158
159 let quoted_table = quote_identifier(table_name);
160 let count_query = format!("SELECT COUNT(*) as count FROM {quoted_table}");
161 let row_count: i64 = sqlx::query_scalar(&count_query)
162 .fetch_one(&*self.pool)
163 .await?;
164
165 Ok(row_count)
166 }
167
168 pub async fn get_database_info(&self) -> Result<DatabaseInfo> {
169 let version: String = sqlx::query_scalar("SELECT version()")
170 .fetch_one(&*self.pool)
171 .await?;
172
173 let size: i64 = sqlx::query_scalar("SELECT pg_database_size(current_database())")
174 .fetch_one(&*self.pool)
175 .await?;
176
177 let tables = self.list_tables().await?;
178
179 Ok(DatabaseInfo {
180 path: "PostgreSQL".to_string(),
181 size: u64::try_from(size).unwrap_or(0),
182 version,
183 tables,
184 })
185 }
186
187 pub fn get_expected_tables() -> Vec<&'static str> {
188 vec![
189 "users",
190 "user_sessions",
191 "user_contexts",
192 "agent_tasks",
193 "agent_skills",
194 "task_messages",
195 "task_artifacts",
196 "task_execution_steps",
197 "artifact_parts",
198 "message_parts",
199 "ai_requests",
200 "ai_request_messages",
201 "ai_request_tool_calls",
202 "mcp_tool_executions",
203 "logs",
204 "analytics_events",
205 "oauth_clients",
206 "oauth_auth_codes",
207 "oauth_refresh_tokens",
208 "scheduled_jobs",
209 "services",
210 "markdown_content",
211 "markdown_categories",
212 "files",
213 "content_files",
214 ]
215 }
216}
217
218fn quote_identifier(identifier: &str) -> String {
219 let escaped = identifier.replace('"', "\"\"");
220 format!("\"{escaped}\"")
221}