1use crate::output_stream::Next;
2use crate::output_stream::consumer::Sink;
3use crate::output_stream::event::Chunk;
4use crate::output_stream::line::adapter::AsyncLineSink;
5use crate::output_stream::visitor::AsyncStreamVisitor;
6use std::borrow::Cow;
7use std::io;
8use tokio::io::AsyncWriteExt;
9use typed_builder::TypedBuilder;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum LineWriteMode {
14 AsIs,
19
20 AppendLf,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum SinkWriteErrorAction {
30 Stop,
33
34 Continue,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum SinkWriteOperation {
41 Chunk,
43
44 Line,
46
47 LineDelimiter,
49}
50
51#[derive(Debug)]
53pub struct SinkWriteError {
54 stream_name: &'static str,
55 operation: SinkWriteOperation,
56 attempted_len: usize,
57 source: io::Error,
58}
59
60impl SinkWriteError {
61 pub(crate) fn new(
62 stream_name: &'static str,
63 operation: SinkWriteOperation,
64 attempted_len: usize,
65 source: io::Error,
66 ) -> Self {
67 Self {
68 stream_name,
69 operation,
70 attempted_len,
71 source,
72 }
73 }
74
75 #[must_use]
77 pub fn stream_name(&self) -> &'static str {
78 self.stream_name
79 }
80
81 #[must_use]
83 pub fn operation(&self) -> SinkWriteOperation {
84 self.operation
85 }
86
87 #[must_use]
89 pub fn attempted_len(&self) -> usize {
90 self.attempted_len
91 }
92
93 #[must_use]
95 pub fn source(&self) -> &io::Error {
96 &self.source
97 }
98}
99
100impl std::fmt::Display for SinkWriteError {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 write!(
103 f,
104 "Failed to write consumed output from stream '{}' to sink: {}",
105 self.stream_name, self.source
106 )
107 }
108}
109
110impl std::error::Error for SinkWriteError {
111 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
112 Some(&self.source)
113 }
114}
115
116pub trait SinkWriteErrorHandler: Send + 'static {
118 fn handle(&mut self, error: &SinkWriteError) -> SinkWriteErrorAction;
120}
121
122impl<F> SinkWriteErrorHandler for F
123where
124 F: FnMut(&SinkWriteError) -> SinkWriteErrorAction + Send + 'static,
125{
126 fn handle(&mut self, error: &SinkWriteError) -> SinkWriteErrorAction {
127 self(error)
128 }
129}
130
131#[derive(Debug, Clone, Copy)]
137pub struct WriteCollectionOptions<H = fn(&SinkWriteError) -> SinkWriteErrorAction> {
138 error_handler: H,
139}
140
141impl WriteCollectionOptions {
142 #[must_use]
144 pub fn fail_fast() -> Self {
145 Self {
146 error_handler: |_| SinkWriteErrorAction::Stop,
147 }
148 }
149
150 #[must_use]
152 pub fn log_and_continue() -> Self {
153 Self {
154 error_handler: |error| {
155 tracing::warn!(
156 stream = error.stream_name(),
157 operation = ?error.operation(),
158 attempted_len = error.attempted_len(),
159 source = %error.source(),
160 "Could not write collected output to write sink; continuing"
161 );
162 SinkWriteErrorAction::Continue
163 },
164 }
165 }
166
167 #[must_use]
169 pub fn with_error_handler<H>(handler: H) -> WriteCollectionOptions<H>
170 where
171 H: FnMut(&SinkWriteError) -> SinkWriteErrorAction + Send + 'static,
172 {
173 WriteCollectionOptions {
174 error_handler: handler,
175 }
176 }
177}
178
179impl<H> WriteCollectionOptions<H> {
180 pub(crate) fn into_error_handler(self) -> H {
181 self.error_handler
182 }
183}
184
185#[derive(TypedBuilder)]
186pub(crate) struct WriteChunks<W, H, F, B>
187where
188 W: Sink + AsyncWriteExt + Unpin,
189 H: SinkWriteErrorHandler,
190 B: AsRef<[u8]> + Send + 'static,
191 F: Fn(Chunk) -> B + Send + Sync + 'static,
192{
193 pub stream_name: &'static str,
194 pub writer: W,
195 pub error_handler: H,
196 pub mapper: F,
197 pub error: Option<SinkWriteError>,
198}
199
200impl<W, H, F, B> AsyncStreamVisitor for WriteChunks<W, H, F, B>
201where
202 W: Sink + AsyncWriteExt + Unpin,
203 H: SinkWriteErrorHandler,
204 B: AsRef<[u8]> + Send + 'static,
205 F: Fn(Chunk) -> B + Send + Sync + 'static,
206{
207 type Output = Result<W, SinkWriteError>;
208
209 async fn on_chunk(&mut self, chunk: Chunk) -> Next {
210 let mapped_output = (self.mapper)(chunk);
211 let bytes = mapped_output.as_ref();
212 let attempted_len = bytes.len();
213 let result = self.writer.write_all(bytes).await;
214 match handle_write_result(
215 self.stream_name,
216 &mut self.error_handler,
217 SinkWriteOperation::Chunk,
218 attempted_len,
219 result,
220 ) {
221 Ok(_) => Next::Continue,
222 Err(err) => {
223 self.error = Some(err);
224 Next::Break
225 }
226 }
227 }
228
229 fn into_output(self) -> Self::Output {
230 match self.error {
231 Some(err) => Err(err),
232 None => Ok(self.writer),
233 }
234 }
235}
236
237pub struct WriteLineSink<W, H, F, B>
244where
245 W: Sink + AsyncWriteExt + Unpin,
246 H: SinkWriteErrorHandler,
247 B: AsRef<[u8]> + Send + 'static,
248 F: Fn(Cow<'_, str>) -> B + Send + Sync + 'static,
249{
250 stream_name: &'static str,
251 writer: W,
252 error_handler: H,
253 mapper: F,
254 mode: LineWriteMode,
255 error: Option<SinkWriteError>,
256}
257
258impl<W, H, F, B> WriteLineSink<W, H, F, B>
259where
260 W: Sink + AsyncWriteExt + Unpin,
261 H: SinkWriteErrorHandler,
262 B: AsRef<[u8]> + Send + 'static,
263 F: Fn(Cow<'_, str>) -> B + Send + Sync + 'static,
264{
265 pub fn new(
269 stream_name: &'static str,
270 writer: W,
271 error_handler: H,
272 mapper: F,
273 mode: LineWriteMode,
274 ) -> Self {
275 Self {
276 stream_name,
277 writer,
278 error_handler,
279 mapper,
280 mode,
281 error: None,
282 }
283 }
284}
285
286impl<W, H, F, B> AsyncLineSink for WriteLineSink<W, H, F, B>
287where
288 W: Sink + AsyncWriteExt + Unpin,
289 H: SinkWriteErrorHandler,
290 B: AsRef<[u8]> + Send + 'static,
291 F: Fn(Cow<'_, str>) -> B + Send + Sync + 'static,
292{
293 type Output = Result<W, SinkWriteError>;
294
295 async fn on_line<'a>(&'a mut self, line: Cow<'a, str>) -> Next {
296 let mapped_output = (self.mapper)(line);
297 let bytes = mapped_output.as_ref();
298 match write_line(
299 self.stream_name,
300 &mut self.writer,
301 &mut self.error_handler,
302 bytes,
303 self.mode,
304 )
305 .await
306 {
307 Ok(()) => Next::Continue,
308 Err(err) => {
309 self.error = Some(err);
310 Next::Break
311 }
312 }
313 }
314
315 fn into_output(self) -> Self::Output {
316 match self.error {
317 Some(err) => Err(err),
318 None => Ok(self.writer),
319 }
320 }
321}
322
323async fn write_line<W, H>(
324 stream_name: &'static str,
325 write: &mut W,
326 error_handler: &mut H,
327 line: &[u8],
328 mode: LineWriteMode,
329) -> Result<(), SinkWriteError>
330where
331 W: AsyncWriteExt + Unpin,
332 H: SinkWriteErrorHandler,
333{
334 let line_write = write.write_all(line).await;
335 let line_written = handle_write_result(
336 stream_name,
337 error_handler,
338 SinkWriteOperation::Line,
339 line.len(),
340 line_write,
341 )?;
342 if !line_written || !matches!(mode, LineWriteMode::AppendLf) {
343 return Ok(());
344 }
345
346 handle_write_result(
347 stream_name,
348 error_handler,
349 SinkWriteOperation::LineDelimiter,
350 1,
351 write.write_all(b"\n").await,
352 )?;
353
354 Ok(())
355}
356
357fn handle_write_result<H>(
358 stream_name: &'static str,
359 error_handler: &mut H,
360 operation: SinkWriteOperation,
361 attempted_len: usize,
362 result: io::Result<()>,
363) -> Result<bool, SinkWriteError>
364where
365 H: SinkWriteErrorHandler,
366{
367 match result {
368 Ok(()) => Ok(true),
369 Err(source) => {
370 let error = SinkWriteError::new(stream_name, operation, attempted_len, source);
371 match error_handler.handle(&error) {
372 SinkWriteErrorAction::Stop => Err(error),
373 SinkWriteErrorAction::Continue => Ok(false),
374 }
375 }
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::output_stream::Subscription;
383 use crate::output_stream::consumer::Consumer;
384 use crate::output_stream::consumer::driver::spawn_consumer_async;
385 use crate::output_stream::event::StreamEvent;
386 use crate::output_stream::event::tests::event_receiver;
387 use crate::output_stream::line::adapter::LineAdapter;
388 use crate::output_stream::line::options::LineParsingOptions;
389 use assertr::prelude::*;
390 use bytes::Bytes;
391 use std::cell::Cell;
392 use std::io;
393 use std::pin::Pin;
394 use std::sync::{Arc, Mutex};
395 use std::task::{Context, Poll};
396 use tokio::io::AsyncWrite;
397
398 fn collect_chunks_into_write<S, W, H>(
402 stream_name: &'static str,
403 subscription: S,
404 write: W,
405 write_options: WriteCollectionOptions<H>,
406 ) -> Consumer<Result<W, SinkWriteError>>
407 where
408 S: Subscription,
409 W: Sink + AsyncWriteExt + Unpin,
410 H: SinkWriteErrorHandler,
411 {
412 spawn_consumer_async(
413 stream_name,
414 subscription,
415 WriteChunks::builder()
416 .stream_name(stream_name)
417 .writer(write)
418 .error_handler(write_options.into_error_handler())
419 .mapper((|chunk: Chunk| chunk) as fn(Chunk) -> Chunk)
420 .error(None)
421 .build(),
422 )
423 }
424
425 fn collect_chunks_into_write_mapped<S, W, B, F, H>(
426 stream_name: &'static str,
427 subscription: S,
428 write: W,
429 mapper: F,
430 write_options: WriteCollectionOptions<H>,
431 ) -> Consumer<Result<W, SinkWriteError>>
432 where
433 S: Subscription,
434 W: Sink + AsyncWriteExt + Unpin,
435 B: AsRef<[u8]> + Send + 'static,
436 F: Fn(Chunk) -> B + Send + Sync + 'static,
437 H: SinkWriteErrorHandler,
438 {
439 spawn_consumer_async(
440 stream_name,
441 subscription,
442 WriteChunks::builder()
443 .stream_name(stream_name)
444 .writer(write)
445 .error_handler(write_options.into_error_handler())
446 .mapper(mapper)
447 .error(None)
448 .build(),
449 )
450 }
451
452 fn collect_lines_into_write<S, W, H>(
453 stream_name: &'static str,
454 subscription: S,
455 write: W,
456 options: LineParsingOptions,
457 mode: LineWriteMode,
458 write_options: WriteCollectionOptions<H>,
459 ) -> Consumer<Result<W, SinkWriteError>>
460 where
461 S: Subscription,
462 W: Sink + AsyncWriteExt + Unpin,
463 H: SinkWriteErrorHandler,
464 {
465 spawn_consumer_async(
466 stream_name,
467 subscription,
468 LineAdapter::new(
469 options,
470 WriteLineSink::new(
471 stream_name,
472 write,
473 write_options.into_error_handler(),
474 (|line: Cow<'_, str>| line.into_owned()) as fn(Cow<'_, str>) -> String,
475 mode,
476 ),
477 ),
478 )
479 }
480
481 fn collect_lines_into_write_mapped<S, W, B, F, H>(
482 stream_name: &'static str,
483 subscription: S,
484 write: W,
485 mapper: F,
486 options: LineParsingOptions,
487 mode: LineWriteMode,
488 write_options: WriteCollectionOptions<H>,
489 ) -> Consumer<Result<W, SinkWriteError>>
490 where
491 S: Subscription,
492 W: Sink + AsyncWriteExt + Unpin,
493 B: AsRef<[u8]> + Send + 'static,
494 F: Fn(Cow<'_, str>) -> B + Send + Sync + 'static,
495 H: SinkWriteErrorHandler,
496 {
497 spawn_consumer_async(
498 stream_name,
499 subscription,
500 LineAdapter::new(
501 options,
502 WriteLineSink::new(
503 stream_name,
504 write,
505 write_options.into_error_handler(),
506 mapper,
507 mode,
508 ),
509 ),
510 )
511 }
512
513 #[derive(Debug)]
514 struct FailingWrite {
515 fail_after_successful_writes: usize,
516 error_kind: io::ErrorKind,
517 write_calls: usize,
518 bytes_written: usize,
519 }
520
521 impl FailingWrite {
522 fn new(fail_after_successful_writes: usize, error_kind: io::ErrorKind) -> Self {
523 Self {
524 fail_after_successful_writes,
525 error_kind,
526 write_calls: 0,
527 bytes_written: 0,
528 }
529 }
530 }
531
532 impl AsyncWrite for FailingWrite {
533 fn poll_write(
534 mut self: Pin<&mut Self>,
535 _cx: &mut Context<'_>,
536 buf: &[u8],
537 ) -> Poll<io::Result<usize>> {
538 self.write_calls += 1;
539 if self.write_calls > self.fail_after_successful_writes {
540 return Poll::Ready(Err(io::Error::new(
541 self.error_kind,
542 "injected write failure",
543 )));
544 }
545
546 self.bytes_written += buf.len();
547 Poll::Ready(Ok(buf.len()))
548 }
549
550 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
551 Poll::Ready(Ok(()))
552 }
553
554 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
555 Poll::Ready(Ok(()))
556 }
557 }
558
559 #[derive(Default)]
560 struct SendOnlyWrite {
561 bytes: Vec<u8>,
562 write_calls: Cell<usize>,
563 }
564
565 impl AsyncWrite for SendOnlyWrite {
566 fn poll_write(
567 mut self: Pin<&mut Self>,
568 _cx: &mut Context<'_>,
569 buf: &[u8],
570 ) -> Poll<io::Result<usize>> {
571 self.write_calls.set(self.write_calls.get() + 1);
572 self.bytes.extend_from_slice(buf);
573 Poll::Ready(Ok(buf.len()))
574 }
575
576 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
577 Poll::Ready(Ok(()))
578 }
579
580 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
581 Poll::Ready(Ok(()))
582 }
583 }
584
585 #[tokio::test]
586 async fn chunk_writer_reports_and_can_handle_sink_write_errors() {
587 let collector = collect_chunks_into_write(
588 "custom",
589 event_receiver(vec![
590 StreamEvent::Chunk(Chunk(Bytes::from_static(b"abc"))),
591 StreamEvent::Eof,
592 ])
593 .await,
594 FailingWrite::new(0, io::ErrorKind::BrokenPipe),
595 WriteCollectionOptions::fail_fast(),
596 );
597
598 match collector.wait().await {
599 Ok(Err(err)) => {
600 assert_that!(err.stream_name()).is_equal_to("custom");
601 assert_that!(err.source().kind()).is_equal_to(io::ErrorKind::BrokenPipe);
602 }
603 other => {
604 assert_that!(&other).fail(format_args!("expected sink write error, got {other:?}"));
605 }
606 }
607
608 let handled_count = Arc::new(Mutex::new(0_usize));
609 let count_for_handler = Arc::clone(&handled_count);
610 let collector = collect_chunks_into_write(
611 "custom",
612 event_receiver(vec![
613 StreamEvent::Chunk(Chunk(Bytes::from_static(b"abc"))),
614 StreamEvent::Eof,
615 ])
616 .await,
617 FailingWrite::new(0, io::ErrorKind::BrokenPipe),
618 WriteCollectionOptions::with_error_handler(move |err| {
619 assert_that!(err.stream_name()).is_equal_to("custom");
620 assert_that!(err.source().kind()).is_equal_to(io::ErrorKind::BrokenPipe);
621 *count_for_handler.lock().unwrap() += 1;
622 SinkWriteErrorAction::Continue
623 }),
624 );
625
626 let write = collector.wait().await.unwrap().unwrap();
627 assert_that!(write.bytes_written).is_equal_to(0);
628 assert_that!(*handled_count.lock().unwrap()).is_equal_to(1);
629 }
630
631 #[tokio::test]
632 async fn line_writer_reports_line_and_delimiter_write_errors() {
633 let line_error = collect_lines_into_write(
634 "custom",
635 event_receiver(vec![
636 StreamEvent::Chunk(Chunk(Bytes::from_static(b"line\n"))),
637 StreamEvent::Eof,
638 ])
639 .await,
640 FailingWrite::new(0, io::ErrorKind::BrokenPipe),
641 LineParsingOptions::default(),
642 LineWriteMode::AppendLf,
643 WriteCollectionOptions::fail_fast(),
644 )
645 .wait()
646 .await;
647 match line_error {
648 Ok(Err(err)) => {
649 assert_that!(err.source().kind()).is_equal_to(io::ErrorKind::BrokenPipe);
650 }
651 other => {
652 assert_that!(&other).fail(format_args!("expected line write error, got {other:?}"));
653 }
654 }
655
656 let delimiter_error = collect_lines_into_write(
657 "custom",
658 event_receiver(vec![
659 StreamEvent::Chunk(Chunk(Bytes::from_static(b"line\n"))),
660 StreamEvent::Eof,
661 ])
662 .await,
663 FailingWrite::new(1, io::ErrorKind::WriteZero),
664 LineParsingOptions::default(),
665 LineWriteMode::AppendLf,
666 WriteCollectionOptions::fail_fast(),
667 )
668 .wait()
669 .await;
670 match delimiter_error {
671 Ok(Err(err)) => {
672 assert_that!(err.source().kind()).is_equal_to(io::ErrorKind::WriteZero);
673 }
674 other => {
675 assert_that!(&other).fail(format_args!(
676 "expected delimiter write error, got {other:?}"
677 ));
678 }
679 }
680 }
681
682 #[tokio::test]
683 async fn line_writer_respects_requested_delimiter_mode() {
684 let collector = collect_lines_into_write(
685 "custom",
686 event_receiver(vec![
687 StreamEvent::Chunk(Chunk(Bytes::from_static(
688 b"Cargo.lock\nCargo.toml\nREADME.md\nsrc\ntarget\n",
689 ))),
690 StreamEvent::Eof,
691 ])
692 .await,
693 SendOnlyWrite::default(),
694 LineParsingOptions::default(),
695 LineWriteMode::AsIs,
696 WriteCollectionOptions::fail_fast(),
697 );
698
699 let writer = collector.wait().await.unwrap().unwrap();
700 assert_that!(writer.bytes).is_equal_to(b"Cargo.lockCargo.tomlREADME.mdsrctarget".to_vec());
701 }
702
703 #[tokio::test]
704 async fn chunk_writer_accepts_send_only_writer() {
705 let collector = collect_chunks_into_write(
706 "custom",
707 event_receiver(vec![
708 StreamEvent::Chunk(Chunk(Bytes::from_static(b"abc"))),
709 StreamEvent::Chunk(Chunk(Bytes::from_static(b"def"))),
710 StreamEvent::Eof,
711 ])
712 .await,
713 SendOnlyWrite::default(),
714 WriteCollectionOptions::fail_fast(),
715 );
716
717 let writer = collector.wait().await.unwrap().unwrap();
718 assert_that!(writer.bytes).is_equal_to(b"abcdef".to_vec());
719 assert_that!(writer.write_calls.get()).is_greater_than(0);
720 }
721
722 #[tokio::test]
723 async fn chunk_writer_mapped_writes_mapped_output() {
724 let collector = collect_chunks_into_write_mapped(
725 "custom",
726 event_receiver(vec![
727 StreamEvent::Chunk(Chunk(Bytes::from_static(b"Cargo.lock\n"))),
728 StreamEvent::Chunk(Chunk(Bytes::from_static(b"Cargo.toml\n"))),
729 StreamEvent::Eof,
730 ])
731 .await,
732 SendOnlyWrite::default(),
733 |chunk| String::from_utf8_lossy(chunk.as_ref()).to_string(),
734 WriteCollectionOptions::fail_fast(),
735 );
736
737 let writer = collector.wait().await.unwrap().unwrap();
738 assert_that!(writer.bytes).is_equal_to(b"Cargo.lock\nCargo.toml\n".to_vec());
739 }
740
741 #[tokio::test]
742 async fn mapped_writers_return_sink_write_errors() {
743 let chunk_error = collect_chunks_into_write_mapped(
744 "custom",
745 event_receiver(vec![
746 StreamEvent::Chunk(Chunk(Bytes::from_static(b"abc"))),
747 StreamEvent::Eof,
748 ])
749 .await,
750 FailingWrite::new(0, io::ErrorKind::ConnectionReset),
751 |chunk| chunk,
752 WriteCollectionOptions::fail_fast(),
753 )
754 .wait()
755 .await;
756 match chunk_error {
757 Ok(Err(err)) => {
758 assert_that!(err.source().kind()).is_equal_to(io::ErrorKind::ConnectionReset);
759 }
760 other => {
761 assert_that!(&other).fail(format_args!("expected sink write error, got {other:?}"));
762 }
763 }
764
765 let line_error = collect_lines_into_write_mapped(
766 "custom",
767 event_receiver(vec![
768 StreamEvent::Chunk(Chunk(Bytes::from_static(b"one\n"))),
769 StreamEvent::Eof,
770 ])
771 .await,
772 FailingWrite::new(0, io::ErrorKind::BrokenPipe),
773 |line| line.into_owned().into_bytes(),
774 LineParsingOptions::default(),
775 LineWriteMode::AsIs,
776 WriteCollectionOptions::fail_fast(),
777 )
778 .wait()
779 .await;
780 match line_error {
781 Ok(Err(err)) => {
782 assert_that!(err.source().kind()).is_equal_to(io::ErrorKind::BrokenPipe);
783 }
784 other => {
785 assert_that!(&other).fail(format_args!("expected sink write error, got {other:?}"));
786 }
787 }
788 }
789
790 #[tokio::test]
791 async fn line_write_error_handler_can_continue_after_sink_write_errors() {
792 let events = Arc::new(Mutex::new(Vec::new()));
793 let handled_events = Arc::clone(&events);
794 let collector = collect_lines_into_write(
795 "custom",
796 event_receiver(vec![
797 StreamEvent::Chunk(Chunk(Bytes::from_static(b"a\nb\n"))),
798 StreamEvent::Eof,
799 ])
800 .await,
801 FailingWrite::new(0, io::ErrorKind::BrokenPipe),
802 LineParsingOptions::default(),
803 LineWriteMode::AppendLf,
804 WriteCollectionOptions::with_error_handler(move |err| {
805 handled_events.lock().unwrap().push((
806 err.stream_name(),
807 err.operation(),
808 err.attempted_len(),
809 err.source().kind(),
810 ));
811 SinkWriteErrorAction::Continue
812 }),
813 );
814
815 let write = collector.wait().await.unwrap().unwrap();
816 assert_that!(write.bytes_written).is_equal_to(0);
817 assert_that!(events.lock().unwrap().as_slice()).is_equal_to([
818 (
819 "custom",
820 SinkWriteOperation::Line,
821 1,
822 io::ErrorKind::BrokenPipe,
823 ),
824 (
825 "custom",
826 SinkWriteOperation::Line,
827 1,
828 io::ErrorKind::BrokenPipe,
829 ),
830 ]);
831 }
832
833 #[tokio::test]
834 async fn chunk_write_error_handler_can_continue_then_stop() {
835 let handled_count = Arc::new(Mutex::new(0_usize));
836 let count_for_handler = Arc::clone(&handled_count);
837 let collector = collect_chunks_into_write(
838 "custom",
839 event_receiver(vec![
840 StreamEvent::Chunk(Chunk(Bytes::from_static(b"a"))),
841 StreamEvent::Chunk(Chunk(Bytes::from_static(b"b"))),
842 StreamEvent::Eof,
843 ])
844 .await,
845 FailingWrite::new(0, io::ErrorKind::BrokenPipe),
846 WriteCollectionOptions::with_error_handler(move |err| {
847 assert_that!(err.operation()).is_equal_to(SinkWriteOperation::Chunk);
848 let mut count = count_for_handler.lock().unwrap();
849 *count += 1;
850 if *count == 1 {
851 SinkWriteErrorAction::Continue
852 } else {
853 SinkWriteErrorAction::Stop
854 }
855 }),
856 );
857
858 match collector.wait().await {
859 Ok(Err(err)) => {
860 assert_that!(err.source().kind()).is_equal_to(io::ErrorKind::BrokenPipe);
861 }
862 other => {
863 assert_that!(&other).fail(format_args!("expected sink write error, got {other:?}"));
864 }
865 }
866 assert_that!(*handled_count.lock().unwrap()).is_equal_to(2);
867 }
868}