sunset_stdasync/
agent.rs

1#[allow(unused_imports)]
2use {
3    log::{debug, error, info, log, trace, warn},
4    sunset::{Error, Result},
5};
6
7use std::path::Path;
8
9use pretty_hex::PrettyHex;
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::UnixStream;
12
13use sunset_sshwire_derive::*;
14
15use crate::*;
16use sshwire::{
17    BinString, Blob, SSHDecode, SSHEncode, SSHSink, SSHSource, TextString,
18    WireError, WireResult,
19};
20use sshwire::{SSHDecodeEnum, SSHEncodeEnum};
21use sunset::sshnames::*;
22use sunset::sshwire;
23use sunset::{AuthSigMsg, OwnedSig, PubKey, SignKey, Signature};
24
25// Must be sufficient for the list of all public keys
26const MAX_RESPONSE: usize = 200_000;
27
28#[derive(Debug, SSHEncode)]
29struct AgentSignRequest<'a> {
30    pub key_blob: Blob<PubKey<'a>>,
31    pub msg: Blob<&'a AuthSigMsg<'a>>,
32    pub flags: u32,
33}
34
35#[derive(Debug, SSHDecode)]
36struct AgentSignResponse<'a> {
37    pub sig: Blob<Signature<'a>>,
38}
39
40#[derive(Debug)]
41struct AgentIdentitiesAnswer<'a> {
42    // [(key blob, comment)]
43    pub keys: Vec<(PubKey<'a>, TextString<'a>)>,
44}
45
46#[derive(Debug)]
47enum AgentRequest<'a> {
48    SignRequest(AgentSignRequest<'a>),
49    RequestIdentities,
50}
51
52impl SSHEncode for AgentRequest<'_> {
53    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
54        match self {
55            Self::SignRequest(a) => {
56                let n = AgentMessageNum::SSH_AGENTC_SIGN_REQUEST as u8;
57                n.enc(s)?;
58                a.enc(s)?;
59            }
60            Self::RequestIdentities => {
61                let n = AgentMessageNum::SSH_AGENTC_REQUEST_IDENTITIES as u8;
62                n.enc(s)?;
63            }
64        }
65        Ok(())
66    }
67}
68
69/// The subset of responses we recognise
70#[derive(Debug)]
71enum AgentResponse<'a> {
72    IdentitiesAnswer(AgentIdentitiesAnswer<'a>),
73    SignResponse(AgentSignResponse<'a>),
74}
75
76impl<'de: 'a, 'a> SSHDecode<'de> for AgentResponse<'a> {
77    fn dec<S>(s: &mut S) -> WireResult<Self>
78    where
79        S: SSHSource<'de>,
80    {
81        let number = u8::dec(s)?;
82        if number == AgentMessageNum::SSH_AGENT_IDENTITIES_ANSWER as u8 {
83            Ok(Self::IdentitiesAnswer(AgentIdentitiesAnswer::dec(s)?))
84        } else if number == AgentMessageNum::SSH_AGENT_SIGN_RESPONSE as u8 {
85            Ok(Self::SignResponse(AgentSignResponse::dec(s)?))
86        } else {
87            Err(WireError::UnknownPacket { number })
88        }
89    }
90}
91
92impl<'de: 'a, 'a> SSHDecode<'de> for AgentIdentitiesAnswer<'a> {
93    fn dec<S>(s: &mut S) -> WireResult<Self>
94    where
95        S: SSHSource<'de>,
96    {
97        //     uint32                  nkeys
98        // Where "nkeys" indicates the number of keys to follow.  Following the
99        // preamble are zero or more keys, each encoded as:
100        //     string                  key blob
101        //     string                  comment
102        let l = u32::dec(s)?;
103        let mut keys = vec![];
104        for _ in 0..l {
105            let kb = Blob::<PubKey>::dec(s)?;
106            let comment = TextString::dec(s)?;
107            keys.push((kb.0, comment))
108        }
109        Ok(AgentIdentitiesAnswer { keys })
110    }
111}
112
113/// A SSH Agent client
114pub struct AgentClient {
115    conn: UnixStream,
116    buf: Vec<u8>,
117}
118
119impl AgentClient {
120    /// Create a new client
121    ///
122    /// `path` is a Unix socket to a ssh-agent, such as that from `$SSH_AUTH_SOCK`.
123    pub async fn new(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
124        let conn = UnixStream::connect(path).await?;
125        Ok(Self { conn, buf: vec![] })
126    }
127
128    async fn request(&mut self, r: AgentRequest<'_>) -> Result<AgentResponse<'_>> {
129        let mut b = vec![];
130        sshwire::ssh_push_vec(&mut b, &Blob(r))?;
131
132        trace!("agent request {:?}", b.hex_dump());
133
134        self.conn.write_all(&b).await?;
135        self.response().await
136    }
137
138    async fn response(&mut self) -> Result<AgentResponse<'_>> {
139        let mut l = [0u8; 4];
140        self.conn.read_exact(&mut l).await?;
141        let l = u32::from_be_bytes(l) as usize;
142        if l > MAX_RESPONSE {
143            error!("Response is {l} bytes long");
144            return Err(Error::msg("Too large response"));
145        }
146        self.buf.resize(l, 0);
147        self.conn.read_exact(&mut self.buf).await?;
148        let r: AgentResponse = sshwire::read_ssh(&self.buf, None)?;
149        Ok(r)
150    }
151
152    pub async fn keys(&mut self) -> Result<Vec<SignKey>> {
153        match self.request(AgentRequest::RequestIdentities).await? {
154            AgentResponse::IdentitiesAnswer(i) => {
155                let mut keys = vec![];
156                for (pk, comment) in i.keys.iter() {
157                    match SignKey::from_agent_pubkey(pk) {
158                        Ok(k) => keys.push(k),
159                        Err(e) => debug!("skipping agent key {comment:?}: {e}"),
160                    }
161                }
162                Ok(keys)
163            }
164            resp => {
165                debug!("response: {resp:?}");
166                Err(Error::msg("Unexpected agent response"))
167            }
168        }
169    }
170
171    pub async fn sign_auth(
172        &mut self,
173        key: &SignKey,
174        msg: &AuthSigMsg<'_>,
175    ) -> Result<OwnedSig> {
176        let flags = match key {
177            #[cfg(feature = "rsa")]
178            SignKey::AgentRSA(_) => SSH_AGENT_FLAG_RSA_SHA2_256,
179            _ => 0,
180        };
181        trace!("flags {flags:?}");
182        let r = AgentRequest::SignRequest(AgentSignRequest {
183            key_blob: Blob(key.pubkey()),
184            msg: Blob(msg),
185            flags,
186        });
187
188        match self.request(r).await? {
189            AgentResponse::SignResponse(s) => s.sig.0.try_into(),
190            resp => {
191                debug!("response: {resp:?}");
192                Err(Error::msg("Unexpected agent response"))
193            }
194        }
195    }
196}