synaptic_postgres/
cache.rs1use async_trait::async_trait;
2use sqlx::PgPool;
3use synaptic_core::{validate_table_name, ChatResponse, SynapticError};
4
5#[derive(Debug, Clone)]
7pub struct PgCacheConfig {
8 pub table_name: String,
10 pub ttl: Option<u64>,
13}
14
15impl PgCacheConfig {
16 pub fn new(table_name: impl Into<String>) -> Self {
18 Self {
19 table_name: table_name.into(),
20 ttl: None,
21 }
22 }
23
24 pub fn with_ttl(mut self, seconds: u64) -> Self {
26 self.ttl = Some(seconds);
27 self
28 }
29}
30
31pub struct PgCache {
56 pool: PgPool,
57 config: PgCacheConfig,
58}
59
60impl PgCache {
61 pub fn new(pool: PgPool, config: PgCacheConfig) -> Self {
63 Self { pool, config }
64 }
65
66 pub async fn initialize(&self) -> Result<(), SynapticError> {
70 validate_table_name(&self.config.table_name)?;
71
72 let create_table = format!(
73 r#"CREATE TABLE IF NOT EXISTS {table} (
74 key TEXT PRIMARY KEY,
75 value TEXT NOT NULL,
76 created_at BIGINT NOT NULL DEFAULT (EXTRACT(EPOCH FROM now())::BIGINT)
77 )"#,
78 table = self.config.table_name,
79 );
80
81 sqlx::query(&create_table)
82 .execute(&self.pool)
83 .await
84 .map_err(|e| SynapticError::Cache(format!("failed to create table: {e}")))?;
85
86 Ok(())
87 }
88
89 pub fn pool(&self) -> &PgPool {
91 &self.pool
92 }
93
94 pub fn config(&self) -> &PgCacheConfig {
96 &self.config
97 }
98}
99
100#[async_trait]
101impl synaptic_core::LlmCache for PgCache {
102 async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError> {
103 validate_table_name(&self.config.table_name)?;
104
105 let json_str: Option<String> = if let Some(ttl) = self.config.ttl {
106 let sql = format!(
107 "SELECT value FROM {table} WHERE key = $1 AND created_at + $2 > EXTRACT(EPOCH FROM now())::BIGINT",
108 table = self.config.table_name,
109 );
110 sqlx::query_scalar(&sql)
111 .bind(key)
112 .bind(ttl as i64)
113 .fetch_optional(&self.pool)
114 .await
115 .map_err(|e| SynapticError::Cache(format!("query error: {e}")))?
116 } else {
117 let sql = format!(
118 "SELECT value FROM {table} WHERE key = $1",
119 table = self.config.table_name,
120 );
121 sqlx::query_scalar(&sql)
122 .bind(key)
123 .fetch_optional(&self.pool)
124 .await
125 .map_err(|e| SynapticError::Cache(format!("query error: {e}")))?
126 };
127
128 match json_str {
129 Some(s) => {
130 let response: ChatResponse = serde_json::from_str(&s)
131 .map_err(|e| SynapticError::Cache(format!("JSON deserialize error: {e}")))?;
132 Ok(Some(response))
133 }
134 None => Ok(None),
135 }
136 }
137
138 async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError> {
139 validate_table_name(&self.config.table_name)?;
140
141 let value = serde_json::to_string(response)
142 .map_err(|e| SynapticError::Cache(format!("JSON serialize error: {e}")))?;
143
144 let sql = format!(
145 r#"INSERT INTO {table} (key, value, created_at)
146 VALUES ($1, $2, EXTRACT(EPOCH FROM now())::BIGINT)
147 ON CONFLICT (key) DO UPDATE
148 SET value = EXCLUDED.value,
149 created_at = EXCLUDED.created_at"#,
150 table = self.config.table_name,
151 );
152
153 sqlx::query(&sql)
154 .bind(key)
155 .bind(&value)
156 .execute(&self.pool)
157 .await
158 .map_err(|e| SynapticError::Cache(format!("insert error: {e}")))?;
159
160 Ok(())
161 }
162
163 async fn clear(&self) -> Result<(), SynapticError> {
164 validate_table_name(&self.config.table_name)?;
165
166 let sql = format!("DELETE FROM {table}", table = self.config.table_name);
167
168 sqlx::query(&sql)
169 .execute(&self.pool)
170 .await
171 .map_err(|e| SynapticError::Cache(format!("delete error: {e}")))?;
172
173 Ok(())
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn config_construction() {
183 let config = PgCacheConfig::new("my_cache");
184 assert_eq!(config.table_name, "my_cache");
185 assert!(config.ttl.is_none());
186 }
187
188 #[test]
189 fn config_with_ttl() {
190 let config = PgCacheConfig::new("my_cache").with_ttl(3600);
191 assert_eq!(config.table_name, "my_cache");
192 assert_eq!(config.ttl, Some(3600));
193 }
194
195 #[test]
196 fn validate_table_name_accepts_valid_names() {
197 assert!(validate_table_name("llm_cache").is_ok());
198 assert!(validate_table_name("my_cache").is_ok());
199 assert!(validate_table_name("public.llm_cache").is_ok());
200 assert!(validate_table_name("schema1.cache2").is_ok());
201 }
202
203 #[test]
204 fn validate_table_name_rejects_sql_injection() {
205 assert!(validate_table_name("cache; DROP TABLE users").is_err());
206 assert!(validate_table_name("cache--comment").is_err());
207 assert!(validate_table_name("cache'malicious").is_err());
208 assert!(validate_table_name("").is_err());
209 }
210}