wick_sql/
component.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime};
5
6use flow_component::{BoxFuture, Component, ComponentError, LocalScope};
7use futures::StreamExt;
8use once_cell::sync::Lazy;
9use regex::{Captures, Regex};
10use tracing::Span;
11use url::Url;
12use wick_config::config::components::{ComponentConfig, OperationConfig, SqlComponentConfig, SqlOperationDefinition};
13use wick_config::config::{ErrorBehavior, Metadata};
14use wick_config::Resolver;
15use wick_interface_types::{ComponentSignature, Field, OperationSignatures, Type};
16use wick_packet::{Invocation, Observer, Packet, PacketExt, PacketSender, PacketStream, RuntimeConfig};
17
18use crate::common::{Connection, DatabaseProvider};
19use crate::{common, Error};
20
21#[derive(Debug, Clone, Copy, PartialEq)]
22enum DbKind {
23  Mssql,
24  Postgres,
25  Sqlite,
26}
27#[derive(Clone)]
28struct Client {
29  inner: Arc<dyn DatabaseProvider + Send + Sync>,
30}
31
32impl Client {
33  async fn new(
34    url: &Url,
35    config: &mut SqlComponentConfig,
36    _metadata: Option<Metadata>,
37    _root_config: Option<RuntimeConfig>, // TODO use this
38    resolver: &Resolver,
39  ) -> Result<Self, Error> {
40    let client: Arc<dyn DatabaseProvider + Send + Sync> = match url.scheme() {
41      "mssql" => {
42        normalize_operations(config.operations_mut(), DbKind::Mssql);
43        Arc::new(crate::mssql_tiberius::AzureSqlComponent::new(config.clone(), resolver).await?)
44      }
45      "postgres" => {
46        normalize_operations(config.operations_mut(), DbKind::Postgres);
47        Arc::new(crate::sqlx::SqlXComponent::new(config.clone(), resolver).await?)
48      }
49      "file" | "sqlite" => {
50        normalize_operations(config.operations_mut(), DbKind::Sqlite);
51        Arc::new(crate::sqlx::SqlXComponent::new(config.clone(), resolver).await?)
52      }
53      _ => return Err(Error::InvalidScheme(url.scheme().to_owned())),
54    };
55
56    Ok(Self { inner: client })
57  }
58
59  fn inner(&self) -> &Arc<dyn DatabaseProvider + Sync + Send> {
60    &self.inner
61  }
62}
63
64#[async_trait::async_trait]
65impl DatabaseProvider for Client {
66  fn get_statement<'a>(&'a self, id: &'a str) -> Option<&'a str> {
67    self.inner().get_statement(id)
68  }
69
70  async fn get_connection<'a, 'b>(&'a self) -> Result<Connection<'b>, Error>
71  where
72    'a: 'b,
73  {
74    self.inner().get_connection().await
75  }
76}
77
78/// The Azure SQL Wick component.
79#[derive(Clone)]
80#[must_use]
81pub struct SqlComponent {
82  provider: Client,
83  signature: Arc<ComponentSignature>,
84  url: Url,
85  config: SqlComponentConfig,
86  root_config: Option<RuntimeConfig>,
87}
88
89impl std::fmt::Debug for SqlComponent {
90  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91    f.debug_struct("SqlComponent")
92      .field("signature", &self.signature)
93      .field("url", &self.url)
94      .field("config", &self.config)
95      .field("root_config", &self.root_config)
96      .finish()
97  }
98}
99
100impl SqlComponent {
101  /// Instantiate a new Azure SQL component.
102  pub async fn new(
103    mut config: SqlComponentConfig,
104    root_config: Option<RuntimeConfig>,
105    metadata: Option<Metadata>,
106    resolver: &Resolver,
107  ) -> Result<Self, ComponentError> {
108    validate(&config, resolver)?;
109    let sig = common::gen_signature(
110      "wick/component/sql",
111      config.operation_signatures(),
112      config.config(),
113      &metadata,
114    )?;
115
116    let url = common::convert_url_resource(resolver, config.resource())?;
117
118    validate(&config, resolver)?;
119    let provider = Client::new(&url, &mut config, metadata, root_config.clone(), resolver).await?;
120
121    Ok(Self {
122      provider,
123      signature: Arc::new(sig),
124      url,
125      root_config,
126      config,
127    })
128  }
129}
130
131impl Component for SqlComponent {
132  fn handle(
133    &self,
134    invocation: Invocation,
135    _data: Option<RuntimeConfig>, // TODO: this needs to be used
136    _callback: LocalScope,
137  ) -> BoxFuture<Result<PacketStream, ComponentError>> {
138    let client = self.provider.clone();
139    let opdef = self
140      .config
141      .get_operation(invocation.target().operation_id())
142      .ok_or_else(|| Error::MissingOperation(invocation.target().operation_id().to_owned()))
143      .cloned();
144
145    Box::pin(async move {
146      let opdef = opdef?;
147      let stmt = client.get_statement(opdef.name()).unwrap().to_owned();
148
149      let input_names: Vec<_> = opdef.inputs().iter().map(|i| i.name.clone()).collect();
150      let (invocation, stream) = invocation.split();
151      let input_streams = wick_packet::split_stream(stream, input_names);
152      let (tx, rx) = invocation.make_response();
153      tokio::spawn(async move {
154        let start = SystemTime::now();
155        let span = invocation.span.clone();
156        let Ok(mut connection) = client.get_connection().await else {
157          invocation.trace(|| {
158            error!("could not get connection to database");
159          });
160          return;
161        };
162        if let Err(e) = handle_call(&mut connection, opdef, input_streams, tx.clone(), &stmt, span).await {
163          invocation.trace(|| {
164            error!(error = %e, "error handling sql operation");
165          });
166          let _ = tx.error(wick_packet::Error::component_error(e.to_string()));
167        }
168        let _ = tx.send(Packet::done("output"));
169        let duration = SystemTime::now().duration_since(start).unwrap();
170        invocation.trace(|| {
171          debug!(?duration, target=%invocation.target,"mssql operation complete");
172        });
173      });
174
175      Ok(rx)
176    })
177  }
178
179  fn signature(&self) -> &ComponentSignature {
180    &self.signature
181  }
182}
183
184fn validate(config: &SqlComponentConfig, _resolver: &Resolver) -> Result<(), Error> {
185  let bad_ops: Vec<_> = config
186    .operations()
187    .iter()
188    .filter(|op| {
189      let outputs = op.outputs();
190      outputs.len() > 1 || outputs.len() == 1 && outputs[0] != Field::new("output", Type::Object)
191    })
192    .map(|op| op.name().to_owned())
193    .collect();
194
195  if !bad_ops.is_empty() {
196    return Err(Error::InvalidOutput(bad_ops));
197  }
198
199  Ok(())
200}
201
202async fn handle_call<'a, 'b, 'c>(
203  connection: &'a mut Connection<'c>,
204  opdef: SqlOperationDefinition,
205  input_streams: Vec<PacketStream>,
206  tx: PacketSender,
207  stmt: &'b str,
208  span: Span,
209) -> Result<(), Error>
210where
211  'b: 'a,
212{
213  let error_behavior = opdef.on_error();
214
215  connection.start(error_behavior).await?;
216
217  let result = handle_stream(connection, opdef, input_streams, tx, stmt, span.clone()).await;
218  if let Err(e) = result {
219    span.in_scope(|| error!(error = %e, "error in sql operation"));
220    let err = Error::OperationFailed(e.to_string());
221    connection.handle_error(e, error_behavior).await?;
222    return Err(err);
223  }
224  connection.finish().await?;
225
226  Ok(())
227}
228
229async fn handle_stream<'a, 'b, 'c>(
230  connection: &'a mut Connection<'c>,
231  opdef: SqlOperationDefinition,
232  mut input_streams: Vec<PacketStream>,
233  tx: PacketSender,
234  stmt: &'b str,
235  span: Span,
236) -> Result<(), Error>
237where
238  'b: 'a,
239{
240  span.in_scope(|| debug!(stmt = %stmt, "preparing query for stream"));
241  'outer: loop {
242    let mut incoming_packets = Vec::new();
243
244    for input in &mut input_streams {
245      let packet = input.next().await;
246
247      incoming_packets.push(packet);
248    }
249
250    let num_done = incoming_packets.iter().filter(|r| r.is_none()).count();
251    if num_done > 0 {
252      if num_done != opdef.inputs().len() {
253        return Err(Error::MissingInput);
254      }
255      break 'outer;
256    }
257    let incoming_packets = incoming_packets.into_iter().map(|r| r.unwrap()).collect::<Vec<_>>();
258
259    if let Some(Err(e)) = incoming_packets.iter().find(|r| r.is_err()) {
260      return Err(Error::ComponentError(e.clone()));
261    }
262    let fields = opdef.inputs();
263    let mut type_wrappers = Vec::new();
264
265    for packet in incoming_packets {
266      let packet = packet.unwrap();
267      if packet.is_done() {
268        break 'outer;
269      }
270      if packet.is_open_bracket() || packet.is_close_bracket() {
271        let _ = tx.send(packet.to_port("output"));
272        continue 'outer;
273      }
274      let ty = fields.iter().find(|f| f.name() == packet.port()).unwrap().ty().clone();
275      type_wrappers.push((ty, packet));
276    }
277
278    let start = SystemTime::now();
279    let result = match &opdef {
280      SqlOperationDefinition::Query(_) => {
281        query(connection, tx.clone(), opdef.clone(), type_wrappers, stmt, span.clone()).await
282      }
283      SqlOperationDefinition::Exec(_) => {
284        exec(connection, tx.clone(), opdef.clone(), type_wrappers, stmt, span.clone()).await
285      }
286    };
287    let duration = SystemTime::now().duration_since(start).unwrap();
288
289    span.in_scope(|| debug!(μs = duration.as_micros(), "executed query"));
290
291    if let Err(e) = result {
292      if opdef.on_error() == ErrorBehavior::Ignore {
293        let _ = tx.send(Packet::err("output", e.to_string()));
294      } else {
295        return Err(Error::ErrorInStream(e.to_string()));
296      }
297    };
298
299    if opdef.inputs().len() == 0 {
300      break 'outer;
301    }
302  }
303  Ok(())
304}
305
306async fn query<'a, 'b, 'c>(
307  client: &'a mut Connection<'c>,
308  tx: PacketSender,
309  def: SqlOperationDefinition,
310  args: Vec<(Type, Packet)>,
311  stmt: &'b str,
312  _span: Span,
313) -> Result<Duration, Error>
314where
315  'b: 'a,
316{
317  let start = SystemTime::now();
318
319  let bound_args = common::bind_args(def.arguments(), &args)?;
320
321  let mut rows = client.query(stmt, bound_args).await?;
322
323  while let Some(row) = rows.next().await {
324    let _ = match row {
325      Ok(row) => tx.send(Packet::encode("output", row)),
326      Err(e) => tx.send(Packet::err("output", e.to_string())),
327    };
328  }
329
330  let duration = SystemTime::now().duration_since(start).unwrap();
331
332  Ok(duration)
333}
334
335async fn exec<'a, 'b, 'c>(
336  connection: &'a mut Connection<'c>,
337  tx: PacketSender,
338  def: SqlOperationDefinition,
339  args: Vec<(Type, Packet)>,
340  stmt: &'b str,
341  _span: Span,
342) -> Result<Duration, Error>
343where
344  'b: 'a,
345{
346  let start = SystemTime::now();
347
348  let bound_args = common::bind_args(def.arguments(), &args)?;
349  let packet = match connection.exec(stmt.to_owned(), bound_args).await {
350    Ok(num) => Packet::encode("output", num),
351    Err(err) => Packet::err("output", err.to_string()),
352  };
353
354  let _ = tx.send(packet);
355
356  let duration = SystemTime::now().duration_since(start).unwrap();
357
358  Ok(duration)
359}
360
361static POSITIONAL_ARGS: Lazy<Regex> = Lazy::new(|| Regex::new(r"\$(?<id>\d+)\b").unwrap());
362static WICK_ID_ARGS: Lazy<Regex> = Lazy::new(|| Regex::new(r"\$\{(?<id>\w+)\}").unwrap());
363
364fn normalize_operations(ops: &mut Vec<SqlOperationDefinition>, db: DbKind) {
365  for operations in ops {
366    match operations {
367      wick_config::config::components::SqlOperationDefinition::Query(ref mut op) => {
368        let (mut query, args) = normalize_inline_ids(op.query(), op.arguments().to_vec());
369        if db == DbKind::Mssql {
370          query = normalize_mssql_query(query);
371        }
372        let query = query.to_string();
373        op.set_query(query);
374        op.set_arguments(args);
375      }
376      wick_config::config::components::SqlOperationDefinition::Exec(ref mut op) => {
377        let (mut query, args) = normalize_inline_ids(op.exec(), op.arguments().to_vec());
378        if db == DbKind::Mssql {
379          query = normalize_mssql_query(query);
380        }
381        let query = query.to_string();
382        op.set_exec(query);
383        op.set_arguments(args);
384      }
385    };
386  }
387}
388
389// This translates `${id}` to positional `$1` arguments.
390fn normalize_inline_ids(orig_query: &str, mut orig_args: Vec<String>) -> (Cow<str>, Vec<String>) {
391  if orig_query.contains('$') {
392    // replace all instances of ${id} with @p1, @p2, etc and append the id to the args
393
394    let mut id_map: HashMap<String, usize> = orig_args
395      .iter()
396      .enumerate()
397      .map(|(i, id)| (id.clone(), i + 1))
398      .collect();
399
400    let captures = WICK_ID_ARGS.captures_iter(orig_query);
401    for id in captures {
402      let id = id.name("id").unwrap().as_str().to_owned();
403      if !id_map.contains_key(&id) {
404        id_map.insert(id.clone(), id_map.len() + 1);
405        orig_args.push(id.clone());
406      }
407    }
408
409    let normalized = WICK_ID_ARGS.replace_all(orig_query, |cap: &Captures| {
410      let id = cap.name("id").unwrap().as_str();
411      let id = id_map[id];
412      format!("${}", id)
413    });
414    debug!(%orig_query,%normalized, "sql:inline-replacement");
415    (normalized, orig_args)
416  } else {
417    (Cow::Borrowed(orig_query), orig_args)
418  }
419}
420
421// This translates `$1..$n` to `@p1..@pn` to be compatible with Tiberius.
422fn normalize_mssql_query(original: Cow<str>) -> Cow<str> {
423  if original.contains('$') {
424    let normalized = POSITIONAL_ARGS.replace_all(&original, "@p${id}");
425    debug!(%original,%normalized, "sql:mssql:normalized query");
426    Cow::Owned(normalized.to_string())
427  } else {
428    original
429  }
430}
431
432#[cfg(test)]
433mod test {
434  use anyhow::Result;
435
436  use super::*;
437
438  #[test]
439  fn test_mssql_query_normalization() -> Result<()> {
440    let query = "select id,name from users where id=$1;";
441    let expected = "select id,name from users where id=@p1;";
442    let actual = normalize_mssql_query(Cow::Borrowed(query));
443    assert_eq!(actual, expected);
444
445    Ok(())
446  }
447
448  #[rstest::rstest]
449  #[case("select id,name from users where id=${id};",[],"select id,name from users where id=$1;",["id"])]
450  #[case("select id,name from users where email=$1, id=${id};",["email"],"select id,name from users where email=$1, id=$2;",["email","id"])]
451  #[case("select id,name from users where email=$1, id=${id}, email=${email};",["email"],"select id,name from users where email=$1, id=$2, email=$1;",["email","id"])]
452  #[case("select id,name from users where id=${id}, id2=${id}, id3=${id};",[],"select id,name from users where id=$1, id2=$1, id3=$1;",["id"])]
453  fn test_inline_id_normalization<const K: usize, const U: usize>(
454    #[case] orig_query: &str,
455    #[case] orig_args: [&str; K],
456    #[case] expected_query: &str,
457    #[case] expected_args: [&str; U],
458  ) -> Result<()> {
459    let (actual, actual_args) =
460      normalize_inline_ids(orig_query, orig_args.iter().copied().map(|s| s.to_owned()).collect());
461    let expected_args = expected_args.iter().map(|s| s.to_owned()).collect::<Vec<_>>();
462    assert_eq!(actual, expected_query);
463    assert_eq!(actual_args, expected_args);
464
465    Ok(())
466  }
467}
468
469#[cfg(test)]
470mod integration_test {
471  use anyhow::Result;
472  use flow_component::Component;
473  use futures::StreamExt;
474  use serde_json::json;
475  use wick_config::config::components::{
476    ComponentConfig,
477    SqlComponentConfigBuilder,
478    SqlOperationDefinition,
479    SqlQueryOperationDefinitionBuilder,
480  };
481  use wick_config::config::ResourceDefinition;
482  use wick_interface_types::{Field, Type};
483  use wick_packet::{packet_stream, Invocation, Packet};
484
485  use super::SqlComponent;
486
487  async fn init_mssql_component() -> Result<SqlComponent> {
488    let docker_host = std::env::var("TEST_HOST").unwrap();
489    let password = std::env::var("TEST_PASSWORD").unwrap();
490    let db_host = docker_host.split(':').next().unwrap();
491    let port = std::env::var("MSSQL_PORT").unwrap();
492    let user = "SA";
493    let db_name = "wick_test";
494
495    let mut config = SqlComponentConfigBuilder::default()
496      .resource("db")
497      .tls(false)
498      .build()
499      .unwrap();
500    let op = SqlQueryOperationDefinitionBuilder::default()
501      .name("test")
502      .query("select id,name from users where id=$1;")
503      .inputs([Field::new("input", Type::I32)])
504      .outputs([Field::new("output", Type::Object)])
505      .arguments(["input".to_owned()])
506      .build()
507      .unwrap();
508
509    config.operations_mut().push(SqlOperationDefinition::Query(op));
510    let mut app_config = wick_config::config::AppConfiguration::default();
511    app_config.add_resource(
512      "db",
513      ResourceDefinition::Url(
514        format!("mssql://{}:{}@{}:{}/{}", user, password, db_host, port, db_name)
515          .try_into()
516          .unwrap(),
517      ),
518    );
519
520    let component = SqlComponent::new(config, None, None, &app_config.resolver()).await?;
521
522    Ok(component)
523  }
524
525  #[test_logger::test(tokio::test)]
526  async fn test_mssql_basic() -> Result<()> {
527    let db = init_mssql_component().await?;
528    let input = packet_stream!(("input", 1_i32));
529    let inv = Invocation::test("mssql", "wick://__local__/test", input, None)?;
530    let response = db.handle(inv, Default::default(), Default::default()).await.unwrap();
531    let packets: Vec<_> = response.collect().await;
532
533    assert_eq!(
534      packets,
535      vec![
536        Ok(Packet::encode("output", json!({"id":1_i32, "name":"Test User"}))),
537        Ok(Packet::done("output"))
538      ]
539    );
540    Ok(())
541  }
542
543  async fn init_pg_component() -> Result<SqlComponent> {
544    let docker_host = std::env::var("TEST_HOST").unwrap();
545    let db_host = docker_host.split(':').next().unwrap();
546    let password = std::env::var("TEST_PASSWORD").unwrap();
547    let port = std::env::var("POSTGRES_PORT").unwrap();
548    let user = "postgres";
549    let db_name = "wick_test";
550
551    let mut config = SqlComponentConfigBuilder::default()
552      .resource("db")
553      .tls(false)
554      .build()
555      .unwrap();
556    let op = SqlQueryOperationDefinitionBuilder::default()
557      .name("test")
558      .query("select id,name from users where id=$1;")
559      .inputs([Field::new("input", Type::I32)])
560      .outputs([Field::new("output", Type::Object)])
561      .arguments(["input".to_owned()])
562      .build()
563      .unwrap();
564
565    config.operations_mut().push(SqlOperationDefinition::Query(op));
566    let mut app_config = wick_config::config::AppConfiguration::default();
567    app_config.add_resource(
568      "db",
569      ResourceDefinition::Url(
570        format!("postgres://{}:{}@{}:{}/{}", user, password, db_host, port, db_name)
571          .try_into()
572          .unwrap(),
573      ),
574    );
575
576    let component = SqlComponent::new(config, None, None, &app_config.resolver()).await?;
577
578    Ok(component)
579  }
580
581  #[test_logger::test(tokio::test)]
582  async fn test_pg_basic() -> Result<()> {
583    let pg = init_pg_component().await?;
584    let input = packet_stream!(("input", 1_u32));
585    let inv = Invocation::test("postgres", "wick://__local__/test", input, None)?;
586    let response = pg.handle(inv, Default::default(), Default::default()).await.unwrap();
587    let packets: Vec<_> = response.collect().await;
588
589    assert_eq!(
590      packets,
591      vec![
592        Ok(Packet::encode("output", json!({"id":1_i32, "name":"Test User"}))),
593        Ok(Packet::done("output"))
594      ]
595    );
596    Ok(())
597  }
598
599  async fn init_sqlite_component() -> Result<SqlComponent> {
600    let db = std::env::var("SQLITE_DB").unwrap();
601
602    let mut config = SqlComponentConfigBuilder::default()
603      .resource("db")
604      .tls(false)
605      .build()
606      .unwrap();
607    let op = SqlQueryOperationDefinitionBuilder::default()
608      .name("test")
609      .query("select id,name from users where id=$1;")
610      .inputs([Field::new("input", Type::I32)])
611      .outputs([Field::new("output", Type::Object)])
612      .arguments(["input".to_owned()])
613      .build()
614      .unwrap();
615
616    config.operations_mut().push(SqlOperationDefinition::Query(op));
617    let mut app_config = wick_config::config::AppConfiguration::default();
618    app_config.add_resource(
619      "db",
620      ResourceDefinition::Url(format!("file://{}", db).try_into().unwrap()),
621    );
622
623    let component = SqlComponent::new(config, None, None, &app_config.resolver()).await?;
624
625    Ok(component)
626  }
627
628  #[test_logger::test(tokio::test)]
629  async fn test_sqlite_basic() -> Result<()> {
630    let pg = init_sqlite_component().await?;
631    let input = packet_stream!(("input", 1_i32));
632    let inv = Invocation::test("sqlite", "wick://__local__/test", input, None)?;
633    let response = pg.handle(inv, Default::default(), Default::default()).await.unwrap();
634    let packets: Vec<_> = response.collect().await;
635
636    assert_eq!(
637      packets,
638      vec![
639        Ok(Packet::encode("output", json!({"id":1_i32, "name":"Test User"}))),
640        Ok(Packet::done("output"))
641      ]
642    );
643    Ok(())
644  }
645}