Skip to main content

umbral_testing/
lib.rs

1//! umbral-testing — test helpers for umbral apps.
2//!
3//! Test-case + client ergonomics, in the Rust shape. The
4//! repeated work in every plugin's `tests/integration.rs` was four
5//! things: spin up a fresh sqlite pool, build the router, send
6//! requests, read the response. This crate collapses those into:
7//!
8//! - [`TempPool`] — a tempfile-backed SQLite pool that's dropped
9//!   when the guard goes out of scope.
10//! - [`TestClient`] — wraps an [`axum::Router`] with HTTP-verb-
11//!   shaped methods, a per-client cookie jar (so a session set on
12//!   one request rides on the next), and JSON helpers.
13//! - [`TestResponse`] — owns the response bytes and headers and
14//!   exposes assertion helpers (`assert_status`, `body_json`,
15//!   `assert_body_contains`).
16//!
17//! This crate is **NOT** a plugin. It's a sibling utility library
18//! consumed by test code — drop `umbral-testing` into a crate's
19//! `[dev-dependencies]` and you don't carry it into release builds.
20//!
21//! ```ignore
22//! use umbral_testing::{TempPool, TestClient};
23//!
24//! #[tokio::test]
25//! async fn list_endpoint_returns_seeded_rows() {
26//!     let pool = TempPool::new().await;
27//!     // ... build router using pool.handle() ...
28//!     let client = TestClient::new(router);
29//!     let resp = client.get("/api/notes").await;
30//!     resp.assert_status_ok();
31//!     let notes: Vec<Note> = resp.body_json();
32//!     assert_eq!(notes.len(), 2);
33//! }
34//! ```
35
36use std::sync::Mutex;
37
38use axum::Router;
39use axum::body::Body;
40use http::header::{COOKIE, HeaderName, HeaderValue, SET_COOKIE};
41use http::{HeaderMap, Method, Request, StatusCode};
42use http_body_util::BodyExt;
43use serde::Serialize;
44use serde::de::DeserializeOwned;
45use sqlx::SqlitePool;
46use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
47use tempfile::TempDir;
48use tower::ServiceExt;
49
50/// A tempfile-backed SQLite pool. Holding the [`TempPool`] keeps the
51/// underlying directory alive; dropping it deletes the database file
52/// and every WAL artefact alongside.
53///
54/// In-memory SQLite (`sqlite::memory:`) would be the obvious choice
55/// but it isolates per-connection: pool size > 1 means different
56/// connections see different databases. The tempfile path
57/// sidesteps that completely.
58pub struct TempPool {
59    pool: SqlitePool,
60    _dir: TempDir,
61}
62
63impl TempPool {
64    /// Build a fresh pool with `max_connections = 5`.
65    pub async fn new() -> Self {
66        Self::with_max_connections(5).await
67    }
68
69    pub async fn with_max_connections(n: u32) -> Self {
70        let dir = tempfile::tempdir().expect("tempdir for TempPool");
71        let path = dir.path().join("umbral_test.sqlite");
72        let pool = SqlitePoolOptions::new()
73            .max_connections(n)
74            .connect_with(
75                SqliteConnectOptions::new()
76                    .filename(&path)
77                    .create_if_missing(true),
78            )
79            .await
80            .expect("connect to tempfile sqlite");
81        Self { pool, _dir: dir }
82    }
83
84    /// Borrow the underlying pool. Clone for ownership.
85    pub fn handle(&self) -> &SqlitePool {
86        &self.pool
87    }
88
89    /// Clone the pool out. Each clone shares the same backing
90    /// connection pool.
91    pub fn clone_handle(&self) -> SqlitePool {
92        self.pool.clone()
93    }
94}
95
96/// A simple cookie jar: a flat list of `name=value` pairs. Good
97/// enough for end-to-end test flows that exchange session and CSRF
98/// cookies; not RFC 6265 compliant (no domain, path, or expiry
99/// tracking).
100#[derive(Default)]
101struct CookieJar {
102    cookies: Vec<(String, String)>,
103}
104
105impl CookieJar {
106    fn set_from_header(&mut self, header: &str) {
107        // Server `Set-Cookie` shape: `name=value; Path=/; ...`. Take
108        // the bit before the first `;` as the name=value pair.
109        let pair = header.split(';').next().unwrap_or("").trim();
110        if let Some((name, value)) = pair.split_once('=') {
111            self.cookies.retain(|(n, _)| n != name);
112            self.cookies.push((name.to_string(), value.to_string()));
113        }
114    }
115
116    fn cookie_header(&self) -> Option<String> {
117        if self.cookies.is_empty() {
118            return None;
119        }
120        Some(
121            self.cookies
122                .iter()
123                .map(|(n, v)| format!("{n}={v}"))
124                .collect::<Vec<_>>()
125                .join("; "),
126        )
127    }
128
129    fn get(&self, name: &str) -> Option<&str> {
130        self.cookies
131            .iter()
132            .find(|(n, _)| n == name)
133            .map(|(_, v)| v.as_str())
134    }
135}
136
137/// A test client over an axum [`Router`]. Stateful: cookies set on
138/// one response automatically ride on the next request.
139pub struct TestClient {
140    router: Router,
141    jar: Mutex<CookieJar>,
142    default_headers: Mutex<HeaderMap>,
143}
144
145impl TestClient {
146    pub fn new(router: Router) -> Self {
147        Self {
148            router,
149            jar: Mutex::new(CookieJar::default()),
150            default_headers: Mutex::new(HeaderMap::new()),
151        }
152    }
153
154    /// Add a header that rides on every subsequent request. Useful
155    /// for setting an `Authorization` once per test.
156    pub fn set_default_header(&self, name: HeaderName, value: HeaderValue) {
157        self.default_headers
158            .lock()
159            .expect("default headers poisoned")
160            .insert(name, value);
161    }
162
163    /// Read a cookie the server has set on the jar.
164    pub fn cookie(&self, name: &str) -> Option<String> {
165        self.jar
166            .lock()
167            .expect("cookie jar poisoned")
168            .get(name)
169            .map(str::to_string)
170    }
171
172    pub async fn get(&self, uri: &str) -> TestResponse {
173        self.request(Method::GET, uri, Body::empty(), None).await
174    }
175
176    pub async fn post(&self, uri: &str, body: Body) -> TestResponse {
177        self.request(Method::POST, uri, body, None).await
178    }
179
180    /// POST a value serialized to JSON with `Content-Type:
181    /// application/json`.
182    pub async fn post_json<T: Serialize + ?Sized>(&self, uri: &str, body: &T) -> TestResponse {
183        let bytes = serde_json::to_vec(body).expect("serialize body");
184        self.request(
185            Method::POST,
186            uri,
187            Body::from(bytes),
188            Some(("content-type", "application/json")),
189        )
190        .await
191    }
192
193    pub async fn put_json<T: Serialize + ?Sized>(&self, uri: &str, body: &T) -> TestResponse {
194        let bytes = serde_json::to_vec(body).expect("serialize body");
195        self.request(
196            Method::PUT,
197            uri,
198            Body::from(bytes),
199            Some(("content-type", "application/json")),
200        )
201        .await
202    }
203
204    pub async fn delete(&self, uri: &str) -> TestResponse {
205        self.request(Method::DELETE, uri, Body::empty(), None).await
206    }
207
208    /// Send a fully-formed request. Use for verbs without a typed
209    /// helper or for unusual headers.
210    pub async fn send(&self, method: Method, uri: &str, body: Body) -> TestResponse {
211        self.request(method, uri, body, None).await
212    }
213
214    async fn request(
215        &self,
216        method: Method,
217        uri: &str,
218        body: Body,
219        content_type: Option<(&str, &str)>,
220    ) -> TestResponse {
221        let mut builder = Request::builder().method(method).uri(uri);
222
223        // Replay default headers.
224        for (k, v) in self.default_headers.lock().expect("dh").iter() {
225            builder = builder.header(k, v);
226        }
227        if let Some((k, v)) = content_type {
228            builder = builder.header(k, v);
229        }
230        if let Some(c) = self.jar.lock().expect("jar").cookie_header() {
231            builder = builder.header(COOKIE, c);
232        }
233
234        let req = builder.body(body).expect("build request");
235        let resp = self
236            .router
237            .clone()
238            .oneshot(req)
239            .await
240            .expect("router oneshot");
241
242        // Harvest set-cookies into the jar before stripping the body.
243        let status = resp.status();
244        let headers = resp.headers().clone();
245        for v in headers.get_all(SET_COOKIE) {
246            if let Ok(s) = v.to_str() {
247                self.jar.lock().expect("jar set").set_from_header(s);
248            }
249        }
250        let bytes = resp
251            .into_body()
252            .collect()
253            .await
254            .expect("collect body")
255            .to_bytes();
256
257        TestResponse {
258            status,
259            headers,
260            body: bytes.to_vec(),
261        }
262    }
263}
264
265/// The result of one round trip. Owns the response bytes so the
266/// caller can read them more than once (e.g. snapshot the raw body
267/// before parsing JSON, then assert).
268pub struct TestResponse {
269    pub status: StatusCode,
270    pub headers: HeaderMap,
271    pub body: Vec<u8>,
272}
273
274impl TestResponse {
275    pub fn status(&self) -> StatusCode {
276        self.status
277    }
278
279    pub fn headers(&self) -> &HeaderMap {
280        &self.headers
281    }
282
283    pub fn body_bytes(&self) -> &[u8] {
284        &self.body
285    }
286
287    pub fn body_text(&self) -> String {
288        String::from_utf8_lossy(&self.body).into_owned()
289    }
290
291    /// Parse the body as JSON. Panics with the raw body in the
292    /// message on a parse error — much friendlier in a failing test
293    /// than a bare serde error.
294    pub fn body_json<T: DeserializeOwned>(&self) -> T {
295        serde_json::from_slice(&self.body).unwrap_or_else(|e| {
296            panic!(
297                "body_json: failed to parse response as JSON ({e}). raw body:\n{}",
298                self.body_text()
299            )
300        })
301    }
302
303    /// Read the value of a single response header. None if missing
304    /// or non-UTF-8.
305    pub fn header(&self, name: &str) -> Option<String> {
306        self.headers
307            .get(name)
308            .and_then(|v| v.to_str().ok())
309            .map(str::to_string)
310    }
311
312    pub fn assert_status(&self, expected: StatusCode) -> &Self {
313        assert_eq!(
314            self.status,
315            expected,
316            "expected status {expected}, got {} with body:\n{}",
317            self.status,
318            self.body_text()
319        );
320        self
321    }
322
323    pub fn assert_status_ok(&self) -> &Self {
324        self.assert_status(StatusCode::OK)
325    }
326
327    pub fn assert_body_contains(&self, needle: &str) -> &Self {
328        let body = self.body_text();
329        assert!(
330            body.contains(needle),
331            "expected body to contain {needle:?}\n--- got ---\n{body}\n-----------"
332        );
333        self
334    }
335
336    pub fn assert_header(&self, name: &str, expected: &str) -> &Self {
337        let actual = self.header(name);
338        assert_eq!(
339            actual.as_deref(),
340            Some(expected),
341            "expected header {name} to be {expected:?}, got {actual:?}"
342        );
343        self
344    }
345}
346
347// =========================================================================
348// Factory — realistic test data (feature #79).
349// =========================================================================
350
351/// Re-export of the [`fake`] crate so factories can reach its generators
352/// (`umbral_testing::fake::faker::...`, the `Fake` trait) without adding a
353/// direct dependency of their own.
354pub use fake;
355
356use std::sync::atomic::{AtomicU64, Ordering};
357
358/// A process-wide monotonic counter for unique values within a test run.
359/// Use it to keep `unique` columns (slugs, emails, crate names) from
360/// colliding across a `create_batch`:
361///
362/// ```ignore
363/// slug: format!("plugin-{}", umbral_testing::seq()),
364/// ```
365pub fn seq() -> u64 {
366    static SEQ: AtomicU64 = AtomicU64::new(0);
367    SEQ.fetch_add(1, Ordering::Relaxed) + 1
368}
369
370/// Error from a [`Factory`] persistence call.
371#[derive(Debug)]
372pub enum FactoryError {
373    /// The ORM write failed (constraint violation, missing table, an FK
374    /// that doesn't exist yet, …).
375    Write(umbral::orm::write::WriteError),
376}
377
378impl std::fmt::Display for FactoryError {
379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        match self {
381            FactoryError::Write(e) => write!(f, "factory write failed: {e}"),
382        }
383    }
384}
385
386impl std::error::Error for FactoryError {}
387
388impl From<umbral::orm::write::WriteError> for FactoryError {
389    fn from(e: umbral::orm::write::WriteError) -> Self {
390        FactoryError::Write(e)
391    }
392}
393
394/// A factory for producing realistic instances of a model — the
395/// factory_boy / FactoryBot shape, in Rust.
396///
397/// You define a zero-sized marker type and point it at a [`Model`] through
398/// the associated type. The orphan rule is why the impl lives on a marker
399/// rather than on the model: in a downstream test crate both the model and
400/// this trait are foreign, so `impl Factory for Plugin` wouldn't compile —
401/// but `impl Factory for PluginFactory` (a local marker) does.
402///
403/// ```ignore
404/// use umbral_testing::{Factory, fake::{Fake, faker::{lorem::en::*, company::en::*}}, seq};
405///
406/// struct PluginFactory;
407/// impl Factory for PluginFactory {
408///     type Model = Plugin;
409///     fn build() -> Plugin {
410///         let mut p = Plugin::default();
411///         p.name = CompanyName().fake();
412///         p.slug = format!("plugin-{}", seq());          // unique per call
413///         p.short_description = Sentence(4..8).fake();
414///         p
415///     }
416/// }
417///
418/// // In a test, after `App::builder()...build()` has set the ambient pool
419/// // and the tables exist:
420/// let one      = PluginFactory::create().await?;                    // one row
421/// let many     = PluginFactory::create_batch(5).await?;             // five rows
422/// let featured = PluginFactory::create_with(|p| p.featured = true).await?;
423/// ```
424///
425/// [`build`](Factory::build) is pure (no I/O); the `create*` methods
426/// persist through the ORM against the ambient pool, so a built app must
427/// be in scope. Combine with [`TestClient`] to then exercise a handler
428/// against the rows the factory produced.
429///
430/// [`Model`]: umbral::orm::Model
431#[async_trait::async_trait]
432pub trait Factory {
433    /// The model this factory produces. The bound set is exactly what
434    /// `#[derive(Model)]` already provides on every model (the ORM's
435    /// `create` path needs `Serialize` + `FromRow` + `HydrateRelated`), so
436    /// in practice you only ever write `type Model = YourModel;`.
437    type Model: umbral::orm::Model
438        + serde::Serialize
439        + for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
440        + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
441        + umbral::orm::HydrateRelated;
442
443    /// A fresh, unsaved instance with realistic fake values. Pure — no
444    /// database I/O. Override `unique` fields with [`seq`] so a batch
445    /// doesn't collide.
446    fn build() -> Self::Model;
447
448    /// Build and persist one row through the ORM.
449    async fn create() -> Result<Self::Model, FactoryError> {
450        Self::create_with(|_| {}).await
451    }
452
453    /// Build one row, apply `tweak` to override specific fields, then
454    /// persist. This is the `create(featured = true)` override hook.
455    async fn create_with<F>(tweak: F) -> Result<Self::Model, FactoryError>
456    where
457        F: FnOnce(&mut Self::Model) + Send,
458    {
459        let mut instance = Self::build();
460        tweak(&mut instance);
461        umbral::orm::Manager::<Self::Model>::default()
462            .create(instance)
463            .await
464            .map_err(FactoryError::Write)
465    }
466
467    /// Build and persist `n` rows.
468    async fn create_batch(n: usize) -> Result<Vec<Self::Model>, FactoryError> {
469        let mut out = Vec::with_capacity(n);
470        for _ in 0..n {
471            out.push(Self::create().await?);
472        }
473        Ok(out)
474    }
475}