1use std::fmt;
33use std::rc::Rc;
34use thiserror::Error;
35
36use crate::handler::CommandContext;
37use clap::ArgMatches;
38
39#[derive(Debug, Clone)]
45pub struct TextOutput {
46 pub formatted: String,
48 pub raw: String,
52}
53
54impl TextOutput {
55 pub fn new(formatted: String, raw: String) -> Self {
57 Self { formatted, raw }
58 }
59
60 pub fn plain(text: String) -> Self {
64 Self {
65 formatted: text.clone(),
66 raw: text,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
75pub enum RenderedOutput {
76 Text(TextOutput),
80 Binary(Vec<u8>, String),
82 Silent,
84}
85
86impl RenderedOutput {
87 pub fn is_text(&self) -> bool {
89 matches!(self, RenderedOutput::Text(_))
90 }
91
92 pub fn is_binary(&self) -> bool {
94 matches!(self, RenderedOutput::Binary(_, _))
95 }
96
97 pub fn is_silent(&self) -> bool {
99 matches!(self, RenderedOutput::Silent)
100 }
101
102 pub fn as_text(&self) -> Option<&str> {
104 match self {
105 RenderedOutput::Text(t) => Some(&t.formatted),
106 _ => None,
107 }
108 }
109
110 pub fn as_raw_text(&self) -> Option<&str> {
113 match self {
114 RenderedOutput::Text(t) => Some(&t.raw),
115 _ => None,
116 }
117 }
118
119 pub fn as_text_output(&self) -> Option<&TextOutput> {
121 match self {
122 RenderedOutput::Text(t) => Some(t),
123 _ => None,
124 }
125 }
126
127 pub fn as_binary(&self) -> Option<(&[u8], &str)> {
129 match self {
130 RenderedOutput::Binary(bytes, filename) => Some((bytes, filename)),
131 _ => None,
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum HookPhase {
139 PreDispatch,
141 PostDispatch,
143 PostOutput,
145}
146
147impl fmt::Display for HookPhase {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 match self {
150 HookPhase::PreDispatch => write!(f, "pre-dispatch"),
151 HookPhase::PostDispatch => write!(f, "post-dispatch"),
152 HookPhase::PostOutput => write!(f, "post-output"),
153 }
154 }
155}
156
157#[derive(Debug, Error)]
159#[error("hook error ({phase}): {message}")]
160pub struct HookError {
161 pub message: String,
163 pub phase: HookPhase,
165 #[source]
167 pub source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
168}
169
170impl HookError {
171 pub fn pre_dispatch(message: impl Into<String>) -> Self {
173 Self {
174 message: message.into(),
175 phase: HookPhase::PreDispatch,
176 source: None,
177 }
178 }
179
180 pub fn post_dispatch(message: impl Into<String>) -> Self {
182 Self {
183 message: message.into(),
184 phase: HookPhase::PostDispatch,
185 source: None,
186 }
187 }
188
189 pub fn post_output(message: impl Into<String>) -> Self {
191 Self {
192 message: message.into(),
193 phase: HookPhase::PostOutput,
194 source: None,
195 }
196 }
197
198 pub fn with_source<E>(mut self, source: E) -> Self
200 where
201 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
202 {
203 self.source = Some(source.into());
204 self
205 }
206}
207
208pub type PreDispatchFn = Rc<dyn Fn(&ArgMatches, &mut CommandContext) -> Result<(), HookError>>;
213
214pub type PostDispatchFn = Rc<
216 dyn Fn(&ArgMatches, &CommandContext, serde_json::Value) -> Result<serde_json::Value, HookError>,
217>;
218
219pub type PostOutputFn =
221 Rc<dyn Fn(&ArgMatches, &CommandContext, RenderedOutput) -> Result<RenderedOutput, HookError>>;
222
223#[derive(Clone, Default)]
227pub struct Hooks {
228 pre_dispatch: Vec<PreDispatchFn>,
229 post_dispatch: Vec<PostDispatchFn>,
230 post_output: Vec<PostOutputFn>,
231}
232
233impl Hooks {
234 pub fn new() -> Self {
236 Self::default()
237 }
238
239 pub fn is_empty(&self) -> bool {
241 self.pre_dispatch.is_empty() && self.post_dispatch.is_empty() && self.post_output.is_empty()
242 }
243
244 pub fn pre_dispatch<F>(mut self, f: F) -> Self
265 where
266 F: Fn(&ArgMatches, &mut CommandContext) -> Result<(), HookError> + 'static,
267 {
268 self.pre_dispatch.push(Rc::new(f));
269 self
270 }
271
272 pub fn post_dispatch<F>(mut self, f: F) -> Self
274 where
275 F: Fn(
276 &ArgMatches,
277 &CommandContext,
278 serde_json::Value,
279 ) -> Result<serde_json::Value, HookError>
280 + 'static,
281 {
282 self.post_dispatch.push(Rc::new(f));
283 self
284 }
285
286 pub fn post_output<F>(mut self, f: F) -> Self
288 where
289 F: Fn(&ArgMatches, &CommandContext, RenderedOutput) -> Result<RenderedOutput, HookError>
290 + 'static,
291 {
292 self.post_output.push(Rc::new(f));
293 self
294 }
295
296 pub fn run_pre_dispatch(
300 &self,
301 matches: &ArgMatches,
302 ctx: &mut CommandContext,
303 ) -> Result<(), HookError> {
304 for hook in &self.pre_dispatch {
305 hook(matches, ctx)?;
306 }
307 Ok(())
308 }
309
310 pub fn run_post_dispatch(
312 &self,
313 matches: &ArgMatches,
314 ctx: &CommandContext,
315 data: serde_json::Value,
316 ) -> Result<serde_json::Value, HookError> {
317 let mut current = data;
318 for hook in &self.post_dispatch {
319 current = hook(matches, ctx, current)?;
320 }
321 Ok(current)
322 }
323
324 pub fn run_post_output(
326 &self,
327 matches: &ArgMatches,
328 ctx: &CommandContext,
329 output: RenderedOutput,
330 ) -> Result<RenderedOutput, HookError> {
331 let mut current = output;
332 for hook in &self.post_output {
333 current = hook(matches, ctx, current)?;
334 }
335 Ok(current)
336 }
337}
338
339impl fmt::Debug for Hooks {
340 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341 f.debug_struct("Hooks")
342 .field("pre_dispatch_count", &self.pre_dispatch.len())
343 .field("post_dispatch_count", &self.post_dispatch.len())
344 .field("post_output_count", &self.post_output.len())
345 .finish()
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 fn test_context() -> CommandContext {
354 CommandContext {
355 command_path: vec!["test".into()],
356 ..Default::default()
357 }
358 }
359
360 fn test_matches() -> ArgMatches {
361 clap::Command::new("test").get_matches_from(vec!["test"])
362 }
363
364 #[test]
365 fn test_rendered_output_variants() {
366 let text = RenderedOutput::Text(TextOutput::new("formatted".into(), "raw".into()));
367 assert!(text.is_text());
368 assert!(!text.is_binary());
369 assert!(!text.is_silent());
370 assert_eq!(text.as_text(), Some("formatted"));
371 assert_eq!(text.as_raw_text(), Some("raw"));
372
373 let plain = RenderedOutput::Text(TextOutput::plain("hello".into()));
375 assert_eq!(plain.as_text(), Some("hello"));
376 assert_eq!(plain.as_raw_text(), Some("hello"));
377
378 let binary = RenderedOutput::Binary(vec![1, 2, 3], "file.bin".into());
379 assert!(!binary.is_text());
380 assert!(binary.is_binary());
381 assert_eq!(binary.as_binary(), Some((&[1u8, 2, 3][..], "file.bin")));
382
383 let silent = RenderedOutput::Silent;
384 assert!(silent.is_silent());
385 }
386
387 #[test]
388 fn test_hook_error_creation() {
389 let err = HookError::pre_dispatch("test error");
390 assert_eq!(err.phase, HookPhase::PreDispatch);
391 assert_eq!(err.message, "test error");
392 }
393
394 #[test]
395 fn test_hooks_empty() {
396 let hooks = Hooks::new();
397 assert!(hooks.is_empty());
398 }
399
400 #[test]
401 fn test_pre_dispatch_success() {
402 use std::cell::Cell;
403 use std::rc::Rc;
404
405 let called = Rc::new(Cell::new(false));
406 let called_clone = called.clone();
407
408 let hooks = Hooks::new().pre_dispatch(move |_, _| {
409 called_clone.set(true);
410 Ok(())
411 });
412
413 let mut ctx = test_context();
414 let matches = test_matches();
415 let result = hooks.run_pre_dispatch(&matches, &mut ctx);
416
417 assert!(result.is_ok());
418 assert!(called.get());
419 }
420
421 #[test]
422 fn test_pre_dispatch_error_aborts() {
423 let hooks = Hooks::new()
424 .pre_dispatch(|_, _| Err(HookError::pre_dispatch("first fails")))
425 .pre_dispatch(|_, _| panic!("should not be called"));
426
427 let mut ctx = test_context();
428 let matches = test_matches();
429 let result = hooks.run_pre_dispatch(&matches, &mut ctx);
430
431 assert!(result.is_err());
432 }
433
434 #[test]
435 fn test_pre_dispatch_injects_extensions() {
436 struct TestState {
437 value: i32,
438 }
439
440 let hooks = Hooks::new().pre_dispatch(|_, ctx| {
441 ctx.extensions.insert(TestState { value: 42 });
442 Ok(())
443 });
444
445 let mut ctx = test_context();
446 let matches = test_matches();
447
448 assert!(!ctx.extensions.contains::<TestState>());
450
451 hooks.run_pre_dispatch(&matches, &mut ctx).unwrap();
452
453 let state = ctx.extensions.get::<TestState>().unwrap();
455 assert_eq!(state.value, 42);
456 }
457
458 #[test]
459 fn test_pre_dispatch_multiple_hooks_share_context() {
460 struct Counter {
461 count: i32,
462 }
463
464 let hooks = Hooks::new()
465 .pre_dispatch(|_, ctx| {
466 ctx.extensions.insert(Counter { count: 1 });
467 Ok(())
468 })
469 .pre_dispatch(|_, ctx| {
470 if let Some(counter) = ctx.extensions.get_mut::<Counter>() {
472 counter.count += 10;
473 }
474 Ok(())
475 });
476
477 let mut ctx = test_context();
478 let matches = test_matches();
479 hooks.run_pre_dispatch(&matches, &mut ctx).unwrap();
480
481 let counter = ctx.extensions.get::<Counter>().unwrap();
482 assert_eq!(counter.count, 11);
483 }
484
485 #[test]
486 fn test_post_dispatch_transformation() {
487 use serde_json::json;
488
489 let hooks = Hooks::new().post_dispatch(|_, _, mut data| {
490 if let Some(obj) = data.as_object_mut() {
491 obj.insert("modified".into(), json!(true));
492 }
493 Ok(data)
494 });
495
496 let ctx = test_context();
497 let matches = test_matches();
498 let data = json!({"value": 42});
499 let result = hooks.run_post_dispatch(&matches, &ctx, data);
500
501 assert!(result.is_ok());
502 let output = result.unwrap();
503 assert_eq!(output["value"], 42);
504 assert_eq!(output["modified"], true);
505 }
506
507 #[test]
508 fn test_post_output_transformation() {
509 let hooks = Hooks::new().post_output(|_, _, output| {
510 if let RenderedOutput::Text(text_output) = output {
511 Ok(RenderedOutput::Text(TextOutput::new(
512 text_output.formatted.to_uppercase(),
513 text_output.raw.to_uppercase(),
514 )))
515 } else {
516 Ok(output)
517 }
518 });
519
520 let ctx = test_context();
521 let matches = test_matches();
522 let input = RenderedOutput::Text(TextOutput::plain("hello".into()));
523 let result = hooks.run_post_output(&matches, &ctx, input);
524
525 assert!(result.is_ok());
526 assert_eq!(result.unwrap().as_text(), Some("HELLO"));
527 }
528}