1use crate::config;
2use crate::engine::EngineType;
3use crate::escape::EscapedIdentifier;
4use crate::store::pinner::latest::Latest;
5use crate::store::pinner::spawn::Spawn;
6use crate::store::pinner::Pinner;
7use crate::store::Store;
8use crate::variables::Variables;
9use minijinja::{Environment, Value};
10
11use crate::sql_formatter::SqlDialect;
12use base64::{engine::general_purpose::STANDARD, Engine as _};
13use uuid::Uuid;
14
15use anyhow::{Context, Result};
16use minijinja::context;
17use std::sync::Arc;
18
19fn engine_to_dialect(engine: &EngineType) -> SqlDialect {
25 match engine {
26 EngineType::PostgresPSQL => SqlDialect::Postgres,
27 }
32}
33
34pub fn template_env(store: Store, engine: &EngineType) -> Result<Environment<'static>> {
35 let mut env = Environment::new();
36
37 let store = Arc::new(store);
38
39 let mj_store = MiniJinjaLoader {
40 store: Arc::clone(&store),
41 };
42 env.set_loader(move |name: &str| mj_store.load(name));
43 env.add_function("gen_uuid_v4", gen_uuid_v4);
44 env.add_function("gen_uuid_v5", gen_uuid_v5);
45 env.add_filter("escape_identifier", escape_identifier_filter);
46
47 let read_file_store = Arc::clone(&store);
48 env.add_filter(
49 "read_file",
50 move |path: &str| -> Result<Value, minijinja::Error> {
51 read_file_filter(path, &read_file_store)
52 },
53 );
54 env.add_filter("base64_encode", base64_encode_filter);
55 env.add_filter("to_string_lossy", to_string_lossy_filter);
56 env.add_filter("parse_json", parse_json_filter);
57 env.add_filter("parse_toml", parse_toml_filter);
58 env.add_filter("parse_yaml", parse_yaml_filter);
59
60 let read_json_store = Arc::clone(&store);
61 env.add_filter(
62 "read_json",
63 move |path: &str| -> Result<Value, minijinja::Error> {
64 let bytes = read_file_bytes(path, &read_json_store)?;
65 let s = string_from_bytes(&bytes)?;
66 parse_json_filter(&s)
67 },
68 );
69 let read_toml_store = Arc::clone(&store);
70 env.add_filter(
71 "read_toml",
72 move |path: &str| -> Result<Value, minijinja::Error> {
73 let bytes = read_file_bytes(path, &read_toml_store)?;
74 let s = string_from_bytes(&bytes)?;
75 parse_toml_filter(&s)
76 },
77 );
78 let read_yaml_store = Arc::clone(&store);
79 env.add_filter(
80 "read_yaml",
81 move |path: &str| -> Result<Value, minijinja::Error> {
82 let bytes = read_file_bytes(path, &read_yaml_store)?;
83 let s = string_from_bytes(&bytes)?;
84 parse_yaml_filter(&s)
85 },
86 );
87
88 let dialect = engine_to_dialect(engine);
90
91 env.set_auto_escape_callback(crate::sql_formatter::get_auto_escape_callback(dialect));
93
94 env.set_formatter(crate::sql_formatter::get_formatter(dialect));
96
97 Ok(env)
98}
99
100struct MiniJinjaLoader {
101 pub store: Arc<Store>,
102}
103
104impl MiniJinjaLoader {
105 pub fn load(&self, name: &str) -> std::result::Result<Option<String>, minijinja::Error> {
106 let result = tokio::task::block_in_place(|| {
107 tokio::runtime::Handle::current()
108 .block_on(async { self.store.load_component(name).await })
109 });
110
111 result.map_err(|e| {
112 minijinja::Error::new(
113 minijinja::ErrorKind::InvalidOperation,
114 format!("Failed to load from object store: {}", e),
115 )
116 })
117 }
118}
119
120fn gen_uuid_v4() -> Result<String, minijinja::Error> {
121 Ok(Uuid::new_v4().to_string())
122}
123
124fn gen_uuid_v5(seed: &str) -> Result<String, minijinja::Error> {
125 Ok(Uuid::new_v5(&Uuid::NAMESPACE_DNS, seed.as_bytes()).to_string())
126}
127
128fn escape_identifier_filter(value: &Value) -> Result<Value, minijinja::Error> {
135 let s = value.to_string();
136 let escaped = EscapedIdentifier::new(&s);
137 Ok(Value::from_safe_string(escaped.to_string()))
139}
140
141fn read_file_bytes(path: &str, store: &Arc<Store>) -> Result<Vec<u8>, minijinja::Error> {
143 let bytes = tokio::task::block_in_place(|| {
144 tokio::runtime::Handle::current().block_on(async { store.read_file_bytes(path).await })
145 });
146
147 bytes.map_err(|e| {
148 minijinja::Error::new(
149 minijinja::ErrorKind::InvalidOperation,
150 format!("Failed to read file '{}': {}", path, e),
151 )
152 })
153}
154
155fn string_from_bytes(bytes: &[u8]) -> Result<String, minijinja::Error> {
157 String::from_utf8(bytes.to_vec()).map_err(|e| {
158 minijinja::Error::new(
159 minijinja::ErrorKind::InvalidOperation,
160 format!("File is not valid UTF-8: {}", e),
161 )
162 })
163}
164
165fn read_file_filter(path: &str, store: &Arc<Store>) -> Result<Value, minijinja::Error> {
171 Ok(Value::from_bytes(read_file_bytes(path, store)?))
172}
173
174fn base64_encode_filter(value: &Value) -> Result<Value, minijinja::Error> {
180 use minijinja::value::ValueKind;
181 match value.kind() {
182 ValueKind::Bytes => {
183 let bytes = value.as_bytes().unwrap();
184 Ok(Value::from(STANDARD.encode(bytes)))
185 }
186 ValueKind::String => Ok(Value::from(STANDARD.encode(value.as_str().unwrap()))),
187 _ => Err(minijinja::Error::new(
188 minijinja::ErrorKind::InvalidOperation,
189 "base64_encode filter expects bytes or string input",
190 )),
191 }
192}
193
194fn to_string_lossy_filter(value: &Value) -> Result<Value, minijinja::Error> {
199 use minijinja::value::ValueKind;
200 match value.kind() {
201 ValueKind::Bytes => {
202 let bytes = value.as_bytes().unwrap();
203 Ok(Value::from(String::from_utf8_lossy(bytes).into_owned()))
204 }
205 ValueKind::String => Ok(value.clone()),
206 _ => Err(minijinja::Error::new(
207 minijinja::ErrorKind::InvalidOperation,
208 "to_string_lossy filter expects bytes or string input",
209 )),
210 }
211}
212
213fn parse_json_filter(value: &str) -> Result<Value, minijinja::Error> {
217 let vars = Variables::from_str("json", value).map_err(|e| {
218 minijinja::Error::new(
219 minijinja::ErrorKind::InvalidOperation,
220 format!("parse_json: {}", e),
221 )
222 })?;
223 Ok(Value::from_serialize(&vars))
224}
225
226fn parse_toml_filter(value: &str) -> Result<Value, minijinja::Error> {
230 let vars = Variables::from_str("toml", value).map_err(|e| {
231 minijinja::Error::new(
232 minijinja::ErrorKind::InvalidOperation,
233 format!("parse_toml: {}", e),
234 )
235 })?;
236 Ok(Value::from_serialize(&vars))
237}
238
239fn parse_yaml_filter(value: &str) -> Result<Value, minijinja::Error> {
243 let vars = Variables::from_str("yaml", value).map_err(|e| {
244 minijinja::Error::new(
245 minijinja::ErrorKind::InvalidOperation,
246 format!("parse_yaml: {}", e),
247 )
248 })?;
249 Ok(Value::from_serialize(&vars))
250}
251
252pub struct Generation {
253 pub content: String,
254}
255
256pub struct StreamingGeneration {
259 store: Store,
260 template_contents: String,
261 environment: String,
262 variables: Variables,
263 engine: EngineType,
264}
265
266impl StreamingGeneration {
267 pub fn render_to_writer<W: std::io::Write + ?Sized>(self, writer: &mut W) -> Result<()> {
270 let mut env = template_env(self.store, &self.engine)?;
271 env.add_template("migration.sql", &self.template_contents)?;
272 let tmpl = env.get_template("migration.sql")?;
273 tmpl.render_to_write(
274 context!(env => self.environment, variables => self.variables),
275 writer,
276 )?;
277 Ok(())
278 }
279
280 pub fn into_writer_fn(self) -> crate::engine::WriterFn {
282 Box::new(move |writer: &mut dyn std::io::Write| {
283 self.render_to_writer(writer)
284 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
285 })
286 }
287}
288
289pub async fn generate_streaming(
292 cfg: &config::Config,
293 lock_file: Option<String>,
294 name: &str,
295 variables: Option<Variables>,
296) -> Result<StreamingGeneration> {
297 let pinner: Box<dyn Pinner> = if let Some(lock_file) = lock_file {
298 let lock = cfg
299 .load_lock_file(&lock_file)
300 .await
301 .context("could not load pinned files lock file")?;
302 let pinner = Spawn::new_with_root_hash(
303 cfg.pather().pinned_folder(),
304 cfg.pather().components_folder(),
305 &lock.pin,
306 &cfg.operator(),
307 )
308 .await
309 .context("could not get new root with hash")?;
310 Box::new(pinner)
311 } else {
312 let pinner = Latest::new(cfg.pather().spawn_folder_path())?;
313 Box::new(pinner)
314 };
315
316 let store = Store::new(pinner, cfg.operator().clone(), cfg.pather())
317 .context("could not create new store for generate")?;
318 let db_config = cfg
319 .db_config()
320 .context("could not get db config for generate")?;
321
322 generate_streaming_with_store(
323 name,
324 variables,
325 &db_config.environment,
326 &db_config.engine,
327 store,
328 )
329 .await
330}
331
332pub async fn generate_streaming_with_store(
334 name: &str,
335 variables: Option<Variables>,
336 environment: &str,
337 engine: &EngineType,
338 store: Store,
339) -> Result<StreamingGeneration> {
340 let contents = store
342 .load_migration(name)
343 .await
344 .context("generate_streaming_with_store could not read migration")?;
345
346 Ok(StreamingGeneration {
347 store,
348 template_contents: contents,
349 environment: environment.to_string(),
350 variables: variables.unwrap_or_default(),
351 engine: engine.clone(),
352 })
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::sql_formatter::{get_auto_escape_callback, get_formatter};
359 use minijinja::{context, Environment, Value};
360
361 fn render_sql_value(value: Value) -> String {
363 let mut env = Environment::new();
364 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
365 env.set_formatter(get_formatter(SqlDialect::Postgres));
366 env.add_template("test.sql", "{{ value }}").unwrap();
367 let tmpl = env.get_template("test.sql").unwrap();
368 tmpl.render(context!(value => value)).unwrap()
369 }
370
371 #[test]
372 fn test_engine_to_dialect_postgres_psql() {
373 let dialect = engine_to_dialect(&EngineType::PostgresPSQL);
374 assert_eq!(dialect, SqlDialect::Postgres);
375 }
376
377 #[test]
381 fn test_sql_escape_string() {
382 let result = render_sql_value(Value::from("hello"));
383 assert_eq!(result, "'hello'");
384 }
385
386 #[test]
387 fn test_sql_escape_string_injection_attempt() {
388 let result = render_sql_value(Value::from("'; DROP TABLE users; --"));
389 assert_eq!(result, "'''; DROP TABLE users; --'");
390 }
391
392 #[test]
393 fn test_sql_escape_integer() {
394 let result = render_sql_value(Value::from(42));
395 assert_eq!(result, "42");
396 }
397
398 #[test]
399 fn test_sql_escape_bool() {
400 let result = render_sql_value(Value::from(true));
401 assert_eq!(result, "TRUE");
402 }
403
404 #[test]
405 fn test_sql_escape_none() {
406 let result = render_sql_value(Value::from(()));
407 assert_eq!(result, "NULL");
408 }
409
410 #[test]
411 fn test_sql_escape_seq() {
412 let result = render_sql_value(Value::from(vec![1, 2, 3]));
413 assert_eq!(result, "ARRAY[1, 2, 3]");
414 }
415
416 #[test]
417 fn test_sql_escape_bytes() {
418 let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
419 let result = render_sql_value(bytes);
420 assert_eq!(result, "'\\xdeadbeef'::bytea");
421 }
422
423 #[test]
424 fn test_sql_escape_for_non_sql_templates() {
425 let mut env = Environment::new();
426 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
427 env.set_formatter(get_formatter(SqlDialect::Postgres));
428 env.add_template("test.txt", "{{ value }}").unwrap();
430 let tmpl = env.get_template("test.txt").unwrap();
431 let result = tmpl.render(context!(value => "hello")).unwrap();
432 assert_eq!(result, "'hello'");
434 }
435
436 #[test]
437 fn test_sql_safe_filter_bypasses_escaping() {
438 let mut env = Environment::new();
439 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
440 env.set_formatter(get_formatter(SqlDialect::Postgres));
441 env.add_template("test.sql", "{{ value|safe }}").unwrap();
443 let tmpl = env.get_template("test.sql").unwrap();
444 let result = tmpl.render(context!(value => "raw SQL here")).unwrap();
445 assert_eq!(result, "raw SQL here");
447 }
448
449 #[test]
450 fn test_sql_escape_only_on_output_not_in_loops() {
451 let mut env = Environment::new();
452 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
453 env.set_formatter(get_formatter(SqlDialect::Postgres));
454
455 let template =
456 r#"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"#;
457 env.add_template("test.sql", template).unwrap();
458 let tmpl = env.get_template("test.sql").unwrap();
459
460 let items = vec!["alice", "bob", "charlie"];
461 let result = tmpl.render(context!(items => items)).unwrap();
462 assert_eq!(result, "'alice', 'bob', 'charlie'");
463 }
464
465 #[test]
466 fn test_base64_encode_filter() {
467 let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
468 let result = base64_encode_filter(&bytes).unwrap();
469 assert_eq!(result.to_string(), "3q2+7w==");
470 }
471
472 #[test]
473 fn test_base64_encode_filter_text() {
474 let bytes = Value::from_bytes(b"hello world".to_vec());
475 let result = base64_encode_filter(&bytes).unwrap();
476 assert_eq!(result.to_string(), "aGVsbG8gd29ybGQ=");
477 }
478
479 #[test]
480 fn test_base64_encode_filter_string() {
481 let value = Value::from("hello world");
482 let result = base64_encode_filter(&value).unwrap();
483 assert_eq!(result.to_string(), "aGVsbG8gd29ybGQ=");
484 }
485
486 #[test]
487 fn test_base64_encode_filter_rejects_other_types() {
488 let value = Value::from(42);
489 let result = base64_encode_filter(&value);
490 assert!(result.is_err());
491 }
492
493 #[test]
494 fn test_to_string_lossy_filter_valid_utf8() {
495 let bytes = Value::from_bytes(b"hello world".to_vec());
496 let result = to_string_lossy_filter(&bytes).unwrap();
497 assert_eq!(result.to_string(), "hello world");
498 }
499
500 #[test]
501 fn test_to_string_lossy_filter_invalid_utf8() {
502 let bytes = Value::from_bytes(vec![0x68, 0x65, 0x6C, 0xFF, 0x6F]);
503 let result = to_string_lossy_filter(&bytes).unwrap();
504 let s = result.to_string();
505 assert!(s.contains("hel"));
506 assert!(s.contains('\u{FFFD}'));
507 assert!(s.contains('o'));
508 }
509
510 #[test]
511 fn test_to_string_lossy_filter_passes_through_string() {
512 let value = Value::from("already a string");
513 let result = to_string_lossy_filter(&value).unwrap();
514 assert_eq!(result.to_string(), "already a string");
515 }
516
517 #[test]
518 fn test_to_string_lossy_filter_rejects_other_types() {
519 let value = Value::from(42);
520 let result = to_string_lossy_filter(&value);
521 assert!(result.is_err());
522 }
523
524 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
525 async fn test_read_file_filter_with_store() {
526 use crate::config::FolderPather;
527 use crate::store::pinner::latest::Latest;
528 use opendal::services::Memory;
529 use opendal::Operator;
530
531 let mem_service = Memory::default();
533 let op = Operator::new(mem_service).unwrap().finish();
534 op.write("components/test.txt", "file contents here")
535 .await
536 .unwrap();
537
538 let pinner = Latest::new("").unwrap();
539 let pather = FolderPather {
540 spawn_folder: "".to_string(),
541 };
542 let store = Store::new(Box::new(pinner), op, pather).unwrap();
543
544 let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
545 env.add_template(
546 "test.sql",
547 r#"{{ "test.txt"|read_file|to_string_lossy|safe }}"#,
548 )
549 .unwrap();
550 let tmpl = env.get_template("test.sql").unwrap();
551 let result = tmpl.render(context!()).unwrap();
552 assert_eq!(result, "file contents here");
553 }
554
555 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
556 async fn test_read_file_with_base64_encode() {
557 use crate::config::FolderPather;
558 use crate::store::pinner::latest::Latest;
559 use opendal::services::Memory;
560 use opendal::Operator;
561
562 let mem_service = Memory::default();
563 let op = Operator::new(mem_service).unwrap().finish();
564 op.write("components/binary.dat", vec![0xDE, 0xAD, 0xBE, 0xEF])
565 .await
566 .unwrap();
567
568 let pinner = Latest::new("").unwrap();
569 let pather = FolderPather {
570 spawn_folder: "".to_string(),
571 };
572 let store = Store::new(Box::new(pinner), op, pather).unwrap();
573
574 let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
575 env.add_template(
576 "test.sql",
577 r#"{{ "binary.dat"|read_file|base64_encode|safe }}"#,
578 )
579 .unwrap();
580 let tmpl = env.get_template("test.sql").unwrap();
581 let result = tmpl.render(context!()).unwrap();
582 assert_eq!(result, "3q2+7w==");
583 }
584
585 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
586 async fn test_read_file_missing_file_returns_error() {
587 use crate::config::FolderPather;
588 use crate::store::pinner::latest::Latest;
589 use opendal::services::Memory;
590 use opendal::Operator;
591
592 let mem_service = Memory::default();
593 let op = Operator::new(mem_service).unwrap().finish();
594
595 let pinner = Latest::new("").unwrap();
596 let pather = FolderPather {
597 spawn_folder: "".to_string(),
598 };
599 let store = Store::new(Box::new(pinner), op, pather).unwrap();
600
601 let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
602 env.add_template(
603 "test.sql",
604 r#"{{ "nonexistent.txt"|read_file|to_string_lossy }}"#,
605 )
606 .unwrap();
607 let tmpl = env.get_template("test.sql").unwrap();
608 let result = tmpl.render(context!());
609 assert!(result.is_err());
610 }
611}