1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime};
5
6use flow_component::{BoxFuture, Component, ComponentError, LocalScope};
7use futures::StreamExt;
8use once_cell::sync::Lazy;
9use regex::{Captures, Regex};
10use tracing::Span;
11use url::Url;
12use wick_config::config::components::{ComponentConfig, OperationConfig, SqlComponentConfig, SqlOperationDefinition};
13use wick_config::config::{ErrorBehavior, Metadata};
14use wick_config::Resolver;
15use wick_interface_types::{ComponentSignature, Field, OperationSignatures, Type};
16use wick_packet::{Invocation, Observer, Packet, PacketExt, PacketSender, PacketStream, RuntimeConfig};
17
18use crate::common::{Connection, DatabaseProvider};
19use crate::{common, Error};
20
21#[derive(Debug, Clone, Copy, PartialEq)]
22enum DbKind {
23 Mssql,
24 Postgres,
25 Sqlite,
26}
27#[derive(Clone)]
28struct Client {
29 inner: Arc<dyn DatabaseProvider + Send + Sync>,
30}
31
32impl Client {
33 async fn new(
34 url: &Url,
35 config: &mut SqlComponentConfig,
36 _metadata: Option<Metadata>,
37 _root_config: Option<RuntimeConfig>, resolver: &Resolver,
39 ) -> Result<Self, Error> {
40 let client: Arc<dyn DatabaseProvider + Send + Sync> = match url.scheme() {
41 "mssql" => {
42 normalize_operations(config.operations_mut(), DbKind::Mssql);
43 Arc::new(crate::mssql_tiberius::AzureSqlComponent::new(config.clone(), resolver).await?)
44 }
45 "postgres" => {
46 normalize_operations(config.operations_mut(), DbKind::Postgres);
47 Arc::new(crate::sqlx::SqlXComponent::new(config.clone(), resolver).await?)
48 }
49 "file" | "sqlite" => {
50 normalize_operations(config.operations_mut(), DbKind::Sqlite);
51 Arc::new(crate::sqlx::SqlXComponent::new(config.clone(), resolver).await?)
52 }
53 _ => return Err(Error::InvalidScheme(url.scheme().to_owned())),
54 };
55
56 Ok(Self { inner: client })
57 }
58
59 fn inner(&self) -> &Arc<dyn DatabaseProvider + Sync + Send> {
60 &self.inner
61 }
62}
63
64#[async_trait::async_trait]
65impl DatabaseProvider for Client {
66 fn get_statement<'a>(&'a self, id: &'a str) -> Option<&'a str> {
67 self.inner().get_statement(id)
68 }
69
70 async fn get_connection<'a, 'b>(&'a self) -> Result<Connection<'b>, Error>
71 where
72 'a: 'b,
73 {
74 self.inner().get_connection().await
75 }
76}
77
78#[derive(Clone)]
80#[must_use]
81pub struct SqlComponent {
82 provider: Client,
83 signature: Arc<ComponentSignature>,
84 url: Url,
85 config: SqlComponentConfig,
86 root_config: Option<RuntimeConfig>,
87}
88
89impl std::fmt::Debug for SqlComponent {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 f.debug_struct("SqlComponent")
92 .field("signature", &self.signature)
93 .field("url", &self.url)
94 .field("config", &self.config)
95 .field("root_config", &self.root_config)
96 .finish()
97 }
98}
99
100impl SqlComponent {
101 pub async fn new(
103 mut config: SqlComponentConfig,
104 root_config: Option<RuntimeConfig>,
105 metadata: Option<Metadata>,
106 resolver: &Resolver,
107 ) -> Result<Self, ComponentError> {
108 validate(&config, resolver)?;
109 let sig = common::gen_signature(
110 "wick/component/sql",
111 config.operation_signatures(),
112 config.config(),
113 &metadata,
114 )?;
115
116 let url = common::convert_url_resource(resolver, config.resource())?;
117
118 validate(&config, resolver)?;
119 let provider = Client::new(&url, &mut config, metadata, root_config.clone(), resolver).await?;
120
121 Ok(Self {
122 provider,
123 signature: Arc::new(sig),
124 url,
125 root_config,
126 config,
127 })
128 }
129}
130
131impl Component for SqlComponent {
132 fn handle(
133 &self,
134 invocation: Invocation,
135 _data: Option<RuntimeConfig>, _callback: LocalScope,
137 ) -> BoxFuture<Result<PacketStream, ComponentError>> {
138 let client = self.provider.clone();
139 let opdef = self
140 .config
141 .get_operation(invocation.target().operation_id())
142 .ok_or_else(|| Error::MissingOperation(invocation.target().operation_id().to_owned()))
143 .cloned();
144
145 Box::pin(async move {
146 let opdef = opdef?;
147 let stmt = client.get_statement(opdef.name()).unwrap().to_owned();
148
149 let input_names: Vec<_> = opdef.inputs().iter().map(|i| i.name.clone()).collect();
150 let (invocation, stream) = invocation.split();
151 let input_streams = wick_packet::split_stream(stream, input_names);
152 let (tx, rx) = invocation.make_response();
153 tokio::spawn(async move {
154 let start = SystemTime::now();
155 let span = invocation.span.clone();
156 let Ok(mut connection) = client.get_connection().await else {
157 invocation.trace(|| {
158 error!("could not get connection to database");
159 });
160 return;
161 };
162 if let Err(e) = handle_call(&mut connection, opdef, input_streams, tx.clone(), &stmt, span).await {
163 invocation.trace(|| {
164 error!(error = %e, "error handling sql operation");
165 });
166 let _ = tx.error(wick_packet::Error::component_error(e.to_string()));
167 }
168 let _ = tx.send(Packet::done("output"));
169 let duration = SystemTime::now().duration_since(start).unwrap();
170 invocation.trace(|| {
171 debug!(?duration, target=%invocation.target,"mssql operation complete");
172 });
173 });
174
175 Ok(rx)
176 })
177 }
178
179 fn signature(&self) -> &ComponentSignature {
180 &self.signature
181 }
182}
183
184fn validate(config: &SqlComponentConfig, _resolver: &Resolver) -> Result<(), Error> {
185 let bad_ops: Vec<_> = config
186 .operations()
187 .iter()
188 .filter(|op| {
189 let outputs = op.outputs();
190 outputs.len() > 1 || outputs.len() == 1 && outputs[0] != Field::new("output", Type::Object)
191 })
192 .map(|op| op.name().to_owned())
193 .collect();
194
195 if !bad_ops.is_empty() {
196 return Err(Error::InvalidOutput(bad_ops));
197 }
198
199 Ok(())
200}
201
202async fn handle_call<'a, 'b, 'c>(
203 connection: &'a mut Connection<'c>,
204 opdef: SqlOperationDefinition,
205 input_streams: Vec<PacketStream>,
206 tx: PacketSender,
207 stmt: &'b str,
208 span: Span,
209) -> Result<(), Error>
210where
211 'b: 'a,
212{
213 let error_behavior = opdef.on_error();
214
215 connection.start(error_behavior).await?;
216
217 let result = handle_stream(connection, opdef, input_streams, tx, stmt, span.clone()).await;
218 if let Err(e) = result {
219 span.in_scope(|| error!(error = %e, "error in sql operation"));
220 let err = Error::OperationFailed(e.to_string());
221 connection.handle_error(e, error_behavior).await?;
222 return Err(err);
223 }
224 connection.finish().await?;
225
226 Ok(())
227}
228
229async fn handle_stream<'a, 'b, 'c>(
230 connection: &'a mut Connection<'c>,
231 opdef: SqlOperationDefinition,
232 mut input_streams: Vec<PacketStream>,
233 tx: PacketSender,
234 stmt: &'b str,
235 span: Span,
236) -> Result<(), Error>
237where
238 'b: 'a,
239{
240 span.in_scope(|| debug!(stmt = %stmt, "preparing query for stream"));
241 'outer: loop {
242 let mut incoming_packets = Vec::new();
243
244 for input in &mut input_streams {
245 let packet = input.next().await;
246
247 incoming_packets.push(packet);
248 }
249
250 let num_done = incoming_packets.iter().filter(|r| r.is_none()).count();
251 if num_done > 0 {
252 if num_done != opdef.inputs().len() {
253 return Err(Error::MissingInput);
254 }
255 break 'outer;
256 }
257 let incoming_packets = incoming_packets.into_iter().map(|r| r.unwrap()).collect::<Vec<_>>();
258
259 if let Some(Err(e)) = incoming_packets.iter().find(|r| r.is_err()) {
260 return Err(Error::ComponentError(e.clone()));
261 }
262 let fields = opdef.inputs();
263 let mut type_wrappers = Vec::new();
264
265 for packet in incoming_packets {
266 let packet = packet.unwrap();
267 if packet.is_done() {
268 break 'outer;
269 }
270 if packet.is_open_bracket() || packet.is_close_bracket() {
271 let _ = tx.send(packet.to_port("output"));
272 continue 'outer;
273 }
274 let ty = fields.iter().find(|f| f.name() == packet.port()).unwrap().ty().clone();
275 type_wrappers.push((ty, packet));
276 }
277
278 let start = SystemTime::now();
279 let result = match &opdef {
280 SqlOperationDefinition::Query(_) => {
281 query(connection, tx.clone(), opdef.clone(), type_wrappers, stmt, span.clone()).await
282 }
283 SqlOperationDefinition::Exec(_) => {
284 exec(connection, tx.clone(), opdef.clone(), type_wrappers, stmt, span.clone()).await
285 }
286 };
287 let duration = SystemTime::now().duration_since(start).unwrap();
288
289 span.in_scope(|| debug!(μs = duration.as_micros(), "executed query"));
290
291 if let Err(e) = result {
292 if opdef.on_error() == ErrorBehavior::Ignore {
293 let _ = tx.send(Packet::err("output", e.to_string()));
294 } else {
295 return Err(Error::ErrorInStream(e.to_string()));
296 }
297 };
298
299 if opdef.inputs().len() == 0 {
300 break 'outer;
301 }
302 }
303 Ok(())
304}
305
306async fn query<'a, 'b, 'c>(
307 client: &'a mut Connection<'c>,
308 tx: PacketSender,
309 def: SqlOperationDefinition,
310 args: Vec<(Type, Packet)>,
311 stmt: &'b str,
312 _span: Span,
313) -> Result<Duration, Error>
314where
315 'b: 'a,
316{
317 let start = SystemTime::now();
318
319 let bound_args = common::bind_args(def.arguments(), &args)?;
320
321 let mut rows = client.query(stmt, bound_args).await?;
322
323 while let Some(row) = rows.next().await {
324 let _ = match row {
325 Ok(row) => tx.send(Packet::encode("output", row)),
326 Err(e) => tx.send(Packet::err("output", e.to_string())),
327 };
328 }
329
330 let duration = SystemTime::now().duration_since(start).unwrap();
331
332 Ok(duration)
333}
334
335async fn exec<'a, 'b, 'c>(
336 connection: &'a mut Connection<'c>,
337 tx: PacketSender,
338 def: SqlOperationDefinition,
339 args: Vec<(Type, Packet)>,
340 stmt: &'b str,
341 _span: Span,
342) -> Result<Duration, Error>
343where
344 'b: 'a,
345{
346 let start = SystemTime::now();
347
348 let bound_args = common::bind_args(def.arguments(), &args)?;
349 let packet = match connection.exec(stmt.to_owned(), bound_args).await {
350 Ok(num) => Packet::encode("output", num),
351 Err(err) => Packet::err("output", err.to_string()),
352 };
353
354 let _ = tx.send(packet);
355
356 let duration = SystemTime::now().duration_since(start).unwrap();
357
358 Ok(duration)
359}
360
361static POSITIONAL_ARGS: Lazy<Regex> = Lazy::new(|| Regex::new(r"\$(?<id>\d+)\b").unwrap());
362static WICK_ID_ARGS: Lazy<Regex> = Lazy::new(|| Regex::new(r"\$\{(?<id>\w+)\}").unwrap());
363
364fn normalize_operations(ops: &mut Vec<SqlOperationDefinition>, db: DbKind) {
365 for operations in ops {
366 match operations {
367 wick_config::config::components::SqlOperationDefinition::Query(ref mut op) => {
368 let (mut query, args) = normalize_inline_ids(op.query(), op.arguments().to_vec());
369 if db == DbKind::Mssql {
370 query = normalize_mssql_query(query);
371 }
372 let query = query.to_string();
373 op.set_query(query);
374 op.set_arguments(args);
375 }
376 wick_config::config::components::SqlOperationDefinition::Exec(ref mut op) => {
377 let (mut query, args) = normalize_inline_ids(op.exec(), op.arguments().to_vec());
378 if db == DbKind::Mssql {
379 query = normalize_mssql_query(query);
380 }
381 let query = query.to_string();
382 op.set_exec(query);
383 op.set_arguments(args);
384 }
385 };
386 }
387}
388
389fn normalize_inline_ids(orig_query: &str, mut orig_args: Vec<String>) -> (Cow<str>, Vec<String>) {
391 if orig_query.contains('$') {
392 let mut id_map: HashMap<String, usize> = orig_args
395 .iter()
396 .enumerate()
397 .map(|(i, id)| (id.clone(), i + 1))
398 .collect();
399
400 let captures = WICK_ID_ARGS.captures_iter(orig_query);
401 for id in captures {
402 let id = id.name("id").unwrap().as_str().to_owned();
403 if !id_map.contains_key(&id) {
404 id_map.insert(id.clone(), id_map.len() + 1);
405 orig_args.push(id.clone());
406 }
407 }
408
409 let normalized = WICK_ID_ARGS.replace_all(orig_query, |cap: &Captures| {
410 let id = cap.name("id").unwrap().as_str();
411 let id = id_map[id];
412 format!("${}", id)
413 });
414 debug!(%orig_query,%normalized, "sql:inline-replacement");
415 (normalized, orig_args)
416 } else {
417 (Cow::Borrowed(orig_query), orig_args)
418 }
419}
420
421fn normalize_mssql_query(original: Cow<str>) -> Cow<str> {
423 if original.contains('$') {
424 let normalized = POSITIONAL_ARGS.replace_all(&original, "@p${id}");
425 debug!(%original,%normalized, "sql:mssql:normalized query");
426 Cow::Owned(normalized.to_string())
427 } else {
428 original
429 }
430}
431
432#[cfg(test)]
433mod test {
434 use anyhow::Result;
435
436 use super::*;
437
438 #[test]
439 fn test_mssql_query_normalization() -> Result<()> {
440 let query = "select id,name from users where id=$1;";
441 let expected = "select id,name from users where id=@p1;";
442 let actual = normalize_mssql_query(Cow::Borrowed(query));
443 assert_eq!(actual, expected);
444
445 Ok(())
446 }
447
448 #[rstest::rstest]
449 #[case("select id,name from users where id=${id};",[],"select id,name from users where id=$1;",["id"])]
450 #[case("select id,name from users where email=$1, id=${id};",["email"],"select id,name from users where email=$1, id=$2;",["email","id"])]
451 #[case("select id,name from users where email=$1, id=${id}, email=${email};",["email"],"select id,name from users where email=$1, id=$2, email=$1;",["email","id"])]
452 #[case("select id,name from users where id=${id}, id2=${id}, id3=${id};",[],"select id,name from users where id=$1, id2=$1, id3=$1;",["id"])]
453 fn test_inline_id_normalization<const K: usize, const U: usize>(
454 #[case] orig_query: &str,
455 #[case] orig_args: [&str; K],
456 #[case] expected_query: &str,
457 #[case] expected_args: [&str; U],
458 ) -> Result<()> {
459 let (actual, actual_args) =
460 normalize_inline_ids(orig_query, orig_args.iter().copied().map(|s| s.to_owned()).collect());
461 let expected_args = expected_args.iter().map(|s| s.to_owned()).collect::<Vec<_>>();
462 assert_eq!(actual, expected_query);
463 assert_eq!(actual_args, expected_args);
464
465 Ok(())
466 }
467}
468
469#[cfg(test)]
470mod integration_test {
471 use anyhow::Result;
472 use flow_component::Component;
473 use futures::StreamExt;
474 use serde_json::json;
475 use wick_config::config::components::{
476 ComponentConfig,
477 SqlComponentConfigBuilder,
478 SqlOperationDefinition,
479 SqlQueryOperationDefinitionBuilder,
480 };
481 use wick_config::config::ResourceDefinition;
482 use wick_interface_types::{Field, Type};
483 use wick_packet::{packet_stream, Invocation, Packet};
484
485 use super::SqlComponent;
486
487 async fn init_mssql_component() -> Result<SqlComponent> {
488 let docker_host = std::env::var("TEST_HOST").unwrap();
489 let password = std::env::var("TEST_PASSWORD").unwrap();
490 let db_host = docker_host.split(':').next().unwrap();
491 let port = std::env::var("MSSQL_PORT").unwrap();
492 let user = "SA";
493 let db_name = "wick_test";
494
495 let mut config = SqlComponentConfigBuilder::default()
496 .resource("db")
497 .tls(false)
498 .build()
499 .unwrap();
500 let op = SqlQueryOperationDefinitionBuilder::default()
501 .name("test")
502 .query("select id,name from users where id=$1;")
503 .inputs([Field::new("input", Type::I32)])
504 .outputs([Field::new("output", Type::Object)])
505 .arguments(["input".to_owned()])
506 .build()
507 .unwrap();
508
509 config.operations_mut().push(SqlOperationDefinition::Query(op));
510 let mut app_config = wick_config::config::AppConfiguration::default();
511 app_config.add_resource(
512 "db",
513 ResourceDefinition::Url(
514 format!("mssql://{}:{}@{}:{}/{}", user, password, db_host, port, db_name)
515 .try_into()
516 .unwrap(),
517 ),
518 );
519
520 let component = SqlComponent::new(config, None, None, &app_config.resolver()).await?;
521
522 Ok(component)
523 }
524
525 #[test_logger::test(tokio::test)]
526 async fn test_mssql_basic() -> Result<()> {
527 let db = init_mssql_component().await?;
528 let input = packet_stream!(("input", 1_i32));
529 let inv = Invocation::test("mssql", "wick://__local__/test", input, None)?;
530 let response = db.handle(inv, Default::default(), Default::default()).await.unwrap();
531 let packets: Vec<_> = response.collect().await;
532
533 assert_eq!(
534 packets,
535 vec![
536 Ok(Packet::encode("output", json!({"id":1_i32, "name":"Test User"}))),
537 Ok(Packet::done("output"))
538 ]
539 );
540 Ok(())
541 }
542
543 async fn init_pg_component() -> Result<SqlComponent> {
544 let docker_host = std::env::var("TEST_HOST").unwrap();
545 let db_host = docker_host.split(':').next().unwrap();
546 let password = std::env::var("TEST_PASSWORD").unwrap();
547 let port = std::env::var("POSTGRES_PORT").unwrap();
548 let user = "postgres";
549 let db_name = "wick_test";
550
551 let mut config = SqlComponentConfigBuilder::default()
552 .resource("db")
553 .tls(false)
554 .build()
555 .unwrap();
556 let op = SqlQueryOperationDefinitionBuilder::default()
557 .name("test")
558 .query("select id,name from users where id=$1;")
559 .inputs([Field::new("input", Type::I32)])
560 .outputs([Field::new("output", Type::Object)])
561 .arguments(["input".to_owned()])
562 .build()
563 .unwrap();
564
565 config.operations_mut().push(SqlOperationDefinition::Query(op));
566 let mut app_config = wick_config::config::AppConfiguration::default();
567 app_config.add_resource(
568 "db",
569 ResourceDefinition::Url(
570 format!("postgres://{}:{}@{}:{}/{}", user, password, db_host, port, db_name)
571 .try_into()
572 .unwrap(),
573 ),
574 );
575
576 let component = SqlComponent::new(config, None, None, &app_config.resolver()).await?;
577
578 Ok(component)
579 }
580
581 #[test_logger::test(tokio::test)]
582 async fn test_pg_basic() -> Result<()> {
583 let pg = init_pg_component().await?;
584 let input = packet_stream!(("input", 1_u32));
585 let inv = Invocation::test("postgres", "wick://__local__/test", input, None)?;
586 let response = pg.handle(inv, Default::default(), Default::default()).await.unwrap();
587 let packets: Vec<_> = response.collect().await;
588
589 assert_eq!(
590 packets,
591 vec![
592 Ok(Packet::encode("output", json!({"id":1_i32, "name":"Test User"}))),
593 Ok(Packet::done("output"))
594 ]
595 );
596 Ok(())
597 }
598
599 async fn init_sqlite_component() -> Result<SqlComponent> {
600 let db = std::env::var("SQLITE_DB").unwrap();
601
602 let mut config = SqlComponentConfigBuilder::default()
603 .resource("db")
604 .tls(false)
605 .build()
606 .unwrap();
607 let op = SqlQueryOperationDefinitionBuilder::default()
608 .name("test")
609 .query("select id,name from users where id=$1;")
610 .inputs([Field::new("input", Type::I32)])
611 .outputs([Field::new("output", Type::Object)])
612 .arguments(["input".to_owned()])
613 .build()
614 .unwrap();
615
616 config.operations_mut().push(SqlOperationDefinition::Query(op));
617 let mut app_config = wick_config::config::AppConfiguration::default();
618 app_config.add_resource(
619 "db",
620 ResourceDefinition::Url(format!("file://{}", db).try_into().unwrap()),
621 );
622
623 let component = SqlComponent::new(config, None, None, &app_config.resolver()).await?;
624
625 Ok(component)
626 }
627
628 #[test_logger::test(tokio::test)]
629 async fn test_sqlite_basic() -> Result<()> {
630 let pg = init_sqlite_component().await?;
631 let input = packet_stream!(("input", 1_i32));
632 let inv = Invocation::test("sqlite", "wick://__local__/test", input, None)?;
633 let response = pg.handle(inv, Default::default(), Default::default()).await.unwrap();
634 let packets: Vec<_> = response.collect().await;
635
636 assert_eq!(
637 packets,
638 vec![
639 Ok(Packet::encode("output", json!({"id":1_i32, "name":"Test User"}))),
640 Ok(Packet::done("output"))
641 ]
642 );
643 Ok(())
644 }
645}