systemprompt_users/repository/banned_ip/
queries.rs1use anyhow::Result;
2
3use super::BannedIpRepository;
4use super::types::{BanDuration, BanIpParams, BanIpWithMetadataParams, BannedIp};
5
6impl BannedIpRepository {
7 pub async fn is_banned(&self, ip_address: &str) -> Result<bool> {
8 let result = sqlx::query_scalar!(
9 r#"
10 SELECT EXISTS(
11 SELECT 1 FROM banned_ips
12 WHERE ip_address = $1
13 AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)
14 ) as "exists!"
15 "#,
16 ip_address
17 )
18 .fetch_one(&*self.pool)
19 .await?;
20
21 Ok(result)
22 }
23
24 pub async fn find_ban(&self, ip_address: &str) -> Result<Option<BannedIp>> {
25 let row = sqlx::query_as!(
26 BannedIp,
27 r#"
28 SELECT
29 ip_address,
30 reason,
31 banned_at,
32 expires_at,
33 ban_count,
34 last_offense_path,
35 last_user_agent,
36 is_permanent,
37 source_fingerprint,
38 ban_source,
39 associated_session_ids
40 FROM banned_ips
41 WHERE ip_address = $1
42 AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)
43 "#,
44 ip_address
45 )
46 .fetch_optional(&*self.pool)
47 .await?;
48
49 Ok(row)
50 }
51
52 pub async fn ban_ip(&self, params: BanIpParams<'_>) -> Result<()> {
53 let expires_at = params.duration.to_expiry();
54 let is_permanent = matches!(params.duration, BanDuration::Permanent);
55
56 sqlx::query!(
57 r#"
58 INSERT INTO banned_ips (
59 ip_address, reason, expires_at, is_permanent,
60 source_fingerprint, ban_source
61 )
62 VALUES ($1, $2, $3, $4, $5, $6)
63 ON CONFLICT (ip_address) DO UPDATE SET
64 reason = $2,
65 expires_at = CASE
66 WHEN banned_ips.is_permanent THEN banned_ips.expires_at
67 ELSE COALESCE($3, banned_ips.expires_at)
68 END,
69 ban_count = banned_ips.ban_count + 1,
70 is_permanent = banned_ips.is_permanent OR $4,
71 source_fingerprint = COALESCE($5, banned_ips.source_fingerprint),
72 ban_source = $6
73 "#,
74 params.ip_address,
75 params.reason,
76 expires_at,
77 is_permanent,
78 params.source_fingerprint,
79 params.ban_source
80 )
81 .execute(&*self.write_pool)
82 .await?;
83
84 Ok(())
85 }
86
87 pub async fn ban_ip_with_metadata(&self, params: BanIpWithMetadataParams<'_>) -> Result<()> {
88 let expires_at = params.duration.to_expiry();
89 let is_permanent = matches!(params.duration, BanDuration::Permanent);
90 let session_ids: Option<Vec<String>> = params.session_id.map(|s| vec![s.to_string()]);
91
92 sqlx::query!(
93 r#"
94 INSERT INTO banned_ips (
95 ip_address, reason, expires_at, is_permanent,
96 source_fingerprint, ban_source, last_offense_path,
97 last_user_agent, associated_session_ids
98 )
99 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
100 ON CONFLICT (ip_address) DO UPDATE SET
101 reason = $2,
102 expires_at = CASE
103 WHEN banned_ips.is_permanent THEN banned_ips.expires_at
104 ELSE COALESCE($3, banned_ips.expires_at)
105 END,
106 ban_count = banned_ips.ban_count + 1,
107 is_permanent = banned_ips.is_permanent OR $4,
108 source_fingerprint = COALESCE($5, banned_ips.source_fingerprint),
109 ban_source = $6,
110 last_offense_path = COALESCE($7, banned_ips.last_offense_path),
111 last_user_agent = COALESCE($8, banned_ips.last_user_agent),
112 associated_session_ids = CASE
113 WHEN $9::TEXT[] IS NOT NULL
114 THEN array_cat(COALESCE(banned_ips.associated_session_ids, '{}'::TEXT[]), $9)
115 ELSE banned_ips.associated_session_ids
116 END
117 "#,
118 params.ip_address,
119 params.reason,
120 expires_at,
121 is_permanent,
122 params.source_fingerprint,
123 params.ban_source,
124 params.offense_path,
125 params.user_agent,
126 session_ids.as_deref()
127 )
128 .execute(&*self.write_pool)
129 .await?;
130
131 Ok(())
132 }
133
134 pub async fn unban_ip(&self, ip_address: &str) -> Result<bool> {
135 let result = sqlx::query!(
136 r#"
137 DELETE FROM banned_ips
138 WHERE ip_address = $1
139 "#,
140 ip_address
141 )
142 .execute(&*self.write_pool)
143 .await?;
144
145 Ok(result.rows_affected() > 0)
146 }
147
148 pub async fn cleanup_expired(&self) -> Result<u64> {
149 let result = sqlx::query!(
150 r#"
151 DELETE FROM banned_ips
152 WHERE expires_at IS NOT NULL
153 AND expires_at < CURRENT_TIMESTAMP
154 AND NOT is_permanent
155 "#
156 )
157 .execute(&*self.write_pool)
158 .await?;
159
160 Ok(result.rows_affected())
161 }
162}