1use std::fmt;
33use std::sync::Arc;
34use thiserror::Error;
35
36use crate::handler::CommandContext;
37use clap::ArgMatches;
38
39#[derive(Debug, Clone)]
43pub enum RenderedOutput {
44 Text(String),
46 Binary(Vec<u8>, String),
48 Silent,
50}
51
52impl RenderedOutput {
53 pub fn is_text(&self) -> bool {
55 matches!(self, RenderedOutput::Text(_))
56 }
57
58 pub fn is_binary(&self) -> bool {
60 matches!(self, RenderedOutput::Binary(_, _))
61 }
62
63 pub fn is_silent(&self) -> bool {
65 matches!(self, RenderedOutput::Silent)
66 }
67
68 pub fn as_text(&self) -> Option<&str> {
70 match self {
71 RenderedOutput::Text(s) => Some(s),
72 _ => None,
73 }
74 }
75
76 pub fn as_binary(&self) -> Option<(&[u8], &str)> {
78 match self {
79 RenderedOutput::Binary(bytes, filename) => Some((bytes, filename)),
80 _ => None,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum HookPhase {
88 PreDispatch,
90 PostDispatch,
92 PostOutput,
94}
95
96impl fmt::Display for HookPhase {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 match self {
99 HookPhase::PreDispatch => write!(f, "pre-dispatch"),
100 HookPhase::PostDispatch => write!(f, "post-dispatch"),
101 HookPhase::PostOutput => write!(f, "post-output"),
102 }
103 }
104}
105
106#[derive(Debug, Error)]
108#[error("hook error ({phase}): {message}")]
109pub struct HookError {
110 pub message: String,
112 pub phase: HookPhase,
114 #[source]
116 pub source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
117}
118
119impl HookError {
120 pub fn pre_dispatch(message: impl Into<String>) -> Self {
122 Self {
123 message: message.into(),
124 phase: HookPhase::PreDispatch,
125 source: None,
126 }
127 }
128
129 pub fn post_dispatch(message: impl Into<String>) -> Self {
131 Self {
132 message: message.into(),
133 phase: HookPhase::PostDispatch,
134 source: None,
135 }
136 }
137
138 pub fn post_output(message: impl Into<String>) -> Self {
140 Self {
141 message: message.into(),
142 phase: HookPhase::PostOutput,
143 source: None,
144 }
145 }
146
147 pub fn with_source<E>(mut self, source: E) -> Self
149 where
150 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
151 {
152 self.source = Some(source.into());
153 self
154 }
155}
156
157pub type PreDispatchFn =
162 Arc<dyn Fn(&ArgMatches, &mut CommandContext) -> Result<(), HookError> + Send + Sync>;
163
164pub type PostDispatchFn = Arc<
166 dyn Fn(&ArgMatches, &CommandContext, serde_json::Value) -> Result<serde_json::Value, HookError>
167 + Send
168 + Sync,
169>;
170
171pub type PostOutputFn = Arc<
173 dyn Fn(&ArgMatches, &CommandContext, RenderedOutput) -> Result<RenderedOutput, HookError>
174 + Send
175 + Sync,
176>;
177
178#[derive(Clone, Default)]
182pub struct Hooks {
183 pre_dispatch: Vec<PreDispatchFn>,
184 post_dispatch: Vec<PostDispatchFn>,
185 post_output: Vec<PostOutputFn>,
186}
187
188impl Hooks {
189 pub fn new() -> Self {
191 Self::default()
192 }
193
194 pub fn is_empty(&self) -> bool {
196 self.pre_dispatch.is_empty() && self.post_dispatch.is_empty() && self.post_output.is_empty()
197 }
198
199 pub fn pre_dispatch<F>(mut self, f: F) -> Self
220 where
221 F: Fn(&ArgMatches, &mut CommandContext) -> Result<(), HookError> + Send + Sync + 'static,
222 {
223 self.pre_dispatch.push(Arc::new(f));
224 self
225 }
226
227 pub fn post_dispatch<F>(mut self, f: F) -> Self
229 where
230 F: Fn(
231 &ArgMatches,
232 &CommandContext,
233 serde_json::Value,
234 ) -> Result<serde_json::Value, HookError>
235 + Send
236 + Sync
237 + 'static,
238 {
239 self.post_dispatch.push(Arc::new(f));
240 self
241 }
242
243 pub fn post_output<F>(mut self, f: F) -> Self
245 where
246 F: Fn(&ArgMatches, &CommandContext, RenderedOutput) -> Result<RenderedOutput, HookError>
247 + Send
248 + Sync
249 + 'static,
250 {
251 self.post_output.push(Arc::new(f));
252 self
253 }
254
255 pub fn run_pre_dispatch(
259 &self,
260 matches: &ArgMatches,
261 ctx: &mut CommandContext,
262 ) -> Result<(), HookError> {
263 for hook in &self.pre_dispatch {
264 hook(matches, ctx)?;
265 }
266 Ok(())
267 }
268
269 pub fn run_post_dispatch(
271 &self,
272 matches: &ArgMatches,
273 ctx: &CommandContext,
274 data: serde_json::Value,
275 ) -> Result<serde_json::Value, HookError> {
276 let mut current = data;
277 for hook in &self.post_dispatch {
278 current = hook(matches, ctx, current)?;
279 }
280 Ok(current)
281 }
282
283 pub fn run_post_output(
285 &self,
286 matches: &ArgMatches,
287 ctx: &CommandContext,
288 output: RenderedOutput,
289 ) -> Result<RenderedOutput, HookError> {
290 let mut current = output;
291 for hook in &self.post_output {
292 current = hook(matches, ctx, current)?;
293 }
294 Ok(current)
295 }
296}
297
298impl fmt::Debug for Hooks {
299 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300 f.debug_struct("Hooks")
301 .field("pre_dispatch_count", &self.pre_dispatch.len())
302 .field("post_dispatch_count", &self.post_dispatch.len())
303 .field("post_output_count", &self.post_output.len())
304 .finish()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 fn test_context() -> CommandContext {
313 CommandContext {
314 command_path: vec!["test".into()],
315 ..Default::default()
316 }
317 }
318
319 fn test_matches() -> ArgMatches {
320 clap::Command::new("test").get_matches_from(vec!["test"])
321 }
322
323 #[test]
324 fn test_rendered_output_variants() {
325 let text = RenderedOutput::Text("hello".into());
326 assert!(text.is_text());
327 assert!(!text.is_binary());
328 assert!(!text.is_silent());
329 assert_eq!(text.as_text(), Some("hello"));
330
331 let binary = RenderedOutput::Binary(vec![1, 2, 3], "file.bin".into());
332 assert!(!binary.is_text());
333 assert!(binary.is_binary());
334 assert_eq!(binary.as_binary(), Some((&[1u8, 2, 3][..], "file.bin")));
335
336 let silent = RenderedOutput::Silent;
337 assert!(silent.is_silent());
338 }
339
340 #[test]
341 fn test_hook_error_creation() {
342 let err = HookError::pre_dispatch("test error");
343 assert_eq!(err.phase, HookPhase::PreDispatch);
344 assert_eq!(err.message, "test error");
345 }
346
347 #[test]
348 fn test_hooks_empty() {
349 let hooks = Hooks::new();
350 assert!(hooks.is_empty());
351 }
352
353 #[test]
354 fn test_pre_dispatch_success() {
355 let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
356 let called_clone = called.clone();
357
358 let hooks = Hooks::new().pre_dispatch(move |_, _| {
359 called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
360 Ok(())
361 });
362
363 let mut ctx = test_context();
364 let matches = test_matches();
365 let result = hooks.run_pre_dispatch(&matches, &mut ctx);
366
367 assert!(result.is_ok());
368 assert!(called.load(std::sync::atomic::Ordering::SeqCst));
369 }
370
371 #[test]
372 fn test_pre_dispatch_error_aborts() {
373 let hooks = Hooks::new()
374 .pre_dispatch(|_, _| Err(HookError::pre_dispatch("first fails")))
375 .pre_dispatch(|_, _| panic!("should not be called"));
376
377 let mut ctx = test_context();
378 let matches = test_matches();
379 let result = hooks.run_pre_dispatch(&matches, &mut ctx);
380
381 assert!(result.is_err());
382 }
383
384 #[test]
385 fn test_pre_dispatch_injects_extensions() {
386 struct TestState {
387 value: i32,
388 }
389
390 let hooks = Hooks::new().pre_dispatch(|_, ctx| {
391 ctx.extensions.insert(TestState { value: 42 });
392 Ok(())
393 });
394
395 let mut ctx = test_context();
396 let matches = test_matches();
397
398 assert!(!ctx.extensions.contains::<TestState>());
400
401 hooks.run_pre_dispatch(&matches, &mut ctx).unwrap();
402
403 let state = ctx.extensions.get::<TestState>().unwrap();
405 assert_eq!(state.value, 42);
406 }
407
408 #[test]
409 fn test_pre_dispatch_multiple_hooks_share_context() {
410 struct Counter {
411 count: i32,
412 }
413
414 let hooks = Hooks::new()
415 .pre_dispatch(|_, ctx| {
416 ctx.extensions.insert(Counter { count: 1 });
417 Ok(())
418 })
419 .pre_dispatch(|_, ctx| {
420 if let Some(counter) = ctx.extensions.get_mut::<Counter>() {
422 counter.count += 10;
423 }
424 Ok(())
425 });
426
427 let mut ctx = test_context();
428 let matches = test_matches();
429 hooks.run_pre_dispatch(&matches, &mut ctx).unwrap();
430
431 let counter = ctx.extensions.get::<Counter>().unwrap();
432 assert_eq!(counter.count, 11);
433 }
434
435 #[test]
436 fn test_post_dispatch_transformation() {
437 use serde_json::json;
438
439 let hooks = Hooks::new().post_dispatch(|_, _, mut data| {
440 if let Some(obj) = data.as_object_mut() {
441 obj.insert("modified".into(), json!(true));
442 }
443 Ok(data)
444 });
445
446 let ctx = test_context();
447 let matches = test_matches();
448 let data = json!({"value": 42});
449 let result = hooks.run_post_dispatch(&matches, &ctx, data);
450
451 assert!(result.is_ok());
452 let output = result.unwrap();
453 assert_eq!(output["value"], 42);
454 assert_eq!(output["modified"], true);
455 }
456
457 #[test]
458 fn test_post_output_transformation() {
459 let hooks = Hooks::new().post_output(|_, _, output| {
460 if let RenderedOutput::Text(text) = output {
461 Ok(RenderedOutput::Text(text.to_uppercase()))
462 } else {
463 Ok(output)
464 }
465 });
466
467 let ctx = test_context();
468 let matches = test_matches();
469 let result = hooks.run_post_output(&matches, &ctx, RenderedOutput::Text("hello".into()));
470
471 assert!(result.is_ok());
472 assert_eq!(result.unwrap().as_text(), Some("HELLO"));
473 }
474}