uri_register/
postgres.rs

1// Copyright TELICENT LTD
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::cache::{create_cache, Cache, CacheStrategy};
16use crate::error::{ConfigurationError, Result};
17use crate::service::UriService;
18use async_trait::async_trait;
19use deadpool_postgres::{ManagerConfig, Pool, RecyclingMethod, Runtime};
20use rustls::RootCertStore;
21use std::sync::Arc;
22use tokio_postgres::{Config, NoTls};
23use tokio_postgres_rustls::MakeRustlsConnect;
24use tracing::{debug, info, instrument, trace};
25use url::Url;
26
27/// PostgreSQL-based URI register implementation with configurable caching
28///
29/// This implementation uses a PostgreSQL table to store URI-to-ID mappings
30/// with an in-memory cache (W-TinyLFU by default, or LRU) to reduce database round-trips.
31/// It's designed for high concurrency with connection pooling and batch operations.
32///
33/// ## Prerequisites
34///
35/// The database schema must be initialized before using this service.
36/// See `schema.sql` for the DDL statements.
37///
38/// ## URI Validation
39///
40/// All URIs are validated before registration to ensure they conform to RFC 3986.
41/// Invalid URIs will return an error.
42///
43/// ## Performance
44///
45/// With default logged tables on typical hardware:
46/// - Batch insert: ~10K-50K URIs/sec
47/// - Batch lookup (cached): ~100K-1M+ URIs/sec (no DB round-trip)
48/// - Batch lookup (uncached): ~100K-200K URIs/sec
49/// - Query overhead: ~2-10ms per query (2 round-trips)
50///
51/// The cache (W-TinyLFU or LRU) significantly improves performance for repeated URI lookups.
52/// Cache strategy and size are configurable when creating the register instance.
53///
54/// For faster writes at the cost of durability, the table can be configured
55/// as UNLOGGED (see `schema.sql` for options).
56pub struct PostgresUriRegister {
57    pool: Pool,
58    /// Cache for URI-to-ID mappings (W-TinyLFU or LRU)
59    cache: Arc<dyn Cache>,
60    /// Name of the database table to use
61    table_name: String,
62}
63
64impl PostgresUriRegister {
65    /// Create a new PostgreSQL URI register service with configurable cache
66    ///
67    /// # Arguments
68    ///
69    /// * `database_url` - PostgreSQL connection string (e.g., "postgres://user:password@host:port/database")
70    /// * `table_name` - Name of the database table to use (must be a valid SQL identifier, default: "uri_register")
71    /// * `max_connections` - Maximum number of connections in the pool (recommended: 10-50)
72    /// * `cache_size` - Number of URI-to-ID mappings to cache in memory (recommended: 1,000-100,000)
73    /// * `cache_strategy` - Cache strategy to use (Moka/W-TinyLFU is default and recommended for most workloads)
74    ///
75    /// # Prerequisites
76    ///
77    /// The database schema must be initialized before using this service.
78    /// See the `schema.sql` file and README.md for setup instructions.
79    ///
80    /// # Example
81    ///
82    /// ```rust,no_run
83    /// use uri_register::PostgresUriRegister;
84    ///
85    /// #[tokio::main]
86    /// async fn main() -> uri_register::Result<()> {
87    ///     let register = PostgresUriRegister::new(
88    ///         "postgres://localhost/mydb",
89    ///         "uri_register",  // table name
90    ///         20,              // max connections
91    ///         10_000           // cache size (defaults to Moka/W-TinyLFU)
92    ///     ).await?;
93    ///     Ok(())
94    /// }
95    /// ```
96    pub async fn new(
97        database_url: &str,
98        table_name: &str,
99        max_connections: u32,
100        cache_size: usize,
101    ) -> Result<Self> {
102        Self::new_with_cache_strategy(
103            database_url,
104            table_name,
105            max_connections,
106            cache_size,
107            None, // Default to Moka
108            None, // Default to no TLS
109        )
110        .await
111    }
112
113    /// Create a new PostgreSQL URI register with a specific cache strategy and TLS
114    ///
115    /// This is identical to `new()` but allows specifying a cache strategy and TLS option.
116    /// Most users should use `new()` which defaults to the recommended Moka (W-TinyLFU) cache and no TLS.
117    ///
118    /// # Arguments
119    ///
120    /// * `cache_strategy` - Optional cache strategy (None = Moka default, or specify CacheStrategy::Lru)
121    /// * `use_tls` - Optional TLS flag (None/false = no TLS, true = TLS with webpki root certificates)
122    ///
123    /// # Example
124    ///
125    /// ```rust,no_run
126    /// use uri_register::{CacheStrategy, PostgresUriRegister};
127    ///
128    /// #[tokio::main]
129    /// async fn main() -> uri_register::Result<()> {
130    ///     // Use LRU instead of default Moka, with TLS enabled
131    ///     let register = PostgresUriRegister::new_with_cache_strategy(
132    ///         "postgres://localhost/mydb",
133    ///         "uri_register",
134    ///         20,
135    ///         10_000,
136    ///         Some(CacheStrategy::Lru),
137    ///         Some(true)  // Enable TLS
138    ///     ).await?;
139    ///     Ok(())
140    /// }
141    /// ```
142    pub async fn new_with_cache_strategy(
143        database_url: &str,
144        table_name: &str,
145        max_connections: u32,
146        cache_size: usize,
147        cache_strategy: Option<CacheStrategy>,
148        use_tls: Option<bool>,
149    ) -> Result<Self> {
150        // Validate inputs
151        if cache_size == 0 {
152            return Err(ConfigurationError::InvalidCacheSize(cache_size).into());
153        }
154
155        if max_connections == 0 {
156            return Err(ConfigurationError::InvalidMaxConnections(max_connections).into());
157        }
158
159        // Validate table name as SQL identifier
160        Self::validate_table_name(table_name)?;
161
162        // Parse the database URL into tokio-postgres Config
163        let pg_config: Config = database_url.parse().map_err(|e| {
164            ConfigurationError::InvalidBackoff(format!("Failed to parse database URL: {}", e))
165        })?;
166
167        // Create deadpool configuration
168        let mut cfg = deadpool_postgres::Config::new();
169        cfg.dbname = pg_config.get_dbname().map(|s| s.to_string());
170        cfg.host = pg_config.get_hosts().first().map(|h| match h {
171            tokio_postgres::config::Host::Tcp(s) => s.to_string(),
172            #[cfg(unix)]
173            tokio_postgres::config::Host::Unix(p) => p.to_str().unwrap_or_default().to_string(),
174        });
175        cfg.port = pg_config.get_ports().first().copied();
176        cfg.user = pg_config.get_user().map(|s| s.to_string());
177        cfg.password = pg_config
178            .get_password()
179            .map(|p| std::str::from_utf8(p).unwrap_or_default().to_string());
180        cfg.manager = Some(ManagerConfig {
181            recycling_method: RecyclingMethod::Fast,
182        });
183        cfg.pool = Some(deadpool_postgres::PoolConfig {
184            max_size: max_connections as usize,
185            timeouts: deadpool_postgres::Timeouts {
186                wait: Some(std::time::Duration::from_secs(10)),
187                create: Some(std::time::Duration::from_secs(10)),
188                recycle: Some(std::time::Duration::from_secs(10)),
189            },
190            ..Default::default()
191        });
192
193        // Create the pool with or without TLS based on use_tls parameter
194        let pool = if use_tls.unwrap_or(false) {
195            // Configure TLS with webpki root certificates
196            let mut root_store = RootCertStore::empty();
197            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
198
199            let tls_config = rustls::ClientConfig::builder()
200                .with_root_certificates(root_store)
201                .with_no_client_auth();
202
203            let tls = MakeRustlsConnect::new(tls_config);
204
205            cfg.create_pool(Some(Runtime::Tokio1), tls).map_err(|e| {
206                ConfigurationError::InvalidBackoff(format!(
207                    "Failed to create connection pool with TLS: {}",
208                    e
209                ))
210            })?
211        } else {
212            // No TLS
213            cfg.create_pool(Some(Runtime::Tokio1), NoTls).map_err(|e| {
214                ConfigurationError::InvalidBackoff(format!(
215                    "Failed to create connection pool: {}",
216                    e
217                ))
218            })?
219        };
220
221        let cache = create_cache(cache_strategy.unwrap_or_default(), cache_size);
222
223        info!(
224            table = table_name,
225            max_connections,
226            cache_size,
227            tls = use_tls.unwrap_or(false),
228            "URI register connected"
229        );
230
231        Ok(Self {
232            pool,
233            cache,
234            table_name: table_name.to_string(),
235        })
236    }
237
238    /// Validate that a table name is a valid SQL identifier
239    ///
240    /// Prevents SQL injection by ensuring the table name only contains
241    /// alphanumeric characters and underscores, and doesn't start with a digit.
242    fn validate_table_name(name: &str) -> Result<()> {
243        if name.is_empty() {
244            return Err(ConfigurationError::InvalidTableName(
245                "table name cannot be empty".to_string(),
246            )
247            .into());
248        }
249
250        if name.len() > 63 {
251            return Err(ConfigurationError::InvalidTableName(format!(
252                "table name too long (max 63 characters): '{}'",
253                name
254            ))
255            .into());
256        }
257
258        // First character must be a letter or underscore
259        let first_char = name.chars().next().unwrap();
260        if !first_char.is_ascii_alphabetic() && first_char != '_' {
261            return Err(ConfigurationError::InvalidTableName(format!(
262                "table name must start with a letter or underscore: '{}'",
263                name
264            ))
265            .into());
266        }
267
268        // All characters must be alphanumeric or underscore
269        if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
270            return Err(ConfigurationError::InvalidTableName(format!(
271                "table name can only contain letters, numbers, and underscores: '{}'",
272                name
273            ))
274            .into());
275        }
276
277        Ok(())
278    }
279
280    /// Get statistics about the URI register
281    ///
282    /// Returns the total number of URIs and the storage size.
283    ///
284    /// # Example
285    ///
286    /// ```rust,no_run
287    /// use uri_register::PostgresUriRegister;
288    ///
289    /// #[tokio::main]
290    /// async fn main() -> uri_register::Result<()> {
291    ///     let register = PostgresUriRegister::new(
292    ///         "postgres://localhost/mydb",
293    ///         "uri_register",
294    ///         20,
295    ///         10_000
296    ///     ).await?;
297    ///     let stats = register.stats().await?;
298    ///     println!("Total URIs: {}", stats.total_uris);
299    ///     println!("Size: {} bytes", stats.size_bytes);
300    ///     Ok(())
301    /// }
302    /// ```
303    pub async fn stats(&self) -> Result<RegisterStats> {
304        // Build query with validated table name (safe from SQL injection)
305        let query = format!(
306            r#"
307            SELECT
308                COUNT(*)::bigint as count,
309                pg_total_relation_size('{}')::bigint as size_bytes
310            FROM {}
311            "#,
312            self.table_name, self.table_name
313        );
314
315        // Execute with retry logic
316        let client = self.pool.get().await.map_err(|e| {
317            crate::error::Error::Database(format!("Failed to get database connection: {}", e))
318        })?;
319
320        let rows = client
321            .query(&query, &[])
322            .await
323            .map_err(|e| crate::error::Error::Database(e.to_string()))?;
324
325        let row = rows.into_iter().next().ok_or_else(|| {
326            crate::error::Error::Database("No rows returned from stats query".to_string())
327        })?;
328
329        // Get cache statistics
330        let cache_stats = self.cache.stats();
331
332        // Get connection pool statistics
333        let status = self.pool.status();
334        let pool_stats = PoolStats {
335            connections_active: (status.size - status.available) as u32,
336            connections_idle: status.available as u32,
337            connections_max: status.max_size as u32,
338        };
339
340        Ok(RegisterStats {
341            total_uris: row.get::<_, i64>("count") as u64,
342            size_bytes: row.get::<_, i64>("size_bytes") as u64,
343            cache: cache_stats,
344            pool: pool_stats,
345        })
346    }
347
348    /// Clone the register instance (shares pool and cache)
349    ///
350    /// This is a shallow clone that shares both the connection pool and cache.
351    /// Both pool and cache use Arc internally, so this clone is cheap and shares
352    /// the underlying resources.
353    ///
354    /// This method is primarily used for Python bindings where we need to move
355    /// data into async closures.
356    #[cfg(feature = "python")]
357    pub(crate) fn clone_inner(&self) -> Self {
358        PostgresUriRegister {
359            pool: self.pool.clone(),
360            cache: self.cache.clone(), // Clone the Arc, shares the same cache
361            table_name: self.table_name.clone(),
362        }
363    }
364
365    /// Validate that a string is a valid URI according to RFC 3986
366    fn validate_uri(uri: &str) -> Result<()> {
367        Url::parse(uri).map_err(|e| {
368            crate::error::Error::InvalidUri(format!("Invalid URI '{}': {}", uri, e))
369        })?;
370        Ok(())
371    }
372}
373
374#[async_trait]
375impl UriService for PostgresUriRegister {
376    #[instrument(skip(self), fields(table = %self.table_name))]
377    async fn register_uri(&self, uri: &str) -> Result<u64> {
378        // Validate URI first
379        Self::validate_uri(uri)?;
380
381        // Check cache first
382        if let Some(id) = self.cache.get(uri) {
383            trace!(id, "cache hit");
384            return Ok(id);
385        }
386        trace!("cache miss, querying database");
387
388        // Insert and return ID (ON CONFLICT handles race conditions and existing URIs)
389        // Build query with validated table name (safe from SQL injection)
390        let query = format!(
391            r#"
392            INSERT INTO {} (uri)
393            VALUES ($1)
394            ON CONFLICT (uri_hash) DO UPDATE SET uri = EXCLUDED.uri
395            RETURNING id
396            "#,
397            self.table_name
398        );
399
400        // Execute with retry logic
401        let client = self.pool.get().await.map_err(|e| {
402            crate::error::Error::Database(format!("Failed to get database connection: {}", e))
403        })?;
404
405        let rows = client
406            .query(&query, &[&uri])
407            .await
408            .map_err(|e| crate::error::Error::Database(e.to_string()))?;
409
410        let result = rows.into_iter().next().ok_or_else(|| {
411            crate::error::Error::Database("No rows returned from register_uri query".to_string())
412        })?;
413
414        let id = result.get::<_, i64>("id") as u64;
415
416        // Update cache
417        self.cache.put(uri.to_string(), id);
418
419        Ok(id)
420    }
421
422    #[instrument(skip(self, uris), fields(table = %self.table_name, batch_size = uris.len()))]
423    async fn register_uri_batch(&self, uris: &[String]) -> Result<Vec<u64>> {
424        if uris.is_empty() {
425            trace!("empty batch, returning early");
426            return Ok(Vec::new());
427        }
428
429        // Validate all URIs first
430        for uri in uris {
431            Self::validate_uri(uri)?;
432        }
433
434        // CORRECTNESS GUARANTEE: Order preservation
435        // We maintain strict correspondence between input URIs and output IDs
436        // by tracking the original index of each URI and using URI strings
437        // (not SQL result order) to map IDs back to their positions.
438
439        let mut result_ids = vec![None; uris.len()];
440        let mut uncached_indices = Vec::new();
441        let mut uncached_uris_dedup = Vec::new();
442        let mut seen_uncached = std::collections::HashMap::new();
443
444        // Step 1: Check cache for all URIs
445        for (idx, uri) in uris.iter().enumerate() {
446            if let Some(id) = self.cache.get(uri) {
447                result_ids[idx] = Some(id);
448            } else {
449                uncached_indices.push(idx);
450                // Deduplicate uncached URIs for DB query
451                if !seen_uncached.contains_key(uri) {
452                    seen_uncached.insert(uri.clone(), uncached_uris_dedup.len());
453                    uncached_uris_dedup.push(uri.clone());
454                }
455            }
456        }
457
458        // If everything was cached, return early
459        if uncached_uris_dedup.is_empty() {
460            debug!(cached = uris.len(), "all URIs found in cache");
461            return Ok(result_ids.into_iter().map(|id| id.unwrap()).collect());
462        }
463
464        let cached_count = uris.len() - uncached_indices.len();
465        debug!(
466            cached = cached_count,
467            uncached = uncached_uris_dedup.len(),
468            "cache lookup complete, querying database"
469        );
470
471        // Step 2: Register deduplicated uncached URIs in batch
472        // IMPORTANT: SQL may return results in ANY order (not guaranteed to match input order)
473        // We use "RETURNING id, uri" to get BOTH values together, then map by URI string
474        // Build query with validated table name (safe from SQL injection)
475        let query = format!(
476            r#"
477            INSERT INTO {} (uri)
478            SELECT unnest($1::text[])
479            ON CONFLICT (uri_hash) DO UPDATE SET uri = EXCLUDED.uri
480            RETURNING id, uri
481            "#,
482            self.table_name
483        );
484
485        // Execute with retry logic
486        let client = self.pool.get().await.map_err(|e| {
487            crate::error::Error::Database(format!("Failed to get database connection: {}", e))
488        })?;
489
490        let rows = client
491            .query(&query, &[&uncached_uris_dedup])
492            .await
493            .map_err(|e| crate::error::Error::Database(e.to_string()))?;
494
495        // Build a map of URI -> ID from database results
496        // This allows us to look up IDs by URI string (order-independent)
497        let mut uri_to_id = std::collections::HashMap::new();
498        for row in rows {
499            let uri: String = row.get("uri");
500            let id: i64 = row.get("id");
501            uri_to_id.insert(uri, id as u64);
502        }
503
504        // Step 3: Fill in the result vector and update cache
505        // CORRECTNESS: We use the saved indices and look up by URI string,
506        // guaranteeing that result_ids[i] corresponds to uris[i]
507        for idx in uncached_indices {
508            let uri = &uris[idx]; // Get URI from original position
509            if let Some(&id) = uri_to_id.get(uri) {
510                // Look up ID by URI string
511                result_ids[idx] = Some(id); // Store at original index
512                self.cache.put(uri.clone(), id);
513            }
514        }
515
516        // Convert Option<u64> to u64 (all should be Some at this point)
517        Ok(result_ids
518            .into_iter()
519            .map(|id| id.expect("All URIs should have IDs"))
520            .collect())
521    }
522
523    #[instrument(skip(self, uris), fields(table = %self.table_name, batch_size = uris.len()))]
524    async fn register_uri_batch_hashmap(
525        &self,
526        uris: &[String],
527    ) -> Result<std::collections::HashMap<String, u64>> {
528        if uris.is_empty() {
529            trace!("empty batch, returning early");
530            return Ok(std::collections::HashMap::new());
531        }
532
533        // Validate all URIs first
534        for uri in uris {
535            Self::validate_uri(uri)?;
536        }
537
538        // CORRECTNESS GUARANTEE: URI-to-ID mapping accuracy
539        // Each URI in the result HashMap is guaranteed to map to its correct ID
540        // because SQL returns both 'id' and 'uri' together in each row (RETURNING id, uri).
541        // We never rely on positional correspondence, eliminating ordering errors.
542
543        let mut result = std::collections::HashMap::new();
544        let mut uncached_uris = Vec::new();
545
546        // Step 1: Deduplicate input and check cache
547        let unique_uris: std::collections::HashSet<_> = uris.iter().collect();
548
549        for uri in unique_uris {
550            if let Some(id) = self.cache.get(uri) {
551                result.insert(uri.clone(), id);
552            } else {
553                uncached_uris.push(uri.clone());
554            }
555        }
556
557        // If everything was cached, return early
558        if uncached_uris.is_empty() {
559            debug!(cached = result.len(), "all URIs found in cache");
560            return Ok(result);
561        }
562
563        debug!(
564            cached = result.len(),
565            uncached = uncached_uris.len(),
566            "cache lookup complete, querying database"
567        );
568
569        // Step 2: Register uncached URIs in batch
570        // Build query with validated table name (safe from SQL injection)
571        let query = format!(
572            r#"
573            INSERT INTO {} (uri)
574            SELECT unnest($1::text[])
575            ON CONFLICT (uri_hash) DO UPDATE SET uri = EXCLUDED.uri
576            RETURNING id, uri
577            "#,
578            self.table_name
579        );
580
581        // Execute with retry logic
582        let client = self.pool.get().await.map_err(|e| {
583            crate::error::Error::Database(format!("Failed to get database connection: {}", e))
584        })?;
585
586        let rows = client
587            .query(&query, &[&uncached_uris])
588            .await
589            .map_err(|e| crate::error::Error::Database(e.to_string()))?;
590
591        // Step 3: Add database results to result map and update cache
592        // CORRECTNESS: Each row contains both URI and ID from same DB row,
593        // guaranteeing correct mapping (no opportunity for misalignment)
594        for row in rows {
595            let uri: String = row.get("uri");
596            let id: i64 = row.get("id");
597            let id_u64 = id as u64;
598
599            result.insert(uri.clone(), id_u64); // URI and ID are from same row
600            self.cache.put(uri, id_u64);
601        }
602
603        Ok(result)
604    }
605}
606
607/// Statistics about the URI register for observability and OpenTelemetry
608#[derive(Debug, Clone)]
609pub struct RegisterStats {
610    /// Total number of URIs in the register
611    pub total_uris: u64,
612    /// Total storage size in bytes (includes indexes)
613    pub size_bytes: u64,
614    /// Cache performance metrics
615    pub cache: crate::cache::CacheStats,
616    /// Connection pool metrics
617    pub pool: PoolStats,
618}
619
620/// Connection pool statistics for observability
621#[derive(Debug, Clone)]
622pub struct PoolStats {
623    /// Number of connections currently being used
624    pub connections_active: u32,
625    /// Number of idle connections in the pool
626    pub connections_idle: u32,
627    /// Maximum number of connections allowed in the pool
628    pub connections_max: u32,
629}