1mod cast_info;
2mod func_info;
3mod sequence_info;
4mod table_info;
5mod type_info;
6mod view_info;
7
8use std::path::{Path, PathBuf};
9
10use sql_fun_core::SqlFunArgs;
11
12use self::{
13 func_info::FunctionDefinitionCollection, sequence_info::SequenceInfoCollection,
14 table_info::TableDefinitionCollection, type_info::TypeDefinitionCollection,
15 view_info::ViewInformationCollection,
16};
17
18use crate::{
19 StringSpan, TrieMap,
20 sem::{
21 AnalysisError, AnalysisProblem, BaseParseContext, CastDefinition, Comment, CreateExtension,
22 CreateIndex, CreateSchema, ImplicitChange, OperatorDefinition, OperatorInfoRead,
23 OperatorName, ParseContext, SchemaName, SemScalarExpr, SourceAccessService, VariableSet,
24 parse_context::{IndexInfoRead, SearchPathRead},
25 },
26 syn::ObjectType,
27};
28
29#[derive(Debug)]
31pub struct SchemaFileContext<TBaseContext>
32where
33 TBaseContext: BaseParseContext + std::fmt::Debug,
34{
35 base: TBaseContext,
36 types: TypeDefinitionCollection,
37 tables: TableDefinitionCollection,
38 functions: FunctionDefinitionCollection,
39 extensions: TrieMap<CreateExtension>,
40 variables: TrieMap<SemScalarExpr>,
41 schemas: TrieMap<CreateSchema>,
42 sequences: SequenceInfoCollection,
43 comments: TrieMap<(ObjectType, String)>,
44 views: ViewInformationCollection,
45 casts: TrieMap<TrieMap<CastDefinition>>,
46 binary_operators: TrieMap<OperatorDefinition>,
47 left_unary_operators: TrieMap<OperatorDefinition>,
48 right_unary_operators: TrieMap<OperatorDefinition>,
49 indexes: TrieMap<CreateIndex>,
50 statement_span: StringSpan,
51 sql_source_path: PathBuf,
52 source_access_service: Option<Box<dyn SourceAccessService>>,
53 reports: Vec<String>,
54}
55
56impl<TBaseContext> SchemaFileContext<TBaseContext>
57where
58 TBaseContext: BaseParseContext + std::fmt::Debug,
59{
60 pub fn new(sql_source_path: &Path, base: TBaseContext) -> Self {
62 Self {
63 base,
64 sql_source_path: sql_source_path.to_path_buf(),
65 types: TypeDefinitionCollection::default(),
66 tables: TableDefinitionCollection::default(),
67 functions: FunctionDefinitionCollection::default(),
68 extensions: TrieMap::default(),
69 variables: TrieMap::default(),
70 schemas: TrieMap::default(),
71 sequences: SequenceInfoCollection::default(),
72 comments: TrieMap::default(),
73 views: ViewInformationCollection::default(),
74 casts: TrieMap::default(),
75 binary_operators: TrieMap::default(),
76 left_unary_operators: TrieMap::default(),
77 right_unary_operators: TrieMap::default(),
78 statement_span: StringSpan::default(),
79 indexes: TrieMap::default(),
80 source_access_service: None,
81 reports: Default::default(),
82 }
83 }
84
85 pub fn apply_implicit_change(
87 self,
88 implicit_change: ImplicitChange,
89 ) -> Result<Self, AnalysisError> {
90 match implicit_change {
91 ImplicitChange::CreateIndex(create_index) => self.apply_create_index(&create_index),
92 }
93 }
94}
95
96#[cfg(test)]
97mod test_schema_file_context {
98 use std::path::Path;
99
100 use crate::{
101 sem::SchemaFileContext,
102 test_helpers::{TestParseContext, test_context},
103 };
104
105 #[rstest::rstest]
106 fn test_new(test_context: TestParseContext) {
107 let path = Path::new("hoge.sql");
108 let _context = SchemaFileContext::new(&path, test_context);
109 }
110}
111
112impl<TBaseContext> IndexInfoRead for SchemaFileContext<TBaseContext>
113where
114 TBaseContext: BaseParseContext + std::fmt::Debug,
115{
116 fn get_index(&self, index_name: &str) -> Option<&CreateIndex> {
117 if let Some(index) = self.indexes.get(index_name) {
118 Some(index)
119 } else {
120 self.base.get_index(index_name)
121 }
122 }
123}
124
125impl<TBaseContext> OperatorInfoRead for SchemaFileContext<TBaseContext>
126where
127 TBaseContext: BaseParseContext + std::fmt::Debug,
128{
129 fn get_operator_op(
130 &self,
131 name: &OperatorName,
132 has_left: bool,
133 has_right: bool,
134 ) -> Option<&OperatorDefinition> {
135 if !name.is_valid() {
136 return None;
137 }
138
139 let name_str = &name.to_name_string();
140 let Some(ope) = (match (has_left, has_right) {
141 (true, true) => self.binary_operators.get(name_str),
142 (true, false) => self.left_unary_operators.get(name_str),
143 (false, true) => self.right_unary_operators.get(name_str),
144 _ => panic!("accessing no-param operator"),
145 }) else {
146 return self.base.get_operator_op(name, has_left, has_right);
147 };
148 Some(ope)
149 }
150}
151
152impl<TBaseContext> SearchPathRead for SchemaFileContext<TBaseContext>
153where
154 TBaseContext: BaseParseContext + std::fmt::Debug,
155{
156 fn set_search_path(&mut self, search_path: Vec<SchemaName>) {
157 self.base.set_search_path(search_path);
158 }
159
160 fn get_default_schema(&self) -> Option<SchemaName> {
161 self.base.get_default_schema()
162 }
163
164 fn get_search_path(&self) -> &[SchemaName] {
165 self.base.get_search_path()
166 }
167}
168
169impl<TBaseContext> BaseParseContext for SchemaFileContext<TBaseContext>
170where
171 TBaseContext: BaseParseContext + std::fmt::Debug,
172{
173 fn get_args(&self) -> &SqlFunArgs {
174 self.base.get_args()
175 }
176
177 fn get_source_file_path(&self) -> Option<&Path> {
178 Some(self.sql_source_path.as_path())
179 }
180
181 fn report_problem(&mut self, problem: AnalysisProblem) -> Result<(), AnalysisError> {
182 if let Some(source_provider) = &self.source_access_service {
183 let source = source_provider.get_source_span(
184 self.get_args(),
185 self.sql_source_path.as_path(),
186 self.statement_span,
187 )?;
188 let theme = self.base.get_args().highlighter_theme()?;
189
190 let report = problem.render_report(&theme, source)?;
191 eprintln!("{report}");
192 self.reports.push(report)
193 }
194
195 self.base.report_problem(problem)
196 }
197
198 fn extend(&mut self, extension: Box<dyn BaseParseContext>) {
199 self.base.extend(extension);
200 }
201
202 fn set_statement_span(&mut self, statement_span: StringSpan) {
203 self.statement_span = statement_span;
204 }
205
206 fn set_source_file(&mut self, source_access_service: Box<dyn super::SourceAccessService>) {
207 self.source_access_service = Some(source_access_service);
208 }
209}
210
211impl<TBaseContext> ParseContext for SchemaFileContext<TBaseContext>
212where
213 TBaseContext: BaseParseContext + std::fmt::Debug,
214{
215 fn apply_create_extension(
216 mut self,
217 create_extension: &CreateExtension,
218 ) -> Result<Self, AnalysisError> {
219 let name = create_extension.name();
220 self.extensions.insert(name, create_extension.clone());
221 Ok(self)
222 }
223
224 fn apply_variable_set(mut self, variable_set: &VariableSet) -> Result<Self, AnalysisError> {
225 let name = variable_set.name();
226 self.variables.insert(name, variable_set.value().clone());
227 Ok(self)
228 }
229
230 fn apply_create_schema(mut self, create_schema: &CreateSchema) -> Result<Self, AnalysisError> {
231 let name = create_schema.name();
232 self.schemas.insert(name, create_schema.clone());
233 Ok(self)
234 }
235
236 fn apply_comment(mut self, comment: &Comment) -> Result<Self, AnalysisError> {
237 let target = comment.target();
238 self.comments.insert(
239 target,
240 (*comment.object_type(), String::from(comment.comment())),
241 );
242 Ok(self)
243 }
244
245 fn apply_create_index(mut self, create_index: &CreateIndex) -> Result<Self, AnalysisError> {
246 let name = create_index.name();
247 self.indexes.insert(name, create_index.clone());
248 Ok(self)
249 }
250}