Skip to main content

test_better_core/
context.rs

1//! [`ContextExt`]: attach "while doing X" context to a fallible value.
2//!
3//! `ContextExt` is what makes `?` carry a story. A bare `?` propagates a
4//! failure as-is; `.context("loading the fixture")?` propagates the same
5//! failure with a frame explaining what the test was attempting.
6//!
7//! When the error path already holds a [`TestError`], the context frame is
8//! pushed onto it directly: the original kind, location, and payload are kept,
9//! and the error is *not* re-wrapped as a [`Payload::Other`].
10
11use std::borrow::Cow;
12use std::error::Error;
13
14use crate::error::{ContextFrame, ErrorKind, Payload, TestError};
15use crate::result::TestResult;
16
17/// Attaches context to the failure path of a [`Result`] or the [`None`] of an
18/// [`Option`].
19pub trait ContextExt<T> {
20    /// Adds a context frame describing the operation that was being attempted.
21    ///
22    /// On the success path the value is returned unchanged.
23    fn context(self, message: impl Into<Cow<'static, str>>) -> TestResult<T>;
24
25    /// Like [`context`](ContextExt::context), but the message is computed by
26    /// `f`, which runs only on the failure path.
27    fn with_context<F, S>(self, f: F) -> TestResult<T>
28    where
29        F: FnOnce() -> S,
30        S: Into<Cow<'static, str>>;
31}
32
33/// Coerces an arbitrary error into a [`TestError`].
34///
35/// If `error` already *is* a `TestError` it is returned untouched (no
36/// double-wrapping); otherwise it becomes the [`Payload::Other`] of a fresh
37/// [`ErrorKind::Custom`] error, so its source chain stays walkable.
38#[track_caller]
39pub(crate) fn coerce<E>(error: E) -> TestError
40where
41    E: Error + Send + Sync + 'static,
42{
43    let boxed: Box<dyn Error + Send + Sync> = Box::new(error);
44    match boxed.downcast::<TestError>() {
45        Ok(test_error) => *test_error,
46        Err(other) => TestError::new(ErrorKind::Custom).with_payload(Payload::Other(other)),
47    }
48}
49
50/// The error produced when context is attached to a [`None`].
51#[track_caller]
52fn none_error() -> TestError {
53    TestError::new(ErrorKind::Custom).with_message("value was None")
54}
55
56impl<T, E> ContextExt<T> for Result<T, E>
57where
58    E: Error + Send + Sync + 'static,
59{
60    #[track_caller]
61    fn context(self, message: impl Into<Cow<'static, str>>) -> TestResult<T> {
62        match self {
63            Ok(value) => Ok(value),
64            Err(error) => Err(coerce(error).with_context_frame(ContextFrame::new(message))),
65        }
66    }
67
68    #[track_caller]
69    fn with_context<F, S>(self, f: F) -> TestResult<T>
70    where
71        F: FnOnce() -> S,
72        S: Into<Cow<'static, str>>,
73    {
74        match self {
75            Ok(value) => Ok(value),
76            Err(error) => Err(coerce(error).with_context_frame(ContextFrame::new(f()))),
77        }
78    }
79}
80
81impl<T> ContextExt<T> for Option<T> {
82    #[track_caller]
83    fn context(self, message: impl Into<Cow<'static, str>>) -> TestResult<T> {
84        match self {
85            Some(value) => Ok(value),
86            None => Err(none_error().with_context_frame(ContextFrame::new(message))),
87        }
88    }
89
90    #[track_caller]
91    fn with_context<F, S>(self, f: F) -> TestResult<T>
92    where
93        F: FnOnce() -> S,
94        S: Into<Cow<'static, str>>,
95    {
96        match self {
97            Some(value) => Ok(value),
98            None => Err(none_error().with_context_frame(ContextFrame::new(f()))),
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::{OrFail, TestResult};
107    use std::cell::Cell;
108    use test_better_matchers::{check, eq, is_true};
109
110    fn io_error() -> std::io::Error {
111        std::io::Error::new(std::io::ErrorKind::NotFound, "missing file")
112    }
113
114    #[test]
115    fn context_passes_through_ok() -> TestResult {
116        let value: TestResult<i32> = Ok::<i32, std::io::Error>(7).context("unused");
117        check!(value?).satisfies(eq(7)).or_fail()?;
118        Ok(())
119    }
120
121    #[test]
122    fn context_passes_through_some() -> TestResult {
123        let value: TestResult<i32> = Some(7).context("unused");
124        check!(value?).satisfies(eq(7)).or_fail()?;
125        Ok(())
126    }
127
128    #[test]
129    fn context_wraps_foreign_error_as_other_payload() -> TestResult {
130        let failing: Result<(), std::io::Error> = Err(io_error());
131        let line = line!() + 1;
132        let result = failing.context("reading the fixture");
133        let error = result.expect_err("err path");
134        check!(error.kind)
135            .satisfies(eq(ErrorKind::Custom))
136            .or_fail()?;
137        check!(error.location.line())
138            .satisfies(eq(line))
139            .or_fail()?;
140        check!(matches!(error.payload.as_deref(), Some(Payload::Other(_))))
141            .satisfies(is_true())
142            .or_fail()?;
143        check!(error.context.len()).satisfies(eq(1)).or_fail()?;
144        check!(error.context[0].message.as_ref())
145            .satisfies(eq("reading the fixture"))
146            .or_fail()?;
147        Ok(())
148    }
149
150    #[test]
151    fn context_does_not_double_wrap_a_test_error() -> TestResult {
152        let original = TestError::assertion("values differ");
153        let original_line = original.location.line();
154        let error = Err::<(), _>(original)
155            .context("comparing the results")
156            .expect_err("err path");
157        // Kind, location, and the (absent) payload of the original are kept.
158        check!(error.kind)
159            .satisfies(eq(ErrorKind::Assertion))
160            .or_fail()?;
161        check!(error.location.line())
162            .satisfies(eq(original_line))
163            .or_fail()?;
164        check!(error.payload.is_none())
165            .satisfies(is_true())
166            .or_fail()?;
167        check!(error.message.as_deref())
168            .satisfies(eq(Some("values differ")))
169            .or_fail()?;
170        check!(error.context.len()).satisfies(eq(1)).or_fail()?;
171        check!(error.context[0].message.as_ref())
172            .satisfies(eq("comparing the results"))
173            .or_fail()?;
174        Ok(())
175    }
176
177    #[test]
178    fn context_frames_accumulate_in_order() -> TestResult {
179        let error = Err::<(), _>(io_error())
180            .context("inner step")
181            .context("outer step")
182            .expect_err("err path");
183        let messages: Vec<_> = error.context.iter().map(|f| f.message.as_ref()).collect();
184        check!(messages)
185            .satisfies(eq(vec!["inner step", "outer step"]))
186            .or_fail()?;
187        Ok(())
188    }
189
190    #[test]
191    fn none_gains_context_and_caller_location() -> TestResult {
192        let missing: Option<i32> = None;
193        let line = line!() + 1;
194        let result = missing.context("looking up the user");
195        let error = result.expect_err("err path");
196        check!(error.kind)
197            .satisfies(eq(ErrorKind::Custom))
198            .or_fail()?;
199        check!(error.location.line())
200            .satisfies(eq(line))
201            .or_fail()?;
202        check!(error.context[0].message.as_ref())
203            .satisfies(eq("looking up the user"))
204            .or_fail()?;
205        Ok(())
206    }
207
208    #[test]
209    fn with_context_runs_the_closure_only_on_failure() -> TestResult {
210        let calls = Cell::new(0);
211        let ok: TestResult<i32> = Ok::<i32, std::io::Error>(1).with_context(|| {
212            calls.set(calls.get() + 1);
213            "unused"
214        });
215        check!(ok?).satisfies(eq(1)).or_fail()?;
216        check!(calls.get()).satisfies(eq(0)).or_fail()?;
217
218        let err = Err::<(), _>(io_error())
219            .with_context(|| {
220                calls.set(calls.get() + 1);
221                "computed context"
222            })
223            .expect_err("err path");
224        check!(calls.get()).satisfies(eq(1)).or_fail()?;
225        check!(err.context[0].message.as_ref())
226            .satisfies(eq("computed context"))
227            .or_fail()?;
228        Ok(())
229    }
230}