Skip to main content

sql_fun_sqlast/sem/
schema_file_context.rs

1mod cast_info;
2mod func_info;
3mod sequence_info;
4mod table_info;
5mod type_info;
6mod view_info;
7
8use std::path::{Path, PathBuf};
9
10use sql_fun_core::SqlFunArgs;
11
12use self::{
13    func_info::FunctionDefinitionCollection, sequence_info::SequenceInfoCollection,
14    table_info::TableDefinitionCollection, type_info::TypeDefinitionCollection,
15    view_info::ViewInformationCollection,
16};
17
18use crate::{
19    StringSpan, TrieMap,
20    sem::{
21        AnalysisError, AnalysisProblem, BaseParseContext, CastDefinition, Comment, CreateExtension,
22        CreateIndex, CreateSchema, ImplicitChange, OperatorDefinition, OperatorInfoRead,
23        OperatorName, ParseContext, SchemaName, SemScalarExpr, SourceAccessService, VariableSet,
24        parse_context::{IndexInfoRead, SearchPathRead},
25    },
26    syn::ObjectType,
27};
28
29/// schema file based context
30#[derive(Debug)]
31pub struct SchemaFileContext<TBaseContext>
32where
33    TBaseContext: BaseParseContext + std::fmt::Debug,
34{
35    base: TBaseContext,
36    types: TypeDefinitionCollection,
37    tables: TableDefinitionCollection,
38    functions: FunctionDefinitionCollection,
39    extensions: TrieMap<CreateExtension>,
40    variables: TrieMap<SemScalarExpr>,
41    schemas: TrieMap<CreateSchema>,
42    sequences: SequenceInfoCollection,
43    comments: TrieMap<(ObjectType, String)>,
44    views: ViewInformationCollection,
45    casts: TrieMap<TrieMap<CastDefinition>>,
46    binary_operators: TrieMap<OperatorDefinition>,
47    left_unary_operators: TrieMap<OperatorDefinition>,
48    right_unary_operators: TrieMap<OperatorDefinition>,
49    indexes: TrieMap<CreateIndex>,
50    statement_span: StringSpan,
51    sql_source_path: PathBuf,
52    source_access_service: Option<Box<dyn SourceAccessService>>,
53    reports: Vec<String>,
54}
55
56impl<TBaseContext> SchemaFileContext<TBaseContext>
57where
58    TBaseContext: BaseParseContext + std::fmt::Debug,
59{
60    /// create context for loading schema file
61    pub fn new(sql_source_path: &Path, base: TBaseContext) -> Self {
62        Self {
63            base,
64            sql_source_path: sql_source_path.to_path_buf(),
65            types: TypeDefinitionCollection::default(),
66            tables: TableDefinitionCollection::default(),
67            functions: FunctionDefinitionCollection::default(),
68            extensions: TrieMap::default(),
69            variables: TrieMap::default(),
70            schemas: TrieMap::default(),
71            sequences: SequenceInfoCollection::default(),
72            comments: TrieMap::default(),
73            views: ViewInformationCollection::default(),
74            casts: TrieMap::default(),
75            binary_operators: TrieMap::default(),
76            left_unary_operators: TrieMap::default(),
77            right_unary_operators: TrieMap::default(),
78            statement_span: StringSpan::default(),
79            indexes: TrieMap::default(),
80            source_access_service: None,
81            reports: Default::default(),
82        }
83    }
84
85    /// apply index from constraint
86    pub fn apply_implicit_change(
87        self,
88        implicit_change: ImplicitChange,
89    ) -> Result<Self, AnalysisError> {
90        match implicit_change {
91            ImplicitChange::CreateIndex(create_index) => self.apply_create_index(&create_index),
92        }
93    }
94}
95
96#[cfg(test)]
97mod test_schema_file_context {
98    use std::path::Path;
99
100    use crate::{
101        sem::SchemaFileContext,
102        test_helpers::{TestParseContext, test_context},
103    };
104
105    #[rstest::rstest]
106    fn test_new(test_context: TestParseContext) {
107        let path = Path::new("hoge.sql");
108        let _context = SchemaFileContext::new(&path, test_context);
109    }
110}
111
112impl<TBaseContext> IndexInfoRead for SchemaFileContext<TBaseContext>
113where
114    TBaseContext: BaseParseContext + std::fmt::Debug,
115{
116    fn get_index(&self, index_name: &str) -> Option<&CreateIndex> {
117        if let Some(index) = self.indexes.get(index_name) {
118            Some(index)
119        } else {
120            self.base.get_index(index_name)
121        }
122    }
123}
124
125impl<TBaseContext> OperatorInfoRead for SchemaFileContext<TBaseContext>
126where
127    TBaseContext: BaseParseContext + std::fmt::Debug,
128{
129    fn get_operator_op(
130        &self,
131        name: &OperatorName,
132        has_left: bool,
133        has_right: bool,
134    ) -> Option<&OperatorDefinition> {
135        if !name.is_valid() {
136            return None;
137        }
138
139        let name_str = &name.to_name_string();
140        let Some(ope) = (match (has_left, has_right) {
141            (true, true) => self.binary_operators.get(name_str),
142            (true, false) => self.left_unary_operators.get(name_str),
143            (false, true) => self.right_unary_operators.get(name_str),
144            _ => panic!("accessing no-param operator"),
145        }) else {
146            return self.base.get_operator_op(name, has_left, has_right);
147        };
148        Some(ope)
149    }
150}
151
152impl<TBaseContext> SearchPathRead for SchemaFileContext<TBaseContext>
153where
154    TBaseContext: BaseParseContext + std::fmt::Debug,
155{
156    fn set_search_path(&mut self, search_path: Vec<SchemaName>) {
157        self.base.set_search_path(search_path);
158    }
159
160    fn get_default_schema(&self) -> Option<SchemaName> {
161        self.base.get_default_schema()
162    }
163
164    fn get_search_path(&self) -> &[SchemaName] {
165        self.base.get_search_path()
166    }
167}
168
169impl<TBaseContext> BaseParseContext for SchemaFileContext<TBaseContext>
170where
171    TBaseContext: BaseParseContext + std::fmt::Debug,
172{
173    fn get_args(&self) -> &SqlFunArgs {
174        self.base.get_args()
175    }
176
177    fn get_source_file_path(&self) -> Option<&Path> {
178        Some(self.sql_source_path.as_path())
179    }
180
181    fn report_problem(&mut self, problem: AnalysisProblem) -> Result<(), AnalysisError> {
182        if let Some(source_provider) = &self.source_access_service {
183            let source = source_provider.get_source_span(
184                self.get_args(),
185                self.sql_source_path.as_path(),
186                self.statement_span,
187            )?;
188            let theme = self.base.get_args().highlighter_theme()?;
189
190            let report = problem.render_report(&theme, source)?;
191            eprintln!("{report}");
192            self.reports.push(report)
193        }
194
195        self.base.report_problem(problem)
196    }
197
198    fn extend(&mut self, extension: Box<dyn BaseParseContext>) {
199        self.base.extend(extension);
200    }
201
202    fn set_statement_span(&mut self, statement_span: StringSpan) {
203        self.statement_span = statement_span;
204    }
205
206    fn set_source_file(&mut self, source_access_service: Box<dyn super::SourceAccessService>) {
207        self.source_access_service = Some(source_access_service);
208    }
209}
210
211impl<TBaseContext> ParseContext for SchemaFileContext<TBaseContext>
212where
213    TBaseContext: BaseParseContext + std::fmt::Debug,
214{
215    fn apply_create_extension(
216        mut self,
217        create_extension: &CreateExtension,
218    ) -> Result<Self, AnalysisError> {
219        let name = create_extension.name();
220        self.extensions.insert(name, create_extension.clone());
221        Ok(self)
222    }
223
224    fn apply_variable_set(mut self, variable_set: &VariableSet) -> Result<Self, AnalysisError> {
225        let name = variable_set.name();
226        self.variables.insert(name, variable_set.value().clone());
227        Ok(self)
228    }
229
230    fn apply_create_schema(mut self, create_schema: &CreateSchema) -> Result<Self, AnalysisError> {
231        let name = create_schema.name();
232        self.schemas.insert(name, create_schema.clone());
233        Ok(self)
234    }
235
236    fn apply_comment(mut self, comment: &Comment) -> Result<Self, AnalysisError> {
237        let target = comment.target();
238        self.comments.insert(
239            target,
240            (*comment.object_type(), String::from(comment.comment())),
241        );
242        Ok(self)
243    }
244
245    fn apply_create_index(mut self, create_index: &CreateIndex) -> Result<Self, AnalysisError> {
246        let name = create_index.name();
247        self.indexes.insert(name, create_index.clone());
248        Ok(self)
249    }
250}