1use std::any::{Any, TypeId};
28use std::borrow::Cow;
29use std::collections::HashMap;
30use std::fmt;
31
32use crate::collector::{InputSourceKind, ResolvedInput};
33
34#[derive(Default)]
45pub struct Inputs {
46 entries: HashMap<Cow<'static, str>, Entry>,
47}
48
49struct Entry {
50 type_id: TypeId,
51 type_name: &'static str,
52 source: InputSourceKind,
53 value: Box<dyn Any + Send + Sync>,
54}
55
56impl Inputs {
57 pub fn new() -> Self {
59 Self {
60 entries: HashMap::new(),
61 }
62 }
63
64 pub fn insert<T>(
72 &mut self,
73 name: impl Into<Cow<'static, str>>,
74 resolved: ResolvedInput<T>,
75 ) -> Option<InputSourceKind>
76 where
77 T: Send + Sync + 'static,
78 {
79 let prev = self.entries.insert(
80 name.into(),
81 Entry {
82 type_id: TypeId::of::<T>(),
83 type_name: std::any::type_name::<T>(),
84 source: resolved.source,
85 value: Box::new(resolved.value),
86 },
87 );
88 prev.map(|e| e.source)
89 }
90
91 pub fn get<T: 'static>(&self, name: &str) -> Option<&T> {
96 let entry = self.entries.get(name)?;
97 if entry.type_id != TypeId::of::<T>() {
98 return None;
99 }
100 entry.value.downcast_ref::<T>()
101 }
102
103 pub fn get_required<T: 'static>(&self, name: &str) -> Result<&T, MissingInput> {
106 let Some(entry) = self.entries.get(name) else {
107 return Err(MissingInput::NotRegistered {
108 name: name.to_string(),
109 });
110 };
111 if entry.type_id != TypeId::of::<T>() {
112 return Err(MissingInput::TypeMismatch {
113 name: name.to_string(),
114 expected: std::any::type_name::<T>(),
115 actual: entry.type_name,
116 });
117 }
118 entry
119 .value
120 .downcast_ref::<T>()
121 .ok_or_else(|| MissingInput::TypeMismatch {
122 name: name.to_string(),
123 expected: std::any::type_name::<T>(),
124 actual: entry.type_name,
125 })
126 }
127
128 pub fn source_of(&self, name: &str) -> Option<InputSourceKind> {
130 self.entries.get(name).map(|e| e.source)
131 }
132
133 pub fn contains(&self, name: &str) -> bool {
135 self.entries.contains_key(name)
136 }
137
138 pub fn len(&self) -> usize {
140 self.entries.len()
141 }
142
143 pub fn is_empty(&self) -> bool {
145 self.entries.is_empty()
146 }
147
148 pub fn iter_sources(&self) -> impl Iterator<Item = (&str, InputSourceKind)> + '_ {
150 self.entries
151 .iter()
152 .map(|(name, entry)| (name.as_ref(), entry.source))
153 }
154}
155
156impl fmt::Debug for Inputs {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 let mut s = f.debug_struct("Inputs");
159 for (name, entry) in &self.entries {
160 s.field(
161 name.as_ref(),
162 &format_args!("{} from {}", entry.type_name, entry.source),
163 );
164 }
165 s.finish()
166 }
167}
168
169#[derive(Debug, thiserror::Error)]
171pub enum MissingInput {
172 #[error("no input named `{name}` was registered for this command")]
174 NotRegistered {
175 name: String,
177 },
178 #[error("input `{name}` is registered as `{actual}`, not `{expected}`")]
180 TypeMismatch {
181 name: String,
183 expected: &'static str,
185 actual: &'static str,
187 },
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 fn arg<T>(value: T) -> ResolvedInput<T> {
195 ResolvedInput {
196 value,
197 source: InputSourceKind::Arg,
198 }
199 }
200
201 #[test]
202 fn insert_and_get() {
203 let mut inputs = Inputs::new();
204 inputs.insert("body", arg("hello".to_string()));
205
206 let body: &String = inputs.get("body").unwrap();
207 assert_eq!(body, "hello");
208 }
209
210 #[test]
211 fn get_missing_returns_none() {
212 let inputs = Inputs::new();
213 assert!(inputs.get::<String>("missing").is_none());
214 }
215
216 #[test]
217 fn get_wrong_type_returns_none() {
218 let mut inputs = Inputs::new();
219 inputs.insert("body", arg("hello".to_string()));
220 assert!(inputs.get::<u32>("body").is_none());
221 }
222
223 #[test]
224 fn get_required_reports_missing() {
225 let inputs = Inputs::new();
226 let err = inputs.get_required::<String>("body").unwrap_err();
227 assert!(matches!(err, MissingInput::NotRegistered { .. }));
228 assert!(err.to_string().contains("body"));
229 }
230
231 #[test]
232 fn get_required_reports_type_mismatch() {
233 let mut inputs = Inputs::new();
234 inputs.insert("body", arg("hello".to_string()));
235 let err = inputs.get_required::<u32>("body").unwrap_err();
236 match err {
237 MissingInput::TypeMismatch {
238 ref name,
239 expected,
240 actual,
241 } => {
242 assert_eq!(name, "body");
243 assert!(expected.contains("u32"));
244 assert!(actual.contains("String"));
245 }
246 other => panic!("expected TypeMismatch, got {:?}", other),
247 }
248 }
249
250 #[test]
251 fn accepts_owned_string_name() {
252 let mut inputs = Inputs::new();
253 let runtime_name: String = format!("input_{}", 42);
254 inputs.insert(runtime_name.clone(), arg("x".to_string()));
255
256 assert_eq!(inputs.get::<String>(runtime_name.as_str()).unwrap(), "x");
259 }
260
261 #[test]
262 fn two_inputs_of_same_type_do_not_collide() {
263 let mut inputs = Inputs::new();
264 inputs.insert("body", arg("the body".to_string()));
265 inputs.insert("title", arg("the title".to_string()));
266
267 assert_eq!(inputs.get::<String>("body").unwrap(), "the body");
268 assert_eq!(inputs.get::<String>("title").unwrap(), "the title");
269 }
270
271 #[test]
272 fn insert_returns_previous_source() {
273 let mut inputs = Inputs::new();
274 assert!(inputs.insert("body", arg("first".to_string())).is_none());
275 let prev = inputs.insert(
276 "body",
277 ResolvedInput {
278 value: "second".to_string(),
279 source: InputSourceKind::Stdin,
280 },
281 );
282 assert_eq!(prev, Some(InputSourceKind::Arg));
283 assert_eq!(inputs.source_of("body"), Some(InputSourceKind::Stdin));
284 }
285
286 #[test]
287 fn source_of_and_contains() {
288 let mut inputs = Inputs::new();
289 assert!(!inputs.contains("body"));
290 inputs.insert("body", arg("x".to_string()));
291 assert!(inputs.contains("body"));
292 assert_eq!(inputs.source_of("body"), Some(InputSourceKind::Arg));
293 assert_eq!(inputs.source_of("missing"), None);
294 }
295
296 #[test]
297 fn iter_sources_yields_all_entries() {
298 let mut inputs = Inputs::new();
299 inputs.insert("body", arg("x".to_string()));
300 inputs.insert(
301 "yes",
302 ResolvedInput {
303 value: true,
304 source: InputSourceKind::Flag,
305 },
306 );
307
308 let mut pairs: Vec<_> = inputs.iter_sources().collect();
309 pairs.sort_by_key(|(name, _)| *name);
310 assert_eq!(
311 pairs,
312 vec![
313 ("body", InputSourceKind::Arg),
314 ("yes", InputSourceKind::Flag)
315 ]
316 );
317 }
318
319 #[test]
320 fn len_and_is_empty() {
321 let mut inputs = Inputs::new();
322 assert!(inputs.is_empty());
323 assert_eq!(inputs.len(), 0);
324 inputs.insert("body", arg("x".to_string()));
325 assert!(!inputs.is_empty());
326 assert_eq!(inputs.len(), 1);
327 }
328}