1use crate::config;
2use crate::template;
3use console::{style, Style};
4
5static BASE_TEST: &str = "-- Test file
6SELECT 1;
7";
8
9use similar::{ChangeTag, TextDiff};
10use std::fmt;
11use std::str;
12use std::sync::{Arc, Mutex};
13
14use anyhow::{Context, Result};
15
16pub struct Tester {
17 config: config::Config,
18 script_path: String,
19}
20
21#[derive(Debug)]
22pub struct TestOutcome {
23 pub diff: Option<String>,
24}
25
26impl Tester {
27 pub fn new(config: &config::Config, script_path: &str) -> Self {
28 Tester {
29 config: config.clone(),
30 script_path: script_path.to_string(),
31 }
32 }
33
34 pub fn test_folder(&self) -> String {
35 let mut s = self.config.pather().tests_folder();
36 s.push('/');
37 s.push_str(&self.script_path);
38 s
39 }
40
41 pub fn test_file_path(&self) -> String {
42 let mut s = self.test_folder();
43 s.push_str("/test.sql");
44 s
45 }
46
47 pub fn expected_file_path(&self) -> String {
48 format!("{}/expected", self.test_folder())
49 }
50
51 pub async fn generate(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
54 let lock_file = None;
55
56 let gen = template::generate_streaming(
57 &self.config,
58 lock_file,
59 &self.test_file_path(),
60 variables,
61 )
62 .await?;
63
64 let mut buffer = Vec::new();
65 gen.render_to_writer(&mut buffer)
66 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
67 let content = String::from_utf8(buffer)?;
68
69 Ok(content)
70 }
71
72 pub async fn run(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
74 let content = self.generate(variables.clone()).await?;
75
76 let engine = self.config.new_engine().await?;
77
78 let stdout_buf = Arc::new(Mutex::new(Vec::new()));
80 let stdout_buf_clone = stdout_buf.clone();
81
82 engine
83 .execute_with_writer(
84 Box::new(move |writer| {
85 writer.write_all(content.as_bytes())?;
86 Ok(())
87 }),
88 Some(Box::new(SharedBufWriter(stdout_buf_clone))),
89 )
90 .await
91 .context("failed to write content to test db")?;
92
93 let buf = stdout_buf.lock().unwrap();
94 let generated = String::from_utf8_lossy(&buf).to_string();
95
96 Ok(generated)
97 }
98
99 pub async fn run_compare(
100 &self,
101 variables: Option<crate::variables::Variables>,
102 ) -> Result<TestOutcome> {
103 let generated = self.run(variables).await?;
104 let expected_bytes = self
105 .config
106 .operator()
107 .read(&self.expected_file_path())
108 .await
109 .context("unable to read expectations file")?
110 .to_bytes();
111 let expected = String::from_utf8(expected_bytes.to_vec())
112 .context("expected file is not valid UTF-8")?;
113
114 let outcome = match self.compare(&generated, &expected) {
115 Ok(()) => TestOutcome { diff: None },
116 Err(differences) => TestOutcome {
117 diff: Some(differences.to_string()),
118 },
119 };
120
121 Ok(outcome)
122 }
123
124 pub async fn save_expected(
125 &self,
126 variables: Option<crate::variables::Variables>,
127 ) -> Result<()> {
128 let content = self.run(variables).await?;
129 self.config
130 .operator()
131 .write(&self.expected_file_path(), content)
132 .await
133 .context("unable to write expectation file")?;
134
135 Ok(())
136 }
137
138 pub async fn create_test(&self) -> Result<String> {
140 let script_path = self.test_file_path();
141 println!("creating test at {}", &script_path);
142 self.config
143 .operator()
144 .write(&script_path, BASE_TEST)
145 .await?;
146
147 Ok(self.script_path.clone())
148 }
149
150 pub fn compare(&self, generated: &str, expected: &str) -> std::result::Result<(), String> {
151 let diff = TextDiff::from_lines(expected, generated);
152
153 let mut diff_display = String::new();
154
155 for (idx, group) in diff.grouped_ops(3).iter().enumerate() {
156 if idx > 0 {
157 diff_display.push_str(&format!("{:-^1$}", "-", 80));
158 }
159 for op in group {
160 for change in diff.iter_inline_changes(op) {
161 let (sign, s) = match change.tag() {
162 ChangeTag::Delete => ("-", Style::new().red()),
163 ChangeTag::Insert => ("+", Style::new().green()),
164 ChangeTag::Equal => (" ", Style::new().dim()),
165 };
166 diff_display.push_str(&format!(
167 "{}{} |{}",
168 style(Line(change.old_index())).dim(),
169 style(Line(change.new_index())).dim(),
170 s.apply_to(sign).bold(),
171 ));
172 for (emphasized, value) in change.iter_strings_lossy() {
173 if emphasized {
174 diff_display.push_str(&format!(
175 "{}",
176 s.apply_to(value).underlined().on_black()
177 ));
178 } else {
179 diff_display.push_str(&format!("{}", s.apply_to(value)));
180 }
181 }
182 if change.missing_newline() {
183 diff_display.push('\n');
184 }
185 }
186 }
187 }
188
189 if !diff_display.is_empty() {
190 return Err(diff_display);
191 }
192
193 Ok(())
194 }
195}
196
197struct Line(Option<usize>);
198
199impl fmt::Display for Line {
200 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
201 match self.0 {
202 None => write!(f, " "),
203 Some(idx) => write!(f, "{:<4}", idx + 1),
204 }
205 }
206}
207
208struct SharedBufWriter(Arc<Mutex<Vec<u8>>>);
210
211impl tokio::io::AsyncWrite for SharedBufWriter {
212 fn poll_write(
213 self: std::pin::Pin<&mut Self>,
214 _cx: &mut std::task::Context<'_>,
215 buf: &[u8],
216 ) -> std::task::Poll<std::io::Result<usize>> {
217 self.0.lock().unwrap().extend_from_slice(buf);
218 std::task::Poll::Ready(Ok(buf.len()))
219 }
220
221 fn poll_flush(
222 self: std::pin::Pin<&mut Self>,
223 _cx: &mut std::task::Context<'_>,
224 ) -> std::task::Poll<std::io::Result<()>> {
225 std::task::Poll::Ready(Ok(()))
226 }
227
228 fn poll_shutdown(
229 self: std::pin::Pin<&mut Self>,
230 _cx: &mut std::task::Context<'_>,
231 ) -> std::task::Poll<std::io::Result<()>> {
232 std::task::Poll::Ready(Ok(()))
233 }
234}