1use std::{collections::HashMap, path::PathBuf};
2
3use anyhow::Result;
4use clap::Subcommand;
5use nkeys::{KeyPair, KeyPairType};
6use serde_json::json;
7use wash_lib::cli::CommandOutput;
8use wash_lib::config::cfg_dir;
9use wash_lib::keys::{fs::KeyDir, KeyManager};
10
11const NKEYS_EXTENSION: &str = ".nk";
12
13#[derive(Debug, Clone, Subcommand)]
14#[allow(clippy::enum_variant_names)]
15pub enum KeysCliCommand {
16 #[clap(name = "gen", about = "Generates a keypair")]
17 GenCommand {
18 keytype: String,
20 },
21 #[clap(name = "get", about = "Retrieves a keypair and prints the contents")]
22 GetCommand {
23 #[clap(help = "The name of the key to output")]
24 keyname: String,
25 #[clap(
26 short = 'd',
27 long = "directory",
28 env = "WASH_KEYS",
29 hide_env_values = true,
30 help = "Absolute path to where keypairs are stored. Defaults to `$HOME/.wash/keys`"
31 )]
32 directory: Option<PathBuf>,
33 },
34 #[clap(name = "list", about = "Lists all keypairs in a directory")]
35 ListCommand {
36 #[clap(
37 short = 'd',
38 long = "directory",
39 env = "WASH_KEYS",
40 hide_env_values = true,
41 help = "Absolute path to where keypairs are stored. Defaults to `$HOME/.wash/keys`"
42 )]
43 directory: Option<PathBuf>,
44 },
45}
46
47pub fn handle_command(command: KeysCliCommand) -> Result<CommandOutput> {
48 match command {
49 KeysCliCommand::GenCommand { keytype } => {
50 let kt = keytype_parser(&keytype)?;
51 generate(&kt)
52 }
53 KeysCliCommand::GetCommand { keyname, directory } => get(&keyname, directory),
54 KeysCliCommand::ListCommand { directory } => list(directory),
55 }
56}
57
58pub fn keytype_parser(keytype: &str) -> Result<KeyPairType> {
59 match keytype.to_lowercase().as_str() {
60 "account" => Ok(KeyPairType::Account),
61 "user" => Ok(KeyPairType::User),
62 "module" | "component" => Ok(KeyPairType::Module),
63 "service" | "provider" => Ok(KeyPairType::Service),
64 "server" | "host" => Ok(KeyPairType::Server),
65 "operator" => Ok(KeyPairType::Operator),
66 "cluster" => Ok(KeyPairType::Cluster),
67 "x25519" | "curve" => Ok(KeyPairType::Curve),
68 _ => Err(anyhow::anyhow!(
69 "Invalid key type. Must be one of Account, User, Module (or Component), Service (or Provider), Server (or Host), Operator, Cluster, Curve (xkey)"
70 )),
71 }
72}
73pub fn generate(kt: &KeyPairType) -> Result<CommandOutput> {
75 let kp = KeyPair::new(kt.clone());
76 let seed = kp.seed()?;
77
78 let mut map = HashMap::new();
79 map.insert("public_key".to_string(), json!(kp.public_key()));
80 map.insert("seed".to_string(), json!(seed));
81 Ok(CommandOutput::new(
82 format!(
83 "Public Key: {}\nSeed: {}\n\nRemember that the seed is private, treat it as a secret.",
84 kp.public_key(),
85 seed,
86 ),
87 map,
88 ))
89}
90
91pub fn get(keyname: &str, directory: Option<PathBuf>) -> Result<CommandOutput> {
93 let key_dir = KeyDir::new(determine_directory(directory)?)?;
94 let key = key_dir
96 .get(keyname.trim_end_matches(NKEYS_EXTENSION))?
97 .ok_or_else(|| anyhow::anyhow!("Key {} doesn't exist", keyname))?;
98
99 Ok(CommandOutput::from_key_and_text("seed", key.seed()?))
100}
101
102pub fn list(directory: Option<PathBuf>) -> Result<CommandOutput> {
104 let key_dir = KeyDir::new(determine_directory(directory)?)?;
105
106 let keys = key_dir.list_names()?;
107
108 let mut map = HashMap::new();
109 map.insert("keys".to_string(), json!(keys));
110 Ok(CommandOutput::new(
111 format!(
112 "====== Keys found in {} ======\n{}",
113 key_dir.display(),
114 keys.join("\n")
115 ),
116 map,
117 ))
118}
119
120fn determine_directory(directory: Option<PathBuf>) -> Result<PathBuf> {
121 if let Some(d) = directory {
122 Ok(d)
123 } else {
124 let d = cfg_dir()?.join("keys");
125 Ok(d)
126 }
127}
128
129#[cfg(test)]
130mod tests {
131
132 use super::{generate, keytype_parser, KeysCliCommand};
133 use clap::Parser;
134 use nkeys::KeyPairType;
135 use serde::Deserialize;
136 use std::path::PathBuf;
137
138 #[derive(Debug, Parser)]
139 struct Cmd {
140 #[clap(subcommand)]
141 keys: KeysCliCommand,
142 }
143 #[test]
144 fn test_generate_basic_test() {
145 let kt = KeyPairType::Account;
146
147 let keypair = generate(&kt).unwrap();
148
149 assert!(keypair.text.contains("Public Key: "));
150 assert!(keypair.text.contains("Seed: "));
151 assert!(keypair
152 .text
153 .contains("Remember that the seed is private, treat it as a secret."));
154 assert_ne!(keypair.text, "");
155 assert!(!keypair.map.is_empty());
156 }
157
158 #[derive(Debug, Clone, Deserialize)]
159 struct KeyPairJson {
160 public_key: String,
161 seed: String,
162 }
163
164 #[test]
165 fn test_generate_valid_keypair() {
166 let sample_public_key = "MBBLAHS7MCGNQ6IR4ZDSGRIAF7NVS7FCKFTKGO5JJJKN2QQRVAH7BSIO";
167 let sample_seed = "SMAH45IUULL57OSX23NOOOTLSVNQOORMDLE3Y3PQLJ4J5MY7MN2K7BIFI4";
168
169 let kt = KeyPairType::Module;
170
171 let keypair_json = generate(&kt).unwrap();
172 let keypair: KeyPairJson =
173 serde_json::from_str(serde_json::to_string(&keypair_json.map).unwrap().as_str())
174 .unwrap();
175
176 assert_eq!(keypair.public_key.len(), sample_public_key.len());
177 assert_eq!(keypair.seed.len(), sample_seed.len());
178 assert!(keypair.public_key.starts_with('M'));
179 assert!(keypair.seed.starts_with("SM"));
180 }
181
182 #[test]
183 fn test_generate_all_types() {
184 let sample_public_key = "MBBLAHS7MCGNQ6IR4ZDSGRIAF7NVS7FCKFTKGO5JJJKN2QQRVAH7BSIO";
185 let sample_seed = "SMAH45IUULL57OSXNOOAKOTLSVNQOORMDLE3Y3PQLJ4J5MY7MN2K7BIFI4";
186
187 let account_keypair: KeyPairJson = serde_json::from_str(
188 serde_json::to_string(&generate(&KeyPairType::Account).unwrap().map)
189 .unwrap()
190 .as_str(),
191 )
192 .unwrap();
193 let user_keypair: KeyPairJson = serde_json::from_str(
194 serde_json::to_string(&generate(&KeyPairType::User).unwrap().map)
195 .unwrap()
196 .as_str(),
197 )
198 .unwrap();
199 let module_keypair: KeyPairJson = serde_json::from_str(
200 serde_json::to_string(&generate(&KeyPairType::Module).unwrap().map)
201 .unwrap()
202 .as_str(),
203 )
204 .unwrap();
205 let service_keypair: KeyPairJson = serde_json::from_str(
206 serde_json::to_string(&generate(&KeyPairType::Service).unwrap().map)
207 .unwrap()
208 .as_str(),
209 )
210 .unwrap();
211 let server_keypair: KeyPairJson = serde_json::from_str(
212 serde_json::to_string(&generate(&KeyPairType::Server).unwrap().map)
213 .unwrap()
214 .as_str(),
215 )
216 .unwrap();
217 let operator_keypair: KeyPairJson = serde_json::from_str(
218 serde_json::to_string(&generate(&KeyPairType::Operator).unwrap().map)
219 .unwrap()
220 .as_str(),
221 )
222 .unwrap();
223 let cluster_keypair: KeyPairJson = serde_json::from_str(
224 serde_json::to_string(&generate(&KeyPairType::Cluster).unwrap().map)
225 .unwrap()
226 .as_str(),
227 )
228 .unwrap();
229
230 assert!(account_keypair.public_key.starts_with('A'));
231 assert_eq!(account_keypair.public_key.len(), sample_public_key.len());
232 assert!(account_keypair.seed.starts_with("SA"));
233 assert_eq!(account_keypair.seed.len(), sample_seed.len());
234
235 assert!(user_keypair.public_key.starts_with('U'));
236 assert_eq!(user_keypair.public_key.len(), sample_public_key.len());
237 assert!(user_keypair.seed.starts_with("SU"));
238 assert_eq!(user_keypair.seed.len(), sample_seed.len());
239
240 assert!(module_keypair.public_key.starts_with('M'));
241 assert_eq!(module_keypair.public_key.len(), sample_public_key.len());
242 assert!(module_keypair.seed.starts_with("SM"));
243 assert_eq!(module_keypair.seed.len(), sample_seed.len());
244
245 assert!(service_keypair.public_key.starts_with('V'));
246 assert_eq!(service_keypair.public_key.len(), sample_public_key.len());
247 assert!(service_keypair.seed.starts_with("SV"));
248 assert_eq!(service_keypair.seed.len(), sample_seed.len());
249
250 assert!(server_keypair.public_key.starts_with('N'));
251 assert_eq!(server_keypair.public_key.len(), sample_public_key.len());
252 assert!(server_keypair.seed.starts_with("SN"));
253 assert_eq!(server_keypair.seed.len(), sample_seed.len());
254
255 assert!(operator_keypair.public_key.starts_with('O'));
256 assert_eq!(operator_keypair.public_key.len(), sample_public_key.len());
257 assert!(operator_keypair.seed.starts_with("SO"));
258 assert_eq!(operator_keypair.seed.len(), sample_seed.len());
259
260 assert!(cluster_keypair.public_key.starts_with('C'));
261 assert_eq!(cluster_keypair.public_key.len(), sample_public_key.len());
262 assert!(cluster_keypair.seed.starts_with("SC"));
263 assert_eq!(cluster_keypair.seed.len(), sample_seed.len());
264 }
265
266 #[test]
267 fn test_gen_comprehensive() {
271 let key_gen_types = [
272 "acCount",
273 "usEr",
274 "module",
275 "COMPONENT",
276 "SERVICE",
277 "provider",
278 "server",
279 "HOST",
280 "operator",
281 "CLUSTER",
282 ];
283
284 key_gen_types
285 .iter()
286 .map(|cmd| cmd.to_lowercase())
287 .for_each(|cmd| {
288 let gen_cmd: Cmd = clap::Parser::try_parse_from(["keys", "gen", &cmd]).unwrap();
289 match gen_cmd.keys {
290 KeysCliCommand::GenCommand { keytype } => {
291 use KeyPairType::*;
292 let parsed_keytype = keytype_parser(&keytype).unwrap();
293 match parsed_keytype {
294 Account => assert_eq!(&cmd, "account"),
295 User => assert_eq!(&cmd, "user"),
296 Module => assert!(cmd.eq("module") || cmd.eq("component")),
297 Service => assert!(cmd.eq("service") || cmd.eq("provider")),
298 Server => assert!(cmd.eq("server") || cmd.eq("host")),
299 Operator => assert_eq!(&cmd, "operator"),
300 Cluster => assert_eq!(&cmd, "cluster"),
301 Curve => assert_eq!(&cmd, "curve"),
302 }
303 }
304 _ => panic!("`keys gen` constructed incorrect command"),
305 };
306 });
307 }
308
309 #[test]
310 fn test_invalid_keytype_input() {
311 let key_gen_types = [
312 "accout", "USE", "moDUl", "actors", "SEVICE", "provder", "srver", "hos", "opERtoR",
313 "cluter",
314 ];
315
316 key_gen_types
317 .iter()
318 .map(|cmd| cmd.to_lowercase())
319 .for_each(|cmd| {
320 let gen_cmd: Cmd = clap::Parser::try_parse_from(["keys", "gen", &cmd]).unwrap();
321 match gen_cmd.keys {
322 KeysCliCommand::GenCommand { keytype } => {
323 let parsed_keytype = keytype_parser(&keytype);
324 assert!(
325 parsed_keytype.is_err(),
326 "Invalid keytype parsed successfully"
327 );
328 }
329 _ => panic!("`keys gen` constructed incorrect command"),
330 };
331 });
332 }
333
334 #[test]
335 fn test_get_basic() {
336 const KEYNAME: &str = "get_basic_test.nk";
337 const KEYPATH: &str = "./tests/fixtures";
338
339 let gen_basic: Cmd =
340 clap::Parser::try_parse_from(["keys", "get", KEYNAME, "--directory", KEYPATH]).unwrap();
341 match gen_basic.keys {
342 KeysCliCommand::GetCommand { keyname, .. } => assert_eq!(keyname, KEYNAME),
343 other_cmd => panic!("keys get generated other command {other_cmd:?}"),
344 }
345 }
346
347 #[test]
348 fn test_get_comprehensive() {
352 const KEYPATH: &str = "./tests/fixtures";
353 const KEYNAME: &str = "get_comprehensive_test.nk";
354
355 let get_all_flags: Cmd =
356 clap::Parser::try_parse_from(["keys", "get", KEYNAME, "-d", KEYPATH]).unwrap();
357 match get_all_flags.keys {
358 KeysCliCommand::GetCommand { keyname, directory } => {
359 assert_eq!(keyname, KEYNAME);
360 assert_eq!(directory, Some(PathBuf::from(KEYPATH)));
361 }
362 other_cmd => panic!("keys get generated other command {other_cmd:?}"),
363 }
364 }
365
366 #[test]
367 fn test_list_comprehensive() {
371 const KEYPATH: &str = "./";
372
373 let list_basic: Cmd =
374 clap::Parser::try_parse_from(["keys", "list", "-d", KEYPATH]).unwrap();
375 match list_basic.keys {
376 KeysCliCommand::ListCommand { .. } => (),
377 other_cmd => panic!("keys get generated other command {other_cmd:?}"),
378 }
379
380 let list_all_flags: Cmd =
381 clap::Parser::try_parse_from(["keys", "list", "-d", KEYPATH]).unwrap();
382 match list_all_flags.keys {
383 KeysCliCommand::ListCommand { directory } => {
384 assert_eq!(directory, Some(PathBuf::from(KEYPATH)));
385 }
386 other_cmd => panic!("keys get generated other command {other_cmd:?}"),
387 }
388 }
389}