1use std::fmt;
33use std::sync::Arc;
34use thiserror::Error;
35
36use crate::handler::CommandContext;
37use clap::ArgMatches;
38
39#[derive(Debug, Clone)]
43pub enum RenderedOutput {
44 Text(String),
46 Binary(Vec<u8>, String),
48 Silent,
50}
51
52impl RenderedOutput {
53 pub fn is_text(&self) -> bool {
55 matches!(self, RenderedOutput::Text(_))
56 }
57
58 pub fn is_binary(&self) -> bool {
60 matches!(self, RenderedOutput::Binary(_, _))
61 }
62
63 pub fn is_silent(&self) -> bool {
65 matches!(self, RenderedOutput::Silent)
66 }
67
68 pub fn as_text(&self) -> Option<&str> {
70 match self {
71 RenderedOutput::Text(s) => Some(s),
72 _ => None,
73 }
74 }
75
76 pub fn as_binary(&self) -> Option<(&[u8], &str)> {
78 match self {
79 RenderedOutput::Binary(bytes, filename) => Some((bytes, filename)),
80 _ => None,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum HookPhase {
88 PreDispatch,
90 PostDispatch,
92 PostOutput,
94}
95
96impl fmt::Display for HookPhase {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 match self {
99 HookPhase::PreDispatch => write!(f, "pre-dispatch"),
100 HookPhase::PostDispatch => write!(f, "post-dispatch"),
101 HookPhase::PostOutput => write!(f, "post-output"),
102 }
103 }
104}
105
106#[derive(Debug, Error)]
108#[error("hook error ({phase}): {message}")]
109pub struct HookError {
110 pub message: String,
112 pub phase: HookPhase,
114 #[source]
116 pub source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
117}
118
119impl HookError {
120 pub fn pre_dispatch(message: impl Into<String>) -> Self {
122 Self {
123 message: message.into(),
124 phase: HookPhase::PreDispatch,
125 source: None,
126 }
127 }
128
129 pub fn post_dispatch(message: impl Into<String>) -> Self {
131 Self {
132 message: message.into(),
133 phase: HookPhase::PostDispatch,
134 source: None,
135 }
136 }
137
138 pub fn post_output(message: impl Into<String>) -> Self {
140 Self {
141 message: message.into(),
142 phase: HookPhase::PostOutput,
143 source: None,
144 }
145 }
146
147 pub fn with_source<E>(mut self, source: E) -> Self
149 where
150 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
151 {
152 self.source = Some(source.into());
153 self
154 }
155}
156
157pub type PreDispatchFn =
159 Arc<dyn Fn(&ArgMatches, &CommandContext) -> Result<(), HookError> + Send + Sync>;
160
161pub type PostDispatchFn = Arc<
163 dyn Fn(&ArgMatches, &CommandContext, serde_json::Value) -> Result<serde_json::Value, HookError>
164 + Send
165 + Sync,
166>;
167
168pub type PostOutputFn = Arc<
170 dyn Fn(&ArgMatches, &CommandContext, RenderedOutput) -> Result<RenderedOutput, HookError>
171 + Send
172 + Sync,
173>;
174
175#[derive(Clone, Default)]
179pub struct Hooks {
180 pre_dispatch: Vec<PreDispatchFn>,
181 post_dispatch: Vec<PostDispatchFn>,
182 post_output: Vec<PostOutputFn>,
183}
184
185impl Hooks {
186 pub fn new() -> Self {
188 Self::default()
189 }
190
191 pub fn is_empty(&self) -> bool {
193 self.pre_dispatch.is_empty() && self.post_dispatch.is_empty() && self.post_output.is_empty()
194 }
195
196 pub fn pre_dispatch<F>(mut self, f: F) -> Self
198 where
199 F: Fn(&ArgMatches, &CommandContext) -> Result<(), HookError> + Send + Sync + 'static,
200 {
201 self.pre_dispatch.push(Arc::new(f));
202 self
203 }
204
205 pub fn post_dispatch<F>(mut self, f: F) -> Self
207 where
208 F: Fn(
209 &ArgMatches,
210 &CommandContext,
211 serde_json::Value,
212 ) -> Result<serde_json::Value, HookError>
213 + Send
214 + Sync
215 + 'static,
216 {
217 self.post_dispatch.push(Arc::new(f));
218 self
219 }
220
221 pub fn post_output<F>(mut self, f: F) -> Self
223 where
224 F: Fn(&ArgMatches, &CommandContext, RenderedOutput) -> Result<RenderedOutput, HookError>
225 + Send
226 + Sync
227 + 'static,
228 {
229 self.post_output.push(Arc::new(f));
230 self
231 }
232
233 pub fn run_pre_dispatch(
235 &self,
236 matches: &ArgMatches,
237 ctx: &CommandContext,
238 ) -> Result<(), HookError> {
239 for hook in &self.pre_dispatch {
240 hook(matches, ctx)?;
241 }
242 Ok(())
243 }
244
245 pub fn run_post_dispatch(
247 &self,
248 matches: &ArgMatches,
249 ctx: &CommandContext,
250 data: serde_json::Value,
251 ) -> Result<serde_json::Value, HookError> {
252 let mut current = data;
253 for hook in &self.post_dispatch {
254 current = hook(matches, ctx, current)?;
255 }
256 Ok(current)
257 }
258
259 pub fn run_post_output(
261 &self,
262 matches: &ArgMatches,
263 ctx: &CommandContext,
264 output: RenderedOutput,
265 ) -> Result<RenderedOutput, HookError> {
266 let mut current = output;
267 for hook in &self.post_output {
268 current = hook(matches, ctx, current)?;
269 }
270 Ok(current)
271 }
272}
273
274impl fmt::Debug for Hooks {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 f.debug_struct("Hooks")
277 .field("pre_dispatch_count", &self.pre_dispatch.len())
278 .field("post_dispatch_count", &self.post_dispatch.len())
279 .field("post_output_count", &self.post_output.len())
280 .finish()
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 fn test_context() -> CommandContext {
289 CommandContext {
290 command_path: vec!["test".into()],
291 }
292 }
293
294 fn test_matches() -> ArgMatches {
295 clap::Command::new("test").get_matches_from(vec!["test"])
296 }
297
298 #[test]
299 fn test_rendered_output_variants() {
300 let text = RenderedOutput::Text("hello".into());
301 assert!(text.is_text());
302 assert!(!text.is_binary());
303 assert!(!text.is_silent());
304 assert_eq!(text.as_text(), Some("hello"));
305
306 let binary = RenderedOutput::Binary(vec![1, 2, 3], "file.bin".into());
307 assert!(!binary.is_text());
308 assert!(binary.is_binary());
309 assert_eq!(binary.as_binary(), Some((&[1u8, 2, 3][..], "file.bin")));
310
311 let silent = RenderedOutput::Silent;
312 assert!(silent.is_silent());
313 }
314
315 #[test]
316 fn test_hook_error_creation() {
317 let err = HookError::pre_dispatch("test error");
318 assert_eq!(err.phase, HookPhase::PreDispatch);
319 assert_eq!(err.message, "test error");
320 }
321
322 #[test]
323 fn test_hooks_empty() {
324 let hooks = Hooks::new();
325 assert!(hooks.is_empty());
326 }
327
328 #[test]
329 fn test_pre_dispatch_success() {
330 let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
331 let called_clone = called.clone();
332
333 let hooks = Hooks::new().pre_dispatch(move |_, _| {
334 called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
335 Ok(())
336 });
337
338 let ctx = test_context();
339 let matches = test_matches();
340 let result = hooks.run_pre_dispatch(&matches, &ctx);
341
342 assert!(result.is_ok());
343 assert!(called.load(std::sync::atomic::Ordering::SeqCst));
344 }
345
346 #[test]
347 fn test_pre_dispatch_error_aborts() {
348 let hooks = Hooks::new()
349 .pre_dispatch(|_, _| Err(HookError::pre_dispatch("first fails")))
350 .pre_dispatch(|_, _| panic!("should not be called"));
351
352 let ctx = test_context();
353 let matches = test_matches();
354 let result = hooks.run_pre_dispatch(&matches, &ctx);
355
356 assert!(result.is_err());
357 }
358
359 #[test]
360 fn test_post_dispatch_transformation() {
361 use serde_json::json;
362
363 let hooks = Hooks::new().post_dispatch(|_, _, mut data| {
364 if let Some(obj) = data.as_object_mut() {
365 obj.insert("modified".into(), json!(true));
366 }
367 Ok(data)
368 });
369
370 let ctx = test_context();
371 let matches = test_matches();
372 let data = json!({"value": 42});
373 let result = hooks.run_post_dispatch(&matches, &ctx, data);
374
375 assert!(result.is_ok());
376 let output = result.unwrap();
377 assert_eq!(output["value"], 42);
378 assert_eq!(output["modified"], true);
379 }
380
381 #[test]
382 fn test_post_output_transformation() {
383 let hooks = Hooks::new().post_output(|_, _, output| {
384 if let RenderedOutput::Text(text) = output {
385 Ok(RenderedOutput::Text(text.to_uppercase()))
386 } else {
387 Ok(output)
388 }
389 });
390
391 let ctx = test_context();
392 let matches = test_matches();
393 let result = hooks.run_post_output(&matches, &ctx, RenderedOutput::Text("hello".into()));
394
395 assert!(result.is_ok());
396 assert_eq!(result.unwrap().as_text(), Some("HELLO"));
397 }
398}