Skip to main content

sql_fun_sqlast/sem/base_context/
builtin_casts.rs

1use std::{
2    collections::{HashMap, HashSet},
3    path::Path,
4};
5
6use crate::{
7    TrieMap,
8    sem::{AnalysisError, CastContext, CastDefinition, CastInfoRead, FullName, TypeReference},
9};
10
11pub struct PgBuiltinCasts {}
12
13#[expect(dead_code)]
14#[derive(Debug, serde::Deserialize)]
15pub struct CastInfo {
16    pub func: String,
17    pub method: String,
18    pub context: String,
19}
20
21// alias: source_type -> target_type -> CastInfo
22pub type CastMap = HashMap<String, HashMap<String, CastInfo>>;
23
24impl PgBuiltinCasts {
25    pub(super) fn load<P: AsRef<Path>>(
26        path: P,
27    ) -> Result<TrieMap<TrieMap<CastDefinition>>, AnalysisError> {
28        let casts: CastMap = super::deserialize_from_path(path)?;
29
30        let mut result = TrieMap::new();
31        for (source, m) in casts {
32            let mut group = TrieMap::new();
33            for (target, ci) in m {
34                let cc = match ci.context.as_str() {
35                    "i" => CastContext::Implicit,
36                    "a" => CastContext::Assignment,
37                    "e" => CastContext::Explicit,
38                    e => panic!("unexpected cast context {e}"),
39                };
40                let def = CastDefinition::new(cc);
41                group.insert(&target, def);
42            }
43            result.insert(&source, group);
44        }
45        Ok(result)
46    }
47}
48
49impl CastInfoRead for super::BaseContext {
50    fn get_explicit_cast(
51        &self,
52        source_type: &TypeReference,
53        target_type: &TypeReference,
54    ) -> Option<super::CastDefinition> {
55        // TODO: normalize / fix search path
56        let source_local_name = source_type.full_name().local_name().clone();
57        let target_local_name = target_type.full_name().local_name().clone();
58        if source_local_name == target_local_name {
59            return Some(CastDefinition::new(CastContext::NoConversion));
60        }
61
62        let from_source = self.casts.get(&source_local_name)?;
63        let cast_def = from_source.get(&target_local_name)?;
64        if cast_def.use_in_explicit() {
65            Some(cast_def.clone())
66        } else {
67            None
68        }
69    }
70
71    fn get_implicit_cast(
72        &self,
73        source_type: &TypeReference,
74        target_type: &TypeReference,
75    ) -> Option<super::CastDefinition> {
76        // TODO: normalize / fix search path
77        let source_local_name = source_type.full_name().local_name().clone();
78        let target_local_name = target_type.full_name().local_name().clone();
79        if source_local_name == target_local_name {
80            return Some(CastDefinition::new(CastContext::NoConversion));
81        }
82
83        let from_source = self.casts.get(&source_local_name)?;
84        let cast_def = from_source.get(&target_local_name)?;
85
86        if cast_def.use_in_implicit() {
87            Some(cast_def.clone())
88        } else {
89            None
90        }
91    }
92
93    fn get_implicit_castable(&self, source_type: &TypeReference) -> HashSet<TypeReference> {
94        use std::str::FromStr;
95
96        let source_local_name = source_type.full_name().local_name().clone();
97        let Some(from_source) = &self.casts.get(&source_local_name) else {
98            return HashSet::default();
99        };
100        let mut results = HashSet::new();
101        results.insert(source_type.clone()); // source_type -> source_type : NoConversion is always ok
102
103        for (k, _) in from_source.iter() {
104            results.insert(TypeReference::from_full_name(
105                FullName::from_str(k.as_str()).expect("FullName from_str for builtin type"),
106            ));
107        }
108        results
109    }
110}