Skip to main content

systemprompt_database/admin/
introspection.rs

1//! Schema introspection service.
2//!
3//! Part of the documented sqlx allowlist — every query here is built
4//! dynamically because the table name is supplied at runtime as a
5//! [`SafeIdentifier`].
6
7use std::sync::Arc;
8
9use sqlx::Row;
10use sqlx::postgres::PgPool;
11
12use crate::admin::identifier::SafeIdentifier;
13use crate::error::{DatabaseResult, RepositoryError};
14use crate::models::{ColumnInfo, DatabaseInfo, IndexInfo, TableInfo};
15
16#[derive(Debug)]
17pub struct DatabaseAdminService {
18    pool: Arc<PgPool>,
19}
20
21impl DatabaseAdminService {
22    pub const fn new(pool: Arc<PgPool>) -> Self {
23        Self { pool }
24    }
25
26    pub async fn list_tables(&self) -> DatabaseResult<Vec<TableInfo>> {
27        let rows = sqlx::query(
28            r"
29            SELECT
30                t.table_name as name,
31                COALESCE(s.n_live_tup, 0) as row_count,
32                COALESCE(pg_total_relation_size(quote_ident(t.table_name)::regclass), 0) as size_bytes
33            FROM information_schema.tables t
34            LEFT JOIN pg_stat_user_tables s ON t.table_name = s.relname
35            WHERE t.table_schema = 'public'
36            ORDER BY t.table_name
37            ",
38        )
39        .fetch_all(&*self.pool)
40        .await?;
41
42        let tables = rows
43            .iter()
44            .map(|row| {
45                let name: String = row.get("name");
46                let row_count: i64 = row.get("row_count");
47                let size_bytes: i64 = row.get("size_bytes");
48                TableInfo {
49                    name,
50                    row_count,
51                    size_bytes,
52                    columns: vec![],
53                }
54            })
55            .collect();
56
57        Ok(tables)
58    }
59
60    pub async fn describe_table(
61        &self,
62        table_name: &SafeIdentifier,
63    ) -> DatabaseResult<(Vec<ColumnInfo>, i64)> {
64        let rows = sqlx::query(
65            "SELECT column_name, data_type, is_nullable, column_default FROM \
66             information_schema.columns WHERE table_name = $1 ORDER BY ordinal_position",
67        )
68        .bind(table_name.as_str())
69        .fetch_all(&*self.pool)
70        .await?;
71
72        if rows.is_empty() {
73            return Err(RepositoryError::not_found(format!(
74                "Table '{table_name}' not found"
75            )));
76        }
77
78        let pk_rows = sqlx::query(
79            r"
80            SELECT a.attname as column_name
81            FROM pg_index i
82            JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
83            WHERE i.indrelid = $1::regclass AND i.indisprimary
84            ",
85        )
86        .bind(table_name.as_str())
87        .fetch_all(&*self.pool)
88        .await?;
89
90        let pk_columns: Vec<String> = pk_rows
91            .iter()
92            .map(|row| row.get::<String, _>("column_name"))
93            .collect();
94
95        let columns = rows
96            .iter()
97            .map(|row| {
98                let name: String = row.get("column_name");
99                let data_type: String = row.get("data_type");
100                let nullable_str: String = row.get("is_nullable");
101                let nullable = nullable_str.to_uppercase() == "YES";
102                let default: Option<String> = row.get("column_default");
103                let primary_key = pk_columns.contains(&name);
104
105                ColumnInfo {
106                    name,
107                    data_type,
108                    nullable,
109                    primary_key,
110                    default,
111                }
112            })
113            .collect();
114
115        let row_count = self.count_rows(table_name).await?;
116
117        Ok((columns, row_count))
118    }
119
120    pub async fn get_table_indexes(
121        &self,
122        table_name: &SafeIdentifier,
123    ) -> DatabaseResult<Vec<IndexInfo>> {
124        let rows = sqlx::query(
125            r"
126            SELECT
127                i.relname as index_name,
128                ix.indisunique as is_unique,
129                array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns
130            FROM pg_class t
131            JOIN pg_index ix ON t.oid = ix.indrelid
132            JOIN pg_class i ON i.oid = ix.indexrelid
133            JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
134            WHERE t.relname = $1 AND t.relkind = 'r'
135            GROUP BY i.relname, ix.indisunique
136            ORDER BY i.relname
137            ",
138        )
139        .bind(table_name.as_str())
140        .fetch_all(&*self.pool)
141        .await?;
142
143        let indexes = rows
144            .iter()
145            .map(|row| {
146                let name: String = row.get("index_name");
147                let unique: bool = row.get("is_unique");
148                let columns: Vec<String> = row.get("columns");
149                IndexInfo {
150                    name,
151                    columns,
152                    unique,
153                }
154            })
155            .collect();
156
157        Ok(indexes)
158    }
159
160    pub async fn count_rows(&self, table_name: &SafeIdentifier) -> DatabaseResult<i64> {
161        let quoted_table = quote_identifier(table_name.as_str());
162        let count_query = format!("SELECT COUNT(*) as count FROM {quoted_table}");
163        let row_count: i64 = sqlx::query_scalar(&count_query)
164            .fetch_one(&*self.pool)
165            .await?;
166
167        Ok(row_count)
168    }
169
170    pub async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
171        let version: String = sqlx::query_scalar("SELECT version()")
172            .fetch_one(&*self.pool)
173            .await?;
174
175        let size: i64 = sqlx::query_scalar("SELECT pg_database_size(current_database())")
176            .fetch_one(&*self.pool)
177            .await?;
178
179        let size = u64::try_from(size).map_err(|_| {
180            RepositoryError::internal(format!("pg_database_size returned negative value: {size}"))
181        })?;
182
183        let tables = self.list_tables().await?;
184
185        Ok(DatabaseInfo {
186            path: "PostgreSQL".to_string(),
187            size,
188            version,
189            tables,
190        })
191    }
192
193    pub fn get_expected_tables() -> Vec<&'static str> {
194        vec![
195            "users",
196            "user_sessions",
197            "user_contexts",
198            "agent_tasks",
199            "agent_skills",
200            "task_messages",
201            "task_artifacts",
202            "task_execution_steps",
203            "artifact_parts",
204            "message_parts",
205            "ai_requests",
206            "ai_request_messages",
207            "ai_request_tool_calls",
208            "mcp_tool_executions",
209            "logs",
210            "analytics_events",
211            "oauth_clients",
212            "oauth_auth_codes",
213            "oauth_refresh_tokens",
214            "scheduled_jobs",
215            "services",
216            "markdown_content",
217            "markdown_categories",
218            "files",
219            "content_files",
220        ]
221    }
222}
223
224fn quote_identifier(identifier: &str) -> String {
225    let escaped = identifier.replace('"', "\"\"");
226    format!("\"{escaped}\"")
227}