Skip to main content

spawn_db/
template.rs

1use crate::config;
2use crate::engine::EngineType;
3use crate::escape::EscapedIdentifier;
4use crate::store::pinner::latest::Latest;
5use crate::store::pinner::spawn::Spawn;
6use crate::store::pinner::Pinner;
7use crate::store::Store;
8use crate::variables::Variables;
9use minijinja::{Environment, Value};
10
11use crate::sql_formatter::SqlDialect;
12use base64::{engine::general_purpose::STANDARD, Engine as _};
13use uuid::Uuid;
14
15use anyhow::{Context, Result};
16use minijinja::context;
17use std::sync::Arc;
18
19/// Maps an EngineType to the appropriate SQL dialect for formatting.
20///
21/// Multiple engine types may share the same dialect. For example,
22/// both a psql CLI engine and a native PostgreSQL driver would use
23/// the Postgres dialect.
24fn engine_to_dialect(engine: &EngineType) -> SqlDialect {
25    match engine {
26        EngineType::PostgresPSQL => SqlDialect::Postgres,
27        // Future engines:
28        // EngineType::PostgresNative => SqlDialect::Postgres,
29        // EngineType::MySQL => SqlDialect::MySQL,
30        // EngineType::SqlServer => SqlDialect::SqlServer,
31    }
32}
33
34pub fn template_env(store: Store, engine: &EngineType) -> Result<Environment<'static>> {
35    let mut env = Environment::new();
36
37    let store = Arc::new(store);
38
39    let mj_store = MiniJinjaLoader {
40        store: Arc::clone(&store),
41    };
42    env.set_loader(move |name: &str| mj_store.load(name));
43    env.add_function("gen_uuid_v4", gen_uuid_v4);
44    env.add_function("gen_uuid_v5", gen_uuid_v5);
45    env.add_filter("escape_identifier", escape_identifier_filter);
46
47    let read_file_store = Arc::clone(&store);
48    env.add_filter(
49        "read_file",
50        move |path: &str| -> Result<Value, minijinja::Error> {
51            read_file_filter(path, &read_file_store)
52        },
53    );
54    env.add_filter("base64_encode", base64_encode_filter);
55    env.add_filter("to_string_lossy", to_string_lossy_filter);
56    env.add_filter("parse_json", parse_json_filter);
57    env.add_filter("parse_toml", parse_toml_filter);
58    env.add_filter("parse_yaml", parse_yaml_filter);
59
60    let read_json_store = Arc::clone(&store);
61    env.add_filter(
62        "read_json",
63        move |path: &str| -> Result<Value, minijinja::Error> {
64            let bytes = read_file_bytes(path, &read_json_store)?;
65            let s = string_from_bytes(&bytes)?;
66            parse_json_filter(&s)
67        },
68    );
69    let read_toml_store = Arc::clone(&store);
70    env.add_filter(
71        "read_toml",
72        move |path: &str| -> Result<Value, minijinja::Error> {
73            let bytes = read_file_bytes(path, &read_toml_store)?;
74            let s = string_from_bytes(&bytes)?;
75            parse_toml_filter(&s)
76        },
77    );
78    let read_yaml_store = Arc::clone(&store);
79    env.add_filter(
80        "read_yaml",
81        move |path: &str| -> Result<Value, minijinja::Error> {
82            let bytes = read_file_bytes(path, &read_yaml_store)?;
83            let s = string_from_bytes(&bytes)?;
84            parse_yaml_filter(&s)
85        },
86    );
87
88    // Get the appropriate dialect for this engine
89    let dialect = engine_to_dialect(engine);
90
91    // Enable SQL auto-escaping for .sql files using the dialect-specific callback
92    env.set_auto_escape_callback(crate::sql_formatter::get_auto_escape_callback(dialect));
93
94    // Set custom formatter that handles SQL escaping based on the dialect
95    env.set_formatter(crate::sql_formatter::get_formatter(dialect));
96
97    Ok(env)
98}
99
100struct MiniJinjaLoader {
101    pub store: Arc<Store>,
102}
103
104impl MiniJinjaLoader {
105    pub fn load(&self, name: &str) -> std::result::Result<Option<String>, minijinja::Error> {
106        let result = tokio::task::block_in_place(|| {
107            tokio::runtime::Handle::current()
108                .block_on(async { self.store.load_component(name).await })
109        });
110
111        result.map_err(|e| {
112            minijinja::Error::new(
113                minijinja::ErrorKind::InvalidOperation,
114                format!("Failed to load from object store: {}", e),
115            )
116        })
117    }
118}
119
120fn gen_uuid_v4() -> Result<String, minijinja::Error> {
121    Ok(Uuid::new_v4().to_string())
122}
123
124fn gen_uuid_v5(seed: &str) -> Result<String, minijinja::Error> {
125    Ok(Uuid::new_v5(&Uuid::NAMESPACE_DNS, seed.as_bytes()).to_string())
126}
127
128/// Filter to escape a value as a PostgreSQL identifier (e.g., database name, table name).
129///
130/// This wraps the value in double quotes and escapes any embedded double quotes,
131/// making it safe to use in SQL statements where an identifier is expected.
132///
133/// Usage in templates: `{{ dbname|escape_identifier }}`
134fn escape_identifier_filter(value: &Value) -> Result<Value, minijinja::Error> {
135    let s = value.to_string();
136    let escaped = EscapedIdentifier::new(&s);
137    // Return as a safe string so it won't be further escaped by the SQL formatter
138    Ok(Value::from_safe_string(escaped.to_string()))
139}
140
141/// Reads raw bytes from a file in the components folder via the Store.
142fn read_file_bytes(path: &str, store: &Arc<Store>) -> Result<Vec<u8>, minijinja::Error> {
143    let bytes = tokio::task::block_in_place(|| {
144        tokio::runtime::Handle::current().block_on(async { store.read_file_bytes(path).await })
145    });
146
147    bytes.map_err(|e| {
148        minijinja::Error::new(
149            minijinja::ErrorKind::InvalidOperation,
150            format!("Failed to read file '{}': {}", path, e),
151        )
152    })
153}
154
155/// Converts raw bytes to a UTF-8 string, returning an error on invalid UTF-8.
156fn string_from_bytes(bytes: &[u8]) -> Result<String, minijinja::Error> {
157    String::from_utf8(bytes.to_vec()).map_err(|e| {
158        minijinja::Error::new(
159            minijinja::ErrorKind::InvalidOperation,
160            format!("File is not valid UTF-8: {}", e),
161        )
162    })
163}
164
165/// Filter to read a file from the components folder and return its contents as raw bytes.
166///
167/// Returns a bytes Value that can be further processed with `base64_encode` or `to_string_lossy`.
168///
169/// Usage in templates: `{{ "path/to/file"|read_file|to_string_lossy }}`
170fn read_file_filter(path: &str, store: &Arc<Store>) -> Result<Value, minijinja::Error> {
171    Ok(Value::from_bytes(read_file_bytes(path, store)?))
172}
173
174/// Filter to encode a value as a base64 string.
175///
176/// Accepts both bytes (e.g. from `read_file`) and strings.
177///
178/// Usage in templates: `{{ "path/to/file"|read_file|base64_encode }}`
179fn base64_encode_filter(value: &Value) -> Result<Value, minijinja::Error> {
180    use minijinja::value::ValueKind;
181    match value.kind() {
182        ValueKind::Bytes => {
183            let bytes = value.as_bytes().unwrap();
184            Ok(Value::from(STANDARD.encode(bytes)))
185        }
186        ValueKind::String => Ok(Value::from(STANDARD.encode(value.as_str().unwrap()))),
187        _ => Err(minijinja::Error::new(
188            minijinja::ErrorKind::InvalidOperation,
189            "base64_encode filter expects bytes or string input",
190        )),
191    }
192}
193
194/// Filter to convert bytes to a string, replacing invalid UTF-8 sequences.
195/// If the value is already a string, it is returned as-is.
196///
197/// Usage in templates: `{{ "path/to/file.txt"|read_file|to_string_lossy }}`
198fn to_string_lossy_filter(value: &Value) -> Result<Value, minijinja::Error> {
199    use minijinja::value::ValueKind;
200    match value.kind() {
201        ValueKind::Bytes => {
202            let bytes = value.as_bytes().unwrap();
203            Ok(Value::from(String::from_utf8_lossy(bytes).into_owned()))
204        }
205        ValueKind::String => Ok(value.clone()),
206        _ => Err(minijinja::Error::new(
207            minijinja::ErrorKind::InvalidOperation,
208            "to_string_lossy filter expects bytes or string input",
209        )),
210    }
211}
212
213/// Filter to parse a JSON string into a template value.
214///
215/// Usage in templates: `{{ "data.json"|read_file|to_string_lossy|parse_json }}`
216fn parse_json_filter(value: &str) -> Result<Value, minijinja::Error> {
217    let vars = Variables::from_str("json", value).map_err(|e| {
218        minijinja::Error::new(
219            minijinja::ErrorKind::InvalidOperation,
220            format!("parse_json: {}", e),
221        )
222    })?;
223    Ok(Value::from_serialize(&vars))
224}
225
226/// Filter to parse a TOML string into a template value.
227///
228/// Usage in templates: `{{ "config.toml"|read_file|to_string_lossy|parse_toml }}`
229fn parse_toml_filter(value: &str) -> Result<Value, minijinja::Error> {
230    let vars = Variables::from_str("toml", value).map_err(|e| {
231        minijinja::Error::new(
232            minijinja::ErrorKind::InvalidOperation,
233            format!("parse_toml: {}", e),
234        )
235    })?;
236    Ok(Value::from_serialize(&vars))
237}
238
239/// Filter to parse a YAML string into a template value.
240///
241/// Usage in templates: `{{ "data.yaml"|read_file|to_string_lossy|parse_yaml }}`
242fn parse_yaml_filter(value: &str) -> Result<Value, minijinja::Error> {
243    let vars = Variables::from_str("yaml", value).map_err(|e| {
244        minijinja::Error::new(
245            minijinja::ErrorKind::InvalidOperation,
246            format!("parse_yaml: {}", e),
247        )
248    })?;
249    Ok(Value::from_serialize(&vars))
250}
251
252pub struct Generation {
253    pub content: String,
254}
255
256/// Holds all the data needed to render a template to a writer.
257/// This struct is Send and can be moved into a WriterFn closure.
258pub struct StreamingGeneration {
259    store: Store,
260    template_contents: String,
261    environment: String,
262    variables: Variables,
263    engine: EngineType,
264}
265
266impl StreamingGeneration {
267    /// Render the template to the provided writer.
268    /// This creates the minijinja environment and renders in one step.
269    pub fn render_to_writer<W: std::io::Write + ?Sized>(self, writer: &mut W) -> Result<()> {
270        let mut env = template_env(self.store, &self.engine)?;
271        env.add_template("migration.sql", &self.template_contents)?;
272        let tmpl = env.get_template("migration.sql")?;
273        tmpl.render_to_write(
274            context!(env => self.environment, variables => self.variables),
275            writer,
276        )?;
277        Ok(())
278    }
279
280    /// Convert this streaming generation into a WriterFn that can be passed to migration_apply.
281    pub fn into_writer_fn(self) -> crate::engine::WriterFn {
282        Box::new(move |writer: &mut dyn std::io::Write| {
283            self.render_to_writer(writer)
284                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
285        })
286    }
287}
288
289/// Generate a streaming migration that can be rendered directly to a writer.
290/// This avoids materializing the entire SQL in memory.
291pub async fn generate_streaming(
292    cfg: &config::Config,
293    lock_file: Option<String>,
294    name: &str,
295    variables: Option<Variables>,
296) -> Result<StreamingGeneration> {
297    let pinner: Box<dyn Pinner> = if let Some(lock_file) = lock_file {
298        let lock = cfg
299            .load_lock_file(&lock_file)
300            .await
301            .context("could not load pinned files lock file")?;
302        let pinner = Spawn::new_with_root_hash(
303            cfg.pather().pinned_folder(),
304            cfg.pather().components_folder(),
305            &lock.pin,
306            &cfg.operator(),
307        )
308        .await
309        .context("could not get new root with hash")?;
310        Box::new(pinner)
311    } else {
312        let pinner = Latest::new(cfg.pather().spawn_folder_path())?;
313        Box::new(pinner)
314    };
315
316    let store = Store::new(pinner, cfg.operator().clone(), cfg.pather())
317        .context("could not create new store for generate")?;
318    let db_config = cfg
319        .db_config()
320        .context("could not get db config for generate")?;
321
322    generate_streaming_with_store(
323        name,
324        variables,
325        &db_config.environment,
326        &db_config.engine,
327        store,
328    )
329    .await
330}
331
332/// Generate a streaming migration with an existing store.
333pub async fn generate_streaming_with_store(
334    name: &str,
335    variables: Option<Variables>,
336    environment: &str,
337    engine: &EngineType,
338    store: Store,
339) -> Result<StreamingGeneration> {
340    // Read contents from our object store first:
341    let contents = store
342        .load_migration(name)
343        .await
344        .context("generate_streaming_with_store could not read migration")?;
345
346    Ok(StreamingGeneration {
347        store,
348        template_contents: contents,
349        environment: environment.to_string(),
350        variables: variables.unwrap_or_default(),
351        engine: engine.clone(),
352    })
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::sql_formatter::{get_auto_escape_callback, get_formatter};
359    use minijinja::{context, Environment, Value};
360
361    /// Helper to test SQL formatting of a value by rendering it in a .sql template
362    fn render_sql_value(value: Value) -> String {
363        let mut env = Environment::new();
364        env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
365        env.set_formatter(get_formatter(SqlDialect::Postgres));
366        env.add_template("test.sql", "{{ value }}").unwrap();
367        let tmpl = env.get_template("test.sql").unwrap();
368        tmpl.render(context!(value => value)).unwrap()
369    }
370
371    #[test]
372    fn test_engine_to_dialect_postgres_psql() {
373        let dialect = engine_to_dialect(&EngineType::PostgresPSQL);
374        assert_eq!(dialect, SqlDialect::Postgres);
375    }
376
377    // Basic escaping tests - verify the integration with spawn-sql-format works
378    // More comprehensive tests are in the spawn-sql-format crate itself
379
380    #[test]
381    fn test_sql_escape_string() {
382        let result = render_sql_value(Value::from("hello"));
383        assert_eq!(result, "'hello'");
384    }
385
386    #[test]
387    fn test_sql_escape_string_injection_attempt() {
388        let result = render_sql_value(Value::from("'; DROP TABLE users; --"));
389        assert_eq!(result, "'''; DROP TABLE users; --'");
390    }
391
392    #[test]
393    fn test_sql_escape_integer() {
394        let result = render_sql_value(Value::from(42));
395        assert_eq!(result, "42");
396    }
397
398    #[test]
399    fn test_sql_escape_bool() {
400        let result = render_sql_value(Value::from(true));
401        assert_eq!(result, "TRUE");
402    }
403
404    #[test]
405    fn test_sql_escape_none() {
406        let result = render_sql_value(Value::from(()));
407        assert_eq!(result, "NULL");
408    }
409
410    #[test]
411    fn test_sql_escape_seq() {
412        let result = render_sql_value(Value::from(vec![1, 2, 3]));
413        assert_eq!(result, "ARRAY[1, 2, 3]");
414    }
415
416    #[test]
417    fn test_sql_escape_bytes() {
418        let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
419        let result = render_sql_value(bytes);
420        assert_eq!(result, "'\\xdeadbeef'::bytea");
421    }
422
423    #[test]
424    fn test_sql_escape_for_non_sql_templates() {
425        let mut env = Environment::new();
426        env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
427        env.set_formatter(get_formatter(SqlDialect::Postgres));
428        // Use .txt extension - should still trigger SQL escaping
429        env.add_template("test.txt", "{{ value }}").unwrap();
430        let tmpl = env.get_template("test.txt").unwrap();
431        let result = tmpl.render(context!(value => "hello")).unwrap();
432        // SQL escaping applies to all files
433        assert_eq!(result, "'hello'");
434    }
435
436    #[test]
437    fn test_sql_safe_filter_bypasses_escaping() {
438        let mut env = Environment::new();
439        env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
440        env.set_formatter(get_formatter(SqlDialect::Postgres));
441        // Using |safe filter should bypass escaping
442        env.add_template("test.sql", "{{ value|safe }}").unwrap();
443        let tmpl = env.get_template("test.sql").unwrap();
444        let result = tmpl.render(context!(value => "raw SQL here")).unwrap();
445        // Should be output as-is without quotes
446        assert_eq!(result, "raw SQL here");
447    }
448
449    #[test]
450    fn test_sql_escape_only_on_output_not_in_loops() {
451        let mut env = Environment::new();
452        env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
453        env.set_formatter(get_formatter(SqlDialect::Postgres));
454
455        let template =
456            r#"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"#;
457        env.add_template("test.sql", template).unwrap();
458        let tmpl = env.get_template("test.sql").unwrap();
459
460        let items = vec!["alice", "bob", "charlie"];
461        let result = tmpl.render(context!(items => items)).unwrap();
462        assert_eq!(result, "'alice', 'bob', 'charlie'");
463    }
464
465    #[test]
466    fn test_base64_encode_filter() {
467        let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
468        let result = base64_encode_filter(&bytes).unwrap();
469        assert_eq!(result.to_string(), "3q2+7w==");
470    }
471
472    #[test]
473    fn test_base64_encode_filter_text() {
474        let bytes = Value::from_bytes(b"hello world".to_vec());
475        let result = base64_encode_filter(&bytes).unwrap();
476        assert_eq!(result.to_string(), "aGVsbG8gd29ybGQ=");
477    }
478
479    #[test]
480    fn test_base64_encode_filter_string() {
481        let value = Value::from("hello world");
482        let result = base64_encode_filter(&value).unwrap();
483        assert_eq!(result.to_string(), "aGVsbG8gd29ybGQ=");
484    }
485
486    #[test]
487    fn test_base64_encode_filter_rejects_other_types() {
488        let value = Value::from(42);
489        let result = base64_encode_filter(&value);
490        assert!(result.is_err());
491    }
492
493    #[test]
494    fn test_to_string_lossy_filter_valid_utf8() {
495        let bytes = Value::from_bytes(b"hello world".to_vec());
496        let result = to_string_lossy_filter(&bytes).unwrap();
497        assert_eq!(result.to_string(), "hello world");
498    }
499
500    #[test]
501    fn test_to_string_lossy_filter_invalid_utf8() {
502        let bytes = Value::from_bytes(vec![0x68, 0x65, 0x6C, 0xFF, 0x6F]);
503        let result = to_string_lossy_filter(&bytes).unwrap();
504        let s = result.to_string();
505        assert!(s.contains("hel"));
506        assert!(s.contains('\u{FFFD}'));
507        assert!(s.contains('o'));
508    }
509
510    #[test]
511    fn test_to_string_lossy_filter_passes_through_string() {
512        let value = Value::from("already a string");
513        let result = to_string_lossy_filter(&value).unwrap();
514        assert_eq!(result.to_string(), "already a string");
515    }
516
517    #[test]
518    fn test_to_string_lossy_filter_rejects_other_types() {
519        let value = Value::from(42);
520        let result = to_string_lossy_filter(&value);
521        assert!(result.is_err());
522    }
523
524    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
525    async fn test_read_file_filter_with_store() {
526        use crate::config::FolderPather;
527        use crate::store::pinner::latest::Latest;
528        use opendal::services::Memory;
529        use opendal::Operator;
530
531        // Set up an in-memory operator with a test file in the components folder
532        let mem_service = Memory::default();
533        let op = Operator::new(mem_service).unwrap().finish();
534        op.write("components/test.txt", "file contents here")
535            .await
536            .unwrap();
537
538        let pinner = Latest::new("").unwrap();
539        let pather = FolderPather {
540            spawn_folder: "".to_string(),
541        };
542        let store = Store::new(Box::new(pinner), op, pather).unwrap();
543
544        let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
545        env.add_template(
546            "test.sql",
547            r#"{{ "test.txt"|read_file|to_string_lossy|safe }}"#,
548        )
549        .unwrap();
550        let tmpl = env.get_template("test.sql").unwrap();
551        let result = tmpl.render(context!()).unwrap();
552        assert_eq!(result, "file contents here");
553    }
554
555    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
556    async fn test_read_file_with_base64_encode() {
557        use crate::config::FolderPather;
558        use crate::store::pinner::latest::Latest;
559        use opendal::services::Memory;
560        use opendal::Operator;
561
562        let mem_service = Memory::default();
563        let op = Operator::new(mem_service).unwrap().finish();
564        op.write("components/binary.dat", vec![0xDE, 0xAD, 0xBE, 0xEF])
565            .await
566            .unwrap();
567
568        let pinner = Latest::new("").unwrap();
569        let pather = FolderPather {
570            spawn_folder: "".to_string(),
571        };
572        let store = Store::new(Box::new(pinner), op, pather).unwrap();
573
574        let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
575        env.add_template(
576            "test.sql",
577            r#"{{ "binary.dat"|read_file|base64_encode|safe }}"#,
578        )
579        .unwrap();
580        let tmpl = env.get_template("test.sql").unwrap();
581        let result = tmpl.render(context!()).unwrap();
582        assert_eq!(result, "3q2+7w==");
583    }
584
585    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
586    async fn test_read_file_missing_file_returns_error() {
587        use crate::config::FolderPather;
588        use crate::store::pinner::latest::Latest;
589        use opendal::services::Memory;
590        use opendal::Operator;
591
592        let mem_service = Memory::default();
593        let op = Operator::new(mem_service).unwrap().finish();
594
595        let pinner = Latest::new("").unwrap();
596        let pather = FolderPather {
597            spawn_folder: "".to_string(),
598        };
599        let store = Store::new(Box::new(pinner), op, pather).unwrap();
600
601        let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
602        env.add_template(
603            "test.sql",
604            r#"{{ "nonexistent.txt"|read_file|to_string_lossy }}"#,
605        )
606        .unwrap();
607        let tmpl = env.get_template("test.sql").unwrap();
608        let result = tmpl.render(context!());
609        assert!(result.is_err());
610    }
611
612    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
613    async fn test_read_file_filter_uses_pinned_store() {
614        use crate::config::FolderPather;
615        use crate::store::pinner::snapshot;
616        use crate::store::pinner::spawn::Spawn;
617        use opendal::services::Memory;
618        use opendal::Operator;
619
620        let mem_service = Memory::default();
621        let op = Operator::new(mem_service).unwrap().finish();
622
623        // Write a file into the components folder and snapshot it into the pinned store
624        op.write("components/test.txt", "pinned content")
625            .await
626            .unwrap();
627        let root_hash = snapshot(&op, "pinned/", "components/").await.unwrap();
628
629        // Delete the original file so it only exists in the pinned CAS store
630        op.delete("components/test.txt").await.unwrap();
631
632        // Create a Spawn pinner using the snapshot hash
633        let pinner = Spawn::new_with_root_hash(
634            "pinned/".to_string(),
635            "components/".to_string(),
636            &root_hash,
637            &op,
638        )
639        .await
640        .unwrap();
641
642        let pather = FolderPather {
643            spawn_folder: "".to_string(),
644        };
645        let store = Store::new(Box::new(pinner), op, pather).unwrap();
646
647        let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
648        env.add_template(
649            "test.sql",
650            r#"{{ "test.txt"|read_file|to_string_lossy|safe }}"#,
651        )
652        .unwrap();
653        let tmpl = env.get_template("test.sql").unwrap();
654        let result = tmpl.render(context!()).unwrap();
655        assert_eq!(result, "pinned content");
656    }
657}