soar_core/database/packages/
query.rs

1use std::sync::{Arc, Mutex};
2
3use rusqlite::{Connection, ToSql};
4
5use crate::{
6    database::models::{FromRow, InstalledPackage},
7    error::SoarError,
8    SoarResult,
9};
10
11use super::{FilterCondition, LogicalOp, PaginatedResponse, QueryFilter, SortDirection};
12
13#[derive(Debug, Clone)]
14pub struct PackageQueryBuilder {
15    db: Arc<Mutex<Connection>>,
16    filters: Vec<QueryFilter>,
17    sort_fields: Vec<(String, SortDirection)>,
18    limit: Option<u32>,
19    shards: Option<Vec<String>>,
20    page: u32,
21    select_columns: Vec<String>,
22}
23
24impl PackageQueryBuilder {
25    pub fn new(db: Arc<Mutex<Connection>>) -> Self {
26        Self {
27            db,
28            filters: Vec::new(),
29            sort_fields: Vec::new(),
30            limit: None,
31            shards: None,
32            page: 1,
33            select_columns: Vec::new(),
34        }
35    }
36
37    pub fn select(mut self, columns: &[&str]) -> Self {
38        self.select_columns
39            .extend(columns.iter().map(|&col| col.to_string()));
40        self
41    }
42
43    pub fn clear_filters(mut self) -> Self {
44        self.filters = Vec::new();
45        self
46    }
47
48    pub fn where_and(mut self, field: &str, condition: FilterCondition) -> Self {
49        self.filters.push(QueryFilter {
50            field: field.to_string(),
51            condition,
52            logical_op: Some(LogicalOp::And),
53        });
54        self
55    }
56
57    pub fn where_or(mut self, field: &str, condition: FilterCondition) -> Self {
58        self.filters.push(QueryFilter {
59            field: field.to_string(),
60            condition,
61            logical_op: Some(LogicalOp::Or),
62        });
63        self
64    }
65
66    pub fn json_where_or(
67        mut self,
68        field: &str,
69        json_field: &str,
70        condition: FilterCondition,
71    ) -> Self {
72        let select_clause = format!("SELECT 1 FROM json_each({field})");
73        let extract_value = format!("json_extract(value, '$.{json_field}')");
74        let where_clause = self.build_subquery_where_clause(&extract_value, condition);
75
76        let query = format!("EXISTS ({select_clause} WHERE {where_clause})");
77
78        self.filters.push(QueryFilter {
79            field: query,
80            condition: FilterCondition::None,
81            logical_op: Some(LogicalOp::Or),
82        });
83        self
84    }
85
86    pub fn json_where_and(
87        mut self,
88        field: &str,
89        json_field: &str,
90        condition: FilterCondition,
91    ) -> Self {
92        let select_clause = format!("SELECT 1 FROM json_each({field})");
93        let extract_value = format!("json_extract(value, '$.{json_field}')");
94        let where_clause = self.build_subquery_where_clause(&extract_value, condition);
95
96        let query = format!("EXISTS ({select_clause} WHERE {where_clause})");
97
98        self.filters.push(QueryFilter {
99            field: query,
100            condition: FilterCondition::None,
101            logical_op: Some(LogicalOp::And),
102        });
103        self
104    }
105
106    pub fn database(mut self, db: Arc<Mutex<Connection>>) -> Self {
107        self.db = db;
108        self
109    }
110
111    pub fn sort_by(mut self, field: &str, direction: SortDirection) -> Self {
112        self.sort_fields.push((field.to_string(), direction));
113        self
114    }
115
116    pub fn limit(mut self, limit: u32) -> Self {
117        self.limit = Some(limit);
118        self
119    }
120
121    pub fn clear_limit(mut self) -> Self {
122        self.limit = None;
123        self
124    }
125
126    pub fn page(mut self, page: u32) -> Self {
127        self.page = page;
128        self
129    }
130
131    pub fn shards(mut self, shards: Vec<String>) -> Self {
132        self.shards = Some(shards);
133        self
134    }
135
136    pub fn load<T: FromRow>(&self) -> SoarResult<PaginatedResponse<T>> {
137        let conn = self.db.lock().map_err(|_| SoarError::PoisonError)?;
138        let shards = self.get_shards(&conn)?;
139
140        let (query, params) = self.build_query(&shards)?;
141        let mut stmt = conn.prepare(&query)?;
142
143        let params_ref: Vec<&dyn rusqlite::ToSql> = params
144            .iter()
145            .map(|p| p.as_ref() as &dyn rusqlite::ToSql)
146            .collect();
147
148        let items = stmt
149            .query_map(params_ref.as_slice(), T::from_row)?
150            .filter_map(|r| match r {
151                Ok(pkg) => Some(pkg),
152                Err(err) => {
153                    eprintln!("Package map error: {err:#?}");
154                    None
155                }
156            })
157            .collect();
158
159        let (count_query, count_params) = self.build_count_query(&shards);
160        let mut count_stmt = conn.prepare(&count_query)?;
161        let count_params_ref: Vec<&dyn rusqlite::ToSql> = count_params
162            .iter()
163            .map(|p| p.as_ref() as &dyn rusqlite::ToSql)
164            .collect();
165        let total: u64 = count_stmt.query_row(count_params_ref.as_slice(), |row| row.get(0))?;
166
167        let page = self.page;
168        let limit = self.limit;
169
170        let has_next = limit.map_or_else(|| false, |v| (self.page as u64 * v as u64) < total);
171
172        Ok(PaginatedResponse {
173            items,
174            page,
175            limit,
176            total,
177            has_next,
178        })
179    }
180
181    fn get_shards(&self, conn: &Connection) -> SoarResult<Vec<String>> {
182        let shards = self.shards.clone().unwrap_or_else(|| {
183            let mut stmt = conn.prepare("PRAGMA database_list").unwrap();
184            stmt.query_map([], |row| row.get::<_, String>(1))
185                .unwrap()
186                .filter_map(Result::ok)
187                .collect()
188        });
189        Ok(shards)
190    }
191
192    fn build_query(
193        &self,
194        shards: &[String],
195    ) -> SoarResult<(String, Vec<Box<dyn rusqlite::ToSql>>)> {
196        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
197
198        let shard_queries: Vec<String> = shards
199            .iter()
200            .map(|shard| {
201                let cols = if self.select_columns.is_empty() {
202                    vec![
203                        "p.id",
204                        "disabled",
205                        "json(disabled_reason) AS disabled_reason",
206                        "rank",
207                        "pkg",
208                        "pkg_id",
209                        "pkg_name",
210                        "pkg_family",
211                        "pkg_type",
212                        "pkg_webpage",
213                        "app_id",
214                        "description",
215                        "version",
216                        "version_upstream",
217                        "json(licenses) AS licenses",
218                        "download_url",
219                        "size",
220                        "ghcr_pkg",
221                        "ghcr_size",
222                        "json(ghcr_files) AS ghcr_files",
223                        "ghcr_blob",
224                        "ghcr_url",
225                        "bsum",
226                        "shasum",
227                        "icon",
228                        "desktop",
229                        "appstream",
230                        "json(homepages) AS homepages",
231                        "json(notes) AS notes",
232                        "json(source_urls) AS source_urls",
233                        "json(tags) AS tags",
234                        "json(categories) AS categories",
235                        "build_id",
236                        "build_date",
237                        "build_action",
238                        "build_script",
239                        "build_log",
240                        "json(provides) AS provides",
241                        "json(snapshots) AS snapshots",
242                        "json(repology) AS repology",
243                        "json(replaces) AS replaces",
244                        "download_count",
245                        "download_count_week",
246                        "download_count_month",
247                        "bundle",
248                        "bundle_type",
249                        "soar_syms",
250                        "deprecated",
251                        "desktop_integration",
252                        "external",
253                        "installable",
254                        "portable",
255                        "trusted",
256                        "version_latest",
257                        "version_outdated",
258                    ]
259                    .join(",")
260                } else {
261                    self.select_columns.join(",")
262                };
263                let select_clause = format!(
264                    "SELECT
265                        {cols}, r.name AS repo_name,
266                        json_group_array(
267                            json_object(
268                                'name', m.name,
269                                'contact', m.contact
270                            )
271                        ) FILTER (WHERE m.id IS NOT NULL) as maintainers
272                     FROM
273                         {shard}.packages p
274                         JOIN {shard}.repository r
275                         LEFT JOIN {shard}.package_maintainers pm ON p.id = pm.package_id
276                         LEFT JOIN {shard}.maintainers m ON m.id = pm.maintainer_id
277                    ",
278                );
279
280                let where_clause = self.build_where_clause(&mut params);
281
282                let mut query = format!("{select_clause} {where_clause}");
283                query.push_str(" GROUP BY p.id, repo_name");
284                query
285            })
286            .collect();
287
288        let combined_query = shard_queries.join("\nUNION ALL\n");
289        let mut final_query = format!("WITH results AS ({combined_query}) SELECT * FROM results");
290
291        if !self.sort_fields.is_empty() {
292            let sort_clauses: Vec<String> = self
293                .sort_fields
294                .iter()
295                .map(|(field, direction)| {
296                    format!(
297                        "{} {}",
298                        field,
299                        match direction {
300                            SortDirection::Asc => "ASC",
301                            SortDirection::Desc => "DESC",
302                        }
303                    )
304                })
305                .collect();
306            final_query.push_str(" ORDER BY ");
307            final_query.push_str(&sort_clauses.join(", "));
308        }
309
310        if let Some(limit) = self.limit {
311            final_query.push_str(" LIMIT ?");
312            params.push(Box::new(limit));
313
314            let offset = self.limit.map(|limit| (self.page - 1) * limit);
315            if let Some(offset) = offset {
316                final_query.push_str(" OFFSET ?");
317                params.push(Box::new(offset));
318            }
319        }
320
321        Ok((final_query, params))
322    }
323
324    fn build_count_query(&self, shards: &[String]) -> (String, Vec<Box<dyn rusqlite::ToSql>>) {
325        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
326
327        let shard_queries: Vec<String> = shards
328            .iter()
329            .map(|shard| {
330                let select_clause = format!(
331                    "SELECT COUNT(1) as cnt, r.name as repo_name FROM {shard}.packages p JOIN {shard}.repository r",
332                );
333
334                let where_clause = self.build_where_clause(&mut params);
335                format!("{select_clause} {where_clause}")
336            })
337            .collect();
338
339        let query = format!(
340            "SELECT SUM(cnt) FROM ({})",
341            shard_queries.join("\nUNION ALL\n")
342        );
343
344        (query, params)
345    }
346
347    pub fn load_installed(&self) -> SoarResult<PaginatedResponse<InstalledPackage>> {
348        let conn = self.db.lock().map_err(|_| SoarError::PoisonError)?;
349        let (query, params) = self.build_installed_query()?;
350        let mut stmt = conn.prepare(&query)?;
351
352        let params_ref: Vec<&dyn rusqlite::ToSql> = params
353            .iter()
354            .map(|p| p.as_ref() as &dyn rusqlite::ToSql)
355            .collect();
356        let items = stmt
357            .query_map(params_ref.as_slice(), InstalledPackage::from_row)?
358            .filter_map(|r| match r {
359                Ok(pkg) => Some(pkg),
360                Err(err) => {
361                    eprintln!("Installed package map error: {err:#?}");
362                    None
363                }
364            })
365            .collect();
366
367        let (count_query, count_params) = {
368            let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
369            let select_clause = "SELECT COUNT(1) FROM packages p";
370            let where_clause = self.build_where_clause(&mut params);
371            let query = format!("{select_clause} {where_clause}");
372            (query, params)
373        };
374        let mut count_stmt = conn.prepare(&count_query)?;
375        let count_params_ref: Vec<&dyn rusqlite::ToSql> = count_params
376            .iter()
377            .map(|p| p.as_ref() as &dyn rusqlite::ToSql)
378            .collect();
379        let total: u64 = count_stmt.query_row(count_params_ref.as_slice(), |row| row.get(0))?;
380
381        let page = self.page;
382        let limit = self.limit;
383
384        let has_next = limit.map_or_else(|| false, |v| (self.page as u64 * v as u64) < total);
385
386        Ok(PaginatedResponse {
387            items,
388            page,
389            limit,
390            total,
391            has_next,
392        })
393    }
394
395    fn build_installed_query(&self) -> SoarResult<(String, Vec<Box<dyn rusqlite::ToSql>>)> {
396        let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
397        let select_clause = "SELECT p.*, pp.* FROM packages p
398            LEFT JOIN portable_package pp
399            ON pp.package_id = p.id";
400        let where_clause = self.build_where_clause(&mut params);
401        let mut query = format!("{select_clause} {where_clause}");
402
403        if !self.sort_fields.is_empty() {
404            let sort_clauses: Vec<String> = self
405                .sort_fields
406                .iter()
407                .map(|(field, direction)| {
408                    format!(
409                        "{} {}",
410                        field,
411                        match direction {
412                            SortDirection::Asc => "ASC",
413                            SortDirection::Desc => "DESC",
414                        }
415                    )
416                })
417                .collect();
418            query.push_str(" ORDER BY ");
419            query.push_str(&sort_clauses.join(", "));
420        }
421
422        if let Some(limit) = self.limit {
423            query.push_str(" LIMIT ?");
424            params.push(Box::new(limit));
425
426            let offset = self.limit.map(|limit| (self.page - 1) * limit);
427            if let Some(offset) = offset {
428                query.push_str(" OFFSET ?");
429                params.push(Box::new(offset));
430            }
431        }
432
433        Ok((query, params))
434    }
435
436    fn build_where_clause(&self, params: &mut Vec<Box<dyn ToSql>>) -> String {
437        if self.filters.is_empty() {
438            return String::new();
439        }
440
441        let conditions: Vec<String> = self
442            .filters
443            .iter()
444            .enumerate()
445            .map(|(idx, filter)| {
446                let condition = match &filter.condition {
447                    FilterCondition::Eq(val) => {
448                        params.push(Box::new(val.clone()));
449                        format!("{} = ?", filter.field)
450                    }
451                    FilterCondition::Ne(val) => {
452                        params.push(Box::new(val.clone()));
453                        format!("{} != ?", filter.field)
454                    }
455                    FilterCondition::Gt(val) => {
456                        params.push(Box::new(val.clone()));
457                        format!("{} > ?", filter.field)
458                    }
459                    FilterCondition::Gte(val) => {
460                        params.push(Box::new(val.clone()));
461                        format!("{} >= ?", filter.field)
462                    }
463                    FilterCondition::Lt(val) => {
464                        params.push(Box::new(val.clone()));
465                        format!("{} < ?", filter.field)
466                    }
467                    FilterCondition::Lte(val) => {
468                        params.push(Box::new(val.clone()));
469                        format!("{} <= ?", filter.field)
470                    }
471                    FilterCondition::Like(val) => {
472                        params.push(Box::new(format!("%{val}%")));
473                        format!("{} LIKE ?", filter.field)
474                    }
475                    FilterCondition::ILike(val) => {
476                        params.push(Box::new(format!("%{val}%")));
477                        format!("LOWER({}) LIKE LOWER(?)", filter.field)
478                    }
479                    FilterCondition::In(vals) => {
480                        let placeholders = vec!["?"; vals.len()].join(", ");
481                        for val in vals {
482                            params.push(Box::new(val.clone()));
483                        }
484                        format!("{} IN ({})", filter.field, placeholders)
485                    }
486                    FilterCondition::NotIn(vals) => {
487                        let placeholders = vec!["?"; vals.len()].join(", ");
488                        for val in vals {
489                            params.push(Box::new(val.clone()));
490                        }
491                        format!("{} NOT IN ({})", filter.field, placeholders)
492                    }
493                    FilterCondition::Between(start, end) => {
494                        params.push(Box::new(start.clone()));
495                        params.push(Box::new(end.clone()));
496                        format!("{} BETWEEN ? AND ?", filter.field)
497                    }
498                    FilterCondition::IsNull => {
499                        format!("{} IS NULL", filter.field)
500                    }
501                    FilterCondition::IsNotNull => {
502                        format!("{} IS NOT NULL", filter.field)
503                    }
504                    FilterCondition::None => filter.field.to_string(),
505                };
506
507                if idx > 0 {
508                    match filter.logical_op {
509                        Some(LogicalOp::And) => format!("AND {condition}"),
510                        Some(LogicalOp::Or) => format!("OR {condition}"),
511                        None => condition,
512                    }
513                } else {
514                    condition
515                }
516            })
517            .collect();
518        format!("WHERE {}", conditions.join(" "))
519    }
520
521    fn build_subquery_where_clause(&self, value: &str, condition: FilterCondition) -> String {
522        match condition {
523            FilterCondition::Eq(val) => {
524                format!("{value} = '{val}'")
525            }
526            FilterCondition::Ne(val) => {
527                format!("{value} != '{val}'")
528            }
529            FilterCondition::Gt(val) => {
530                format!("{value} > '{val}'")
531            }
532            FilterCondition::Gte(val) => {
533                format!("{value} >= '{val}'")
534            }
535            FilterCondition::Lt(val) => {
536                format!("{value} < '{val}'")
537            }
538            FilterCondition::Lte(val) => {
539                format!("{value} <= '{val}'")
540            }
541            FilterCondition::Like(val) => {
542                format!("{value} LIKE '%{val}%'")
543            }
544            FilterCondition::ILike(val) => {
545                format!("LOWER({value}) LIKE LOWER('%{val}%')")
546            }
547            FilterCondition::In(vals) => {
548                format!(
549                    "{} IN ({})",
550                    value,
551                    vals.iter()
552                        .map(|v| format!("'{v}'"))
553                        .collect::<Vec<String>>()
554                        .join(",")
555                )
556            }
557            FilterCondition::NotIn(vals) => {
558                format!(
559                    "{} NOT IN ({})",
560                    value,
561                    vals.iter()
562                        .map(|v| format!("'{v}'"))
563                        .collect::<Vec<String>>()
564                        .join(",")
565                )
566            }
567            FilterCondition::Between(start, end) => {
568                format!("{value} BETWEEN '{start}' AND '{end}'")
569            }
570            FilterCondition::IsNull => {
571                format!("{value} IS NULL")
572            }
573            FilterCondition::IsNotNull => {
574                format!("{value} IS NOT NULL")
575            }
576            FilterCondition::None => String::new(),
577        }
578    }
579}