1use ahash::AHashMap;
2use lsp_server::{Connection, Message, Request, RequestId, Response};
3use lsp_types::notification::{
4 DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, DidSaveTextDocument,
5 Notification, PublishDiagnostics,
6};
7use lsp_types::request::{Formatting, Request as _};
8use lsp_types::{
9 Diagnostic, DiagnosticSeverity, DidChangeTextDocumentParams, DidCloseTextDocumentParams,
10 DidOpenTextDocumentParams, DidSaveTextDocumentParams, DocumentFormattingParams,
11 InitializeParams, InitializeResult, NumberOrString, OneOf, Position, PublishDiagnosticsParams,
12 Registration, ServerCapabilities, TextDocumentIdentifier, TextDocumentItem,
13 TextDocumentSyncCapability, TextDocumentSyncKind, Uri, VersionedTextDocumentIdentifier,
14};
15use serde_json::Value;
16use sqruff_lib::core::config::FluffConfig;
17use sqruff_lib::core::linter::core::Linter;
18use wasm_bindgen::prelude::*;
19
20#[cfg(not(target_arch = "wasm32"))]
21fn load_config() -> FluffConfig {
22 FluffConfig::from_root(None, false, None).unwrap_or_default()
23}
24
25#[cfg(target_arch = "wasm32")]
26fn load_config() -> FluffConfig {
27 FluffConfig::default()
28}
29
30fn server_initialize_result() -> InitializeResult {
31 InitializeResult {
32 capabilities: ServerCapabilities {
33 text_document_sync: TextDocumentSyncCapability::Kind(TextDocumentSyncKind::FULL).into(),
34 document_formatting_provider: OneOf::Left(true).into(),
35 ..Default::default()
36 },
37 server_info: None,
38 }
39}
40
41pub struct LanguageServer {
42 linter: Linter,
43 send_diagnostics_callback: Box<dyn Fn(PublishDiagnosticsParams)>,
44 documents: AHashMap<Uri, String>,
45}
46
47#[wasm_bindgen]
48pub struct Wasm(LanguageServer);
49
50#[wasm_bindgen]
51impl Wasm {
52 #[wasm_bindgen(constructor)]
53 pub fn new(send_diagnostics_callback: js_sys::Function) -> Self {
54 console_error_panic_hook::set_once();
55
56 let send_diagnostics_callback = Box::leak(Box::new(send_diagnostics_callback));
57
58 Self(LanguageServer::new(|diagnostics| {
59 let diagnostics = serde_wasm_bindgen::to_value(&diagnostics).unwrap();
60 send_diagnostics_callback
61 .call1(&JsValue::null(), &diagnostics)
62 .unwrap();
63 }))
64 }
65
66 #[wasm_bindgen(js_name = saveRegistrationOptions)]
67 pub fn save_registration_options() -> JsValue {
68 serde_wasm_bindgen::to_value(&save_registration_options()).unwrap()
69 }
70
71 #[wasm_bindgen(js_name = updateConfig)]
72 pub fn update_config(&mut self, source: &str) {
73 *self.0.linter.config_mut() = FluffConfig::from_source(source, None);
74 self.0.recheck_files();
75 }
76
77 #[wasm_bindgen(js_name = onInitialize)]
78 pub fn on_initialize(&self) -> JsValue {
79 serde_wasm_bindgen::to_value(&server_initialize_result()).unwrap()
80 }
81
82 #[wasm_bindgen(js_name = onNotification)]
83 pub fn on_notification(&mut self, method: &str, params: JsValue) {
84 self.0
85 .on_notification(method, serde_wasm_bindgen::from_value(params).unwrap())
86 }
87
88 #[wasm_bindgen]
89 pub fn format(&mut self, uri: JsValue) -> JsValue {
90 let uri = serde_wasm_bindgen::from_value(uri).unwrap();
91 let edits = self.0.format(uri);
92 serde_wasm_bindgen::to_value(&edits).unwrap()
93 }
94
95 #[wasm_bindgen(js_name = formatSource)]
96 pub fn format_source(&mut self, source: &str) -> String {
97 self.0.format_source(source)
98 }
99}
100
101impl LanguageServer {
102 pub fn new(send_diagnostics_callback: impl Fn(PublishDiagnosticsParams) + 'static) -> Self {
103 Self {
104 linter: Linter::new(load_config(), None, None, false),
105 send_diagnostics_callback: Box::new(send_diagnostics_callback),
106 documents: AHashMap::new(),
107 }
108 }
109
110 fn on_request(&mut self, id: RequestId, method: &str, params: Value) -> Option<Response> {
111 match method {
112 Formatting::METHOD => {
113 let DocumentFormattingParams {
114 text_document: TextDocumentIdentifier { uri },
115 ..
116 } = serde_json::from_value(params).unwrap();
117
118 let edits = self.format(uri);
119 Some(Response::new_ok(id, edits))
120 }
121 _ => None,
122 }
123 }
124
125 fn format(&mut self, uri: Uri) -> Vec<lsp_types::TextEdit> {
126 let text = self.documents.get(&uri).cloned().unwrap();
127 let new_text = self.format_source(&text);
128 self.documents.insert(uri.clone(), new_text.clone());
129 Self::build_edits(new_text)
130 }
131
132 fn format_source(&mut self, source: &str) -> String {
133 let tree = self.linter.lint_string(source, None, true);
134 tree.fix_string()
135 }
136
137 fn build_edits(new_text: String) -> Vec<lsp_types::TextEdit> {
138 let start_position = Position {
139 line: 0,
140 character: 0,
141 };
142 let end_position = Position {
143 line: new_text.lines().count() as u32,
144 character: new_text.chars().count() as u32,
145 };
146
147 vec![lsp_types::TextEdit {
148 range: lsp_types::Range::new(start_position, end_position),
149 new_text,
150 }]
151 }
152
153 pub fn on_notification(&mut self, method: &str, params: Value) {
154 match method {
155 DidOpenTextDocument::METHOD => {
156 let params: DidOpenTextDocumentParams = serde_json::from_value(params).unwrap();
157 let TextDocumentItem {
158 uri,
159 language_id: _,
160 version: _,
161 text,
162 } = params.text_document;
163
164 self.check_file(uri.clone(), &text);
165 self.documents.insert(uri, text);
166 }
167 DidChangeTextDocument::METHOD => {
168 let params: DidChangeTextDocumentParams = serde_json::from_value(params).unwrap();
169
170 let content = params.content_changes[0].text.clone();
171 let VersionedTextDocumentIdentifier { uri, version: _ } = params.text_document;
172
173 self.check_file(uri.clone(), &content);
174 self.documents.insert(uri, content);
175 }
176 DidCloseTextDocument::METHOD => {
177 let params: DidCloseTextDocumentParams = serde_json::from_value(params).unwrap();
178 self.documents.remove(¶ms.text_document.uri);
179 }
180 DidSaveTextDocument::METHOD => {
181 let params: DidSaveTextDocumentParams = serde_json::from_value(params).unwrap();
182 let uri = params.text_document.uri.as_str();
183
184 if uri.ends_with(".sqlfluff") || uri.ends_with(".sqruff") {
185 *self.linter.config_mut() = load_config();
186
187 self.recheck_files();
188 }
189 }
190 _ => {}
191 }
192 }
193
194 fn recheck_files(&mut self) {
195 for (uri, text) in self.documents.iter() {
196 self.check_file(uri.clone(), text);
197 }
198 }
199
200 fn check_file(&self, uri: Uri, text: &str) {
201 let result = self.linter.lint_string(text, None, false);
202
203 let diagnostics = result
204 .into_violations()
205 .into_iter()
206 .map(|violation| {
207 let range = {
208 let pos = Position::new(
209 (violation.line_no as u32).saturating_sub(1),
210 (violation.line_pos as u32).saturating_sub(1),
211 );
212 lsp_types::Range::new(pos, pos)
213 };
214
215 let code = violation
216 .rule
217 .map(|rule| NumberOrString::String(rule.code.to_string()));
218
219 Diagnostic::new(
220 range,
221 DiagnosticSeverity::WARNING.into(),
222 code,
223 Some("sqruff".to_string()),
224 violation.description,
225 None,
226 None,
227 )
228 })
229 .collect();
230
231 let diagnostics = PublishDiagnosticsParams::new(uri.clone(), diagnostics, None);
232 (self.send_diagnostics_callback)(diagnostics);
233 }
234}
235
236pub fn run() {
237 let (connection, io_threads) = Connection::stdio();
238 let (id, params) = connection.initialize_start().unwrap();
239
240 let init_param: InitializeParams = serde_json::from_value(params).unwrap();
241 let initialize_result = serde_json::to_value(server_initialize_result()).unwrap();
242 connection.initialize_finish(id, initialize_result).unwrap();
243
244 main_loop(connection, init_param);
245
246 io_threads.join().unwrap();
247}
248
249fn main_loop(connection: Connection, _init_param: InitializeParams) {
250 let sender = connection.sender.clone();
251 let mut lsp = LanguageServer::new(move |diagnostics| {
252 let notification = new_notification::<PublishDiagnostics>(diagnostics);
253 sender.send(Message::Notification(notification)).unwrap();
254 });
255
256 let params = save_registration_options();
257 connection
258 .sender
259 .send(Message::Request(Request::new(
260 "textDocument-didSave".to_owned().into(),
261 "client/registerCapability".to_owned(),
262 params,
263 )))
264 .unwrap();
265
266 for message in &connection.receiver {
267 match message {
268 Message::Request(request) => {
269 if connection.handle_shutdown(&request).unwrap() {
270 return;
271 }
272
273 if let Some(response) = lsp.on_request(request.id, &request.method, request.params)
274 {
275 connection.sender.send(Message::Response(response)).unwrap();
276 }
277 }
278 Message::Response(_) => {}
279 Message::Notification(notification) => {
280 lsp.on_notification(¬ification.method, notification.params);
281 }
282 }
283 }
284}
285
286pub fn save_registration_options() -> lsp_types::RegistrationParams {
287 let save_registration_options = lsp_types::TextDocumentSaveRegistrationOptions {
288 include_text: false.into(),
289 text_document_registration_options: lsp_types::TextDocumentRegistrationOptions {
290 document_selector: Some(vec![
291 lsp_types::DocumentFilter {
292 language: None,
293 scheme: None,
294 pattern: Some("**/.sqlfluff".into()),
295 },
296 lsp_types::DocumentFilter {
297 language: None,
298 scheme: None,
299 pattern: Some("**/.sqruff".into()),
300 },
301 ]),
302 },
303 };
304
305 lsp_types::RegistrationParams {
306 registrations: vec![Registration {
307 id: "textDocument/didSave".into(),
308 method: "textDocument/didSave".into(),
309 register_options: serde_json::to_value(save_registration_options)
310 .unwrap()
311 .into(),
312 }],
313 }
314}
315
316fn new_notification<T>(params: T::Params) -> lsp_server::Notification
317where
318 T: Notification,
319{
320 lsp_server::Notification {
321 method: T::METHOD.to_owned(),
322 params: serde_json::to_value(¶ms).unwrap(),
323 }
324}