russh/client/
mod.rs

1// Copyright 2016 Pierre-Étienne Meunier
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15
16//! # Implementing clients
17//!
18//! Maybe surprisingly, the data types used by Russh to implement
19//! clients are relatively more complicated than for servers. This is
20//! mostly related to the fact that clients are generally used both in
21//! a synchronous way (in the case of SSH, we can think of sending a
22//! shell command), and asynchronously (because the server may send
23//! unsollicited messages), and hence need to handle multiple
24//! interfaces.
25//!
26//! The [Session](client::Session) is passed to the [Handler](client::Handler)
27//! when the client receives data.
28//!
29//! Check out the following examples:
30//!
31//! * [Client that connects to a server, runs a command and prints its output](https://github.com/warp-tech/russh/blob/main/russh/examples/client_exec_simple.rs)
32//! * [Client that connects to a server, runs a command in a PTY and provides interactive input/output](https://github.com/warp-tech/russh/blob/main/russh/examples/client_exec_interactive.rs)
33//! * [SFTP client (with `russh-sftp`)](https://github.com/warp-tech/russh/blob/main/russh/examples/sftp_client.rs)
34//!
35//! [Session]: client::Session
36
37use std::collections::{HashMap, VecDeque};
38use std::convert::TryInto;
39use std::num::Wrapping;
40use std::pin::Pin;
41use std::sync::Arc;
42#[cfg(not(target_arch = "wasm32"))]
43use std::time::Duration;
44
45use futures::task::{Context, Poll};
46use futures::Future;
47use kex::ClientKex;
48use log::{debug, error, trace, warn};
49use russh_util::time::Instant;
50use ssh_encoding::Decode;
51use ssh_key::{Algorithm, Certificate, HashAlg, PrivateKey, PublicKey};
52use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
53use tokio::pin;
54use tokio::sync::mpsc::{
55    channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender,
56};
57use tokio::sync::oneshot;
58
59pub use crate::auth::AuthResult;
60use crate::channels::{
61    Channel, ChannelMsg, ChannelReadHalf, ChannelRef, ChannelWriteHalf, WindowSizeRef,
62};
63use crate::cipher::{self, clear, OpeningKey};
64use crate::kex::{KexCause, KexProgress, SessionKexState};
65use crate::keys::PrivateKeyWithHashAlg;
66use crate::msg::{is_kex_msg, validate_server_msg_strict_kex};
67use crate::session::{CommonSession, EncryptedState, GlobalRequestResponse, NewKeys};
68use crate::ssh_read::SshRead;
69use crate::sshbuffer::{IncomingSshPacket, PacketWriter, SSHBuffer, SshId};
70use crate::{
71    auth, map_err, msg, negotiation, ChannelId, ChannelOpenFailure, CryptoVec, Disconnect, Error,
72    Limits, MethodSet, Sig,
73};
74
75mod encrypted;
76mod kex;
77mod session;
78
79#[cfg(test)]
80mod test;
81
82/// Actual client session's state.
83///
84/// It is in charge of multiplexing and keeping track of various channels
85/// that may get opened and closed during the lifetime of an SSH session and
86/// allows sending messages to the server.
87#[derive(Debug)]
88pub struct Session {
89    kex: SessionKexState<ClientKex>,
90    common: CommonSession<Arc<Config>>,
91    receiver: Receiver<Msg>,
92    sender: UnboundedSender<Reply>,
93    channels: HashMap<ChannelId, ChannelRef>,
94    target_window_size: u32,
95    pending_reads: Vec<CryptoVec>,
96    pending_len: u32,
97    inbound_channel_sender: Sender<Msg>,
98    inbound_channel_receiver: Receiver<Msg>,
99    open_global_requests: VecDeque<GlobalRequestResponse>,
100    server_sig_algs: Option<Vec<Algorithm>>,
101}
102
103impl Drop for Session {
104    fn drop(&mut self) {
105        debug!("drop session")
106    }
107}
108
109#[derive(Debug)]
110#[allow(clippy::large_enum_variant)]
111enum Reply {
112    AuthSuccess,
113    AuthFailure {
114        proceed_with_methods: MethodSet,
115        partial_success: bool,
116    },
117    ChannelOpenFailure,
118    SignRequest {
119        key: ssh_key::PublicKey,
120        data: CryptoVec,
121    },
122    AuthInfoRequest {
123        name: String,
124        instructions: String,
125        prompts: Vec<Prompt>,
126    },
127}
128
129#[derive(Debug)]
130pub enum Msg {
131    Authenticate {
132        user: String,
133        method: auth::Method,
134    },
135    AuthInfoResponse {
136        responses: Vec<String>,
137    },
138    Signed {
139        data: CryptoVec,
140    },
141    ChannelOpenSession {
142        channel_ref: ChannelRef,
143    },
144    ChannelOpenX11 {
145        originator_address: String,
146        originator_port: u32,
147        channel_ref: ChannelRef,
148    },
149    ChannelOpenDirectTcpIp {
150        host_to_connect: String,
151        port_to_connect: u32,
152        originator_address: String,
153        originator_port: u32,
154        channel_ref: ChannelRef,
155    },
156    ChannelOpenDirectStreamLocal {
157        socket_path: String,
158        channel_ref: ChannelRef,
159    },
160    TcpIpForward {
161        /// Provide a channel for the reply result to request a reply from the server
162        reply_channel: Option<oneshot::Sender<Option<u32>>>,
163        address: String,
164        port: u32,
165    },
166    CancelTcpIpForward {
167        /// Provide a channel for the reply result to request a reply from the server
168        reply_channel: Option<oneshot::Sender<bool>>,
169        address: String,
170        port: u32,
171    },
172    StreamLocalForward {
173        /// Provide a channel for the reply result to request a reply from the server
174        reply_channel: Option<oneshot::Sender<bool>>,
175        socket_path: String,
176    },
177    CancelStreamLocalForward {
178        /// Provide a channel for the reply result to request a reply from the server
179        reply_channel: Option<oneshot::Sender<bool>>,
180        socket_path: String,
181    },
182    Close {
183        id: ChannelId,
184    },
185    Disconnect {
186        reason: Disconnect,
187        description: String,
188        language_tag: String,
189    },
190    Channel(ChannelId, ChannelMsg),
191    Rekey,
192    AwaitExtensionInfo {
193        extension_name: String,
194        reply_channel: oneshot::Sender<()>,
195    },
196    GetServerSigAlgs {
197        reply_channel: oneshot::Sender<Option<Vec<Algorithm>>>,
198    },
199    /// Send a keepalive packet to the remote
200    Keepalive {
201        want_reply: bool,
202    },
203    Ping {
204        reply_channel: oneshot::Sender<()>,
205    },
206    NoMoreSessions {
207        want_reply: bool,
208    },
209}
210
211impl From<(ChannelId, ChannelMsg)> for Msg {
212    fn from((id, msg): (ChannelId, ChannelMsg)) -> Self {
213        Msg::Channel(id, msg)
214    }
215}
216
217#[derive(Debug)]
218pub enum KeyboardInteractiveAuthResponse {
219    Success,
220    Failure {
221        /// The server suggests to proceed with these auth methods
222        remaining_methods: MethodSet,
223        /// The server says that though auth method has been accepted,
224        /// further authentication is required
225        partial_success: bool,
226    },
227    InfoRequest {
228        name: String,
229        instructions: String,
230        prompts: Vec<Prompt>,
231    },
232}
233
234#[derive(Debug)]
235pub struct Prompt {
236    pub prompt: String,
237    pub echo: bool,
238}
239
240#[derive(Debug)]
241pub struct RemoteDisconnectInfo {
242    pub reason_code: crate::Disconnect,
243    pub message: String,
244    pub lang_tag: String,
245}
246
247#[derive(Debug)]
248pub enum DisconnectReason<E: From<crate::Error> + Send> {
249    ReceivedDisconnect(RemoteDisconnectInfo),
250    Error(E),
251}
252
253/// Handle to a session, used to send messages to a client outside of
254/// the request/response cycle.
255pub struct Handle<H: Handler> {
256    sender: Sender<Msg>,
257    receiver: UnboundedReceiver<Reply>,
258    join: russh_util::runtime::JoinHandle<Result<(), H::Error>>,
259    channel_buffer_size: usize,
260}
261
262impl<H: Handler> Drop for Handle<H> {
263    fn drop(&mut self) {
264        debug!("drop handle")
265    }
266}
267
268impl<H: Handler> Handle<H> {
269    pub fn is_closed(&self) -> bool {
270        self.sender.is_closed()
271    }
272
273    /// Perform no authentication. This is useful for testing, but should not be
274    /// used in most other circumstances.
275    pub async fn authenticate_none<U: Into<String>>(
276        &mut self,
277        user: U,
278    ) -> Result<AuthResult, crate::Error> {
279        let user = user.into();
280        self.sender
281            .send(Msg::Authenticate {
282                user,
283                method: auth::Method::None,
284            })
285            .await
286            .map_err(|_| crate::Error::SendError)?;
287        self.wait_recv_reply().await
288    }
289
290    /// Perform password-based SSH authentication.
291    pub async fn authenticate_password<U: Into<String>, P: Into<String>>(
292        &mut self,
293        user: U,
294        password: P,
295    ) -> Result<AuthResult, crate::Error> {
296        let user = user.into();
297        self.sender
298            .send(Msg::Authenticate {
299                user,
300                method: auth::Method::Password {
301                    password: password.into(),
302                },
303            })
304            .await
305            .map_err(|_| crate::Error::SendError)?;
306        self.wait_recv_reply().await
307    }
308
309    /// Initiate Keyboard-Interactive based SSH authentication.
310    ///
311    /// * `submethods` - Hints to the server the preferred methods to be used for authentication
312    pub async fn authenticate_keyboard_interactive_start<
313        U: Into<String>,
314        S: Into<Option<String>>,
315    >(
316        &mut self,
317        user: U,
318        submethods: S,
319    ) -> Result<KeyboardInteractiveAuthResponse, crate::Error> {
320        self.sender
321            .send(Msg::Authenticate {
322                user: user.into(),
323                method: auth::Method::KeyboardInteractive {
324                    submethods: submethods.into().unwrap_or_else(|| "".to_owned()),
325                },
326            })
327            .await
328            .map_err(|_| crate::Error::SendError)?;
329        self.wait_recv_keyboard_interactive_reply().await
330    }
331
332    /// Respond to AuthInfoRequests from the server. A server can send any number of these Requests
333    /// including empty requests. You may have to call this function multple times in order to
334    /// complete Keyboard-Interactive based SSH authentication.
335    ///
336    /// * `responses` - The responses to each prompt. The number of responses must match the number
337    ///   of prompts. If a prompt has an empty string, then the response should be an empty string.
338    pub async fn authenticate_keyboard_interactive_respond(
339        &mut self,
340        responses: Vec<String>,
341    ) -> Result<KeyboardInteractiveAuthResponse, crate::Error> {
342        self.sender
343            .send(Msg::AuthInfoResponse { responses })
344            .await
345            .map_err(|_| crate::Error::SendError)?;
346        self.wait_recv_keyboard_interactive_reply().await
347    }
348
349    async fn wait_recv_keyboard_interactive_reply(
350        &mut self,
351    ) -> Result<KeyboardInteractiveAuthResponse, crate::Error> {
352        loop {
353            match self.receiver.recv().await {
354                Some(Reply::AuthSuccess) => return Ok(KeyboardInteractiveAuthResponse::Success),
355                Some(Reply::AuthFailure {
356                    proceed_with_methods: remaining_methods,
357                    partial_success,
358                }) => {
359                    return Ok(KeyboardInteractiveAuthResponse::Failure {
360                        remaining_methods,
361                        partial_success,
362                    })
363                }
364                Some(Reply::AuthInfoRequest {
365                    name,
366                    instructions,
367                    prompts,
368                }) => {
369                    return Ok(KeyboardInteractiveAuthResponse::InfoRequest {
370                        name,
371                        instructions,
372                        prompts,
373                    });
374                }
375                None => return Err(crate::Error::RecvError),
376                _ => {}
377            }
378        }
379    }
380
381    async fn wait_recv_reply(&mut self) -> Result<AuthResult, crate::Error> {
382        loop {
383            match self.receiver.recv().await {
384                Some(Reply::AuthSuccess) => return Ok(AuthResult::Success),
385                Some(Reply::AuthFailure {
386                    proceed_with_methods: remaining_methods,
387                    partial_success,
388                }) => {
389                    return Ok(AuthResult::Failure {
390                        remaining_methods,
391                        partial_success,
392                    })
393                }
394                None => {
395                    return Ok(AuthResult::Failure {
396                        remaining_methods: MethodSet::empty(),
397                        partial_success: false,
398                    })
399                }
400                _ => {}
401            }
402        }
403    }
404
405    /// Perform public key-based SSH authentication.
406    ///
407    /// For RSA keys, you'll need to decide on which hash algorithm to use.
408    /// This is the difference between what is also known as
409    /// `ssh-rsa`, `rsa-sha2-256`, and `rsa-sha2-512` "keys" in OpenSSH.
410    /// You can use [Handle::best_supported_rsa_hash] to automatically
411    /// figure out the best hash algorithm for RSA keys.
412    pub async fn authenticate_publickey<U: Into<String>>(
413        &mut self,
414        user: U,
415        key: PrivateKeyWithHashAlg,
416    ) -> Result<AuthResult, crate::Error> {
417        let user = user.into();
418        self.sender
419            .send(Msg::Authenticate {
420                user,
421                method: auth::Method::PublicKey { key },
422            })
423            .await
424            .map_err(|_| crate::Error::SendError)?;
425        self.wait_recv_reply().await
426    }
427
428    /// Perform public OpenSSH Certificate-based SSH authentication
429    pub async fn authenticate_openssh_cert<U: Into<String>>(
430        &mut self,
431        user: U,
432        key: Arc<PrivateKey>,
433        cert: Certificate,
434    ) -> Result<AuthResult, crate::Error> {
435        let user = user.into();
436        self.sender
437            .send(Msg::Authenticate {
438                user,
439                method: auth::Method::OpenSshCertificate { key, cert },
440            })
441            .await
442            .map_err(|_| crate::Error::SendError)?;
443        self.wait_recv_reply().await
444    }
445
446    /// Authenticate using a custom method that implements the
447    /// [`Signer`][auth::Signer] trait. Currently, this crate only provides an
448    /// implementation for an [SSH agent][crate::keys::agent::client::AgentClient].
449    pub async fn authenticate_publickey_with<U: Into<String>, S: auth::Signer>(
450        &mut self,
451        user: U,
452        key: ssh_key::PublicKey,
453        hash_alg: Option<HashAlg>,
454        signer: &mut S,
455    ) -> Result<AuthResult, S::Error> {
456        let user = user.into();
457        if self
458            .sender
459            .send(Msg::Authenticate {
460                user,
461                method: auth::Method::FuturePublicKey { key, hash_alg },
462            })
463            .await
464            .is_err()
465        {
466            return Err((crate::SendError {}).into());
467        }
468        loop {
469            let reply = self.receiver.recv().await;
470            match reply {
471                Some(Reply::AuthSuccess) => return Ok(AuthResult::Success),
472                Some(Reply::AuthFailure {
473                    proceed_with_methods: remaining_methods,
474                    partial_success,
475                }) => {
476                    return Ok(AuthResult::Failure {
477                        remaining_methods,
478                        partial_success,
479                    })
480                }
481                Some(Reply::SignRequest { key, data }) => {
482                    let data = signer.auth_publickey_sign(&key, hash_alg, data).await;
483                    let data = match data {
484                        Ok(data) => data,
485                        Err(e) => return Err(e),
486                    };
487                    if self.sender.send(Msg::Signed { data }).await.is_err() {
488                        return Err((crate::SendError {}).into());
489                    }
490                }
491                None => {
492                    return Ok(AuthResult::Failure {
493                        remaining_methods: MethodSet::empty(),
494                        partial_success: false,
495                    })
496                }
497                _ => {}
498            }
499        }
500    }
501
502    /// Wait for confirmation that a channel is open
503    async fn wait_channel_confirmation(
504        &self,
505        mut receiver: Receiver<ChannelMsg>,
506        window_size_ref: WindowSizeRef,
507    ) -> Result<Channel<Msg>, crate::Error> {
508        loop {
509            match receiver.recv().await {
510                Some(ChannelMsg::Open {
511                    id,
512                    max_packet_size,
513                    window_size,
514                }) => {
515                    window_size_ref.update(window_size).await;
516
517                    return Ok(Channel {
518                        write_half: ChannelWriteHalf {
519                            id,
520                            sender: self.sender.clone(),
521                            max_packet_size,
522                            window_size: window_size_ref,
523                        },
524                        read_half: ChannelReadHalf { receiver },
525                    });
526                }
527                Some(ChannelMsg::OpenFailure(reason)) => {
528                    return Err(crate::Error::ChannelOpenFailure(reason));
529                }
530                None => {
531                    debug!("channel confirmation sender was dropped");
532                    return Err(crate::Error::Disconnect);
533                }
534                msg => {
535                    debug!("msg = {:?}", msg);
536                }
537            }
538        }
539    }
540
541    /// See [`Handle::best_supported_rsa_hash`].
542    #[cfg(not(target_arch = "wasm32"))]
543    async fn await_extension_info(&self, extension_name: String) -> Result<(), crate::Error> {
544        let (sender, receiver) = oneshot::channel();
545        self.sender
546            .send(Msg::AwaitExtensionInfo {
547                extension_name,
548                reply_channel: sender,
549            })
550            .await
551            .map_err(|_| crate::Error::SendError)?;
552        let _ = tokio::time::timeout(Duration::from_secs(1), receiver).await;
553        Ok(())
554    }
555
556    /// Returns the best RSA hash algorithm supported by the server,
557    /// as indicated by the `server-sig-algs` extension.
558    /// If the server does not support the extension,
559    /// `None` is returned. In this case you may still attempt an authentication
560    /// with `rsa-sha2-256` or `rsa-sha2-512` and hope for the best.
561    /// If the server supports the extension, but does not support `rsa-sha2-*`,
562    /// `Some(None)` is returned.
563    ///
564    /// Note that this method will wait for up to 1 second for the server to
565    /// send the extension info if it hasn't done so yet (except when running under
566    /// WebAssembly). Unfortunately the timing of the EXT_INFO message cannot be known
567    /// in advance (RFC 8308).
568    ///
569    /// If this method returns `None` once, then for most SSH servers
570    /// you can assume that it will return `None` every time.
571    pub async fn best_supported_rsa_hash(&self) -> Result<Option<Option<HashAlg>>, Error> {
572        // Wait for the extension info from the server
573        #[cfg(not(target_arch = "wasm32"))]
574        self.await_extension_info("server-sig-algs".into()).await?;
575
576        let (sender, receiver) = oneshot::channel();
577
578        self.sender
579            .send(Msg::GetServerSigAlgs {
580                reply_channel: sender,
581            })
582            .await
583            .map_err(|_| crate::Error::SendError)?;
584
585        if let Some(ssa) = receiver.await.map_err(|_| Error::Inconsistent)? {
586            let possible_algs = [
587                Some(ssh_key::HashAlg::Sha512),
588                Some(ssh_key::HashAlg::Sha256),
589                None,
590            ];
591            for alg in possible_algs.into_iter() {
592                if ssa.contains(&Algorithm::Rsa { hash: alg }) {
593                    return Ok(Some(alg));
594                }
595            }
596        }
597
598        Ok(None)
599    }
600
601    /// Request a session channel (the most basic type of
602    /// channel). This function returns `Some(..)` immediately if the
603    /// connection is authenticated, but the channel only becomes
604    /// usable when it's confirmed by the server, as indicated by the
605    /// `confirmed` field of the corresponding `Channel`.
606    pub async fn channel_open_session(&self) -> Result<Channel<Msg>, crate::Error> {
607        let (sender, receiver) = channel(self.channel_buffer_size);
608        let channel_ref = ChannelRef::new(sender);
609        let window_size_ref = channel_ref.window_size().clone();
610
611        self.sender
612            .send(Msg::ChannelOpenSession { channel_ref })
613            .await
614            .map_err(|_| crate::Error::SendError)?;
615        self.wait_channel_confirmation(receiver, window_size_ref)
616            .await
617    }
618
619    /// Request an X11 channel, on which the X11 protocol may be tunneled.
620    pub async fn channel_open_x11<A: Into<String>>(
621        &self,
622        originator_address: A,
623        originator_port: u32,
624    ) -> Result<Channel<Msg>, crate::Error> {
625        let (sender, receiver) = channel(self.channel_buffer_size);
626        let channel_ref = ChannelRef::new(sender);
627        let window_size_ref = channel_ref.window_size().clone();
628
629        self.sender
630            .send(Msg::ChannelOpenX11 {
631                originator_address: originator_address.into(),
632                originator_port,
633                channel_ref,
634            })
635            .await
636            .map_err(|_| crate::Error::SendError)?;
637        self.wait_channel_confirmation(receiver, window_size_ref)
638            .await
639    }
640
641    /// Open a TCP/IP forwarding channel. This is usually done when a
642    /// connection comes to a locally forwarded TCP/IP port. See
643    /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-7). The
644    /// TCP/IP packets can then be tunneled through the channel using
645    /// `.data()`. After writing a stream to a channel using
646    /// [`.data()`][Channel::data], be sure to call [`.eof()`][Channel::eof] to
647    /// indicate that no more data will be sent, or you may see hangs when
648    /// writing large streams.
649    pub async fn channel_open_direct_tcpip<A: Into<String>, B: Into<String>>(
650        &self,
651        host_to_connect: A,
652        port_to_connect: u32,
653        originator_address: B,
654        originator_port: u32,
655    ) -> Result<Channel<Msg>, crate::Error> {
656        let (sender, receiver) = channel(self.channel_buffer_size);
657        let channel_ref = ChannelRef::new(sender);
658        let window_size_ref = channel_ref.window_size().clone();
659
660        self.sender
661            .send(Msg::ChannelOpenDirectTcpIp {
662                host_to_connect: host_to_connect.into(),
663                port_to_connect,
664                originator_address: originator_address.into(),
665                originator_port,
666                channel_ref,
667            })
668            .await
669            .map_err(|_| crate::Error::SendError)?;
670        self.wait_channel_confirmation(receiver, window_size_ref)
671            .await
672    }
673
674    pub async fn channel_open_direct_streamlocal<S: Into<String>>(
675        &self,
676        socket_path: S,
677    ) -> Result<Channel<Msg>, crate::Error> {
678        let (sender, receiver) = channel(self.channel_buffer_size);
679        let channel_ref = ChannelRef::new(sender);
680        let window_size_ref = channel_ref.window_size().clone();
681
682        self.sender
683            .send(Msg::ChannelOpenDirectStreamLocal {
684                socket_path: socket_path.into(),
685                channel_ref,
686            })
687            .await
688            .map_err(|_| crate::Error::SendError)?;
689        self.wait_channel_confirmation(receiver, window_size_ref)
690            .await
691    }
692
693    /// Requests the server to open a TCP/IP forward channel
694    ///
695    /// If port == 0 the server will choose a port that will be returned, returns 0 otherwise
696    pub async fn tcpip_forward<A: Into<String>>(
697        &mut self,
698        address: A,
699        port: u32,
700    ) -> Result<u32, crate::Error> {
701        let (reply_send, reply_recv) = oneshot::channel();
702        self.sender
703            .send(Msg::TcpIpForward {
704                reply_channel: Some(reply_send),
705                address: address.into(),
706                port,
707            })
708            .await
709            .map_err(|_| crate::Error::SendError)?;
710
711        match reply_recv.await {
712            Ok(Some(port)) => Ok(port),
713            Ok(None) => Err(crate::Error::RequestDenied),
714            Err(e) => {
715                error!("Unable to receive TcpIpForward result: {e:?}");
716                Err(crate::Error::Disconnect)
717            }
718        }
719    }
720
721    // Requests the server to close a TCP/IP forward channel
722    pub async fn cancel_tcpip_forward<A: Into<String>>(
723        &self,
724        address: A,
725        port: u32,
726    ) -> Result<(), crate::Error> {
727        let (reply_send, reply_recv) = oneshot::channel();
728        self.sender
729            .send(Msg::CancelTcpIpForward {
730                reply_channel: Some(reply_send),
731                address: address.into(),
732                port,
733            })
734            .await
735            .map_err(|_| crate::Error::SendError)?;
736
737        match reply_recv.await {
738            Ok(true) => Ok(()),
739            Ok(false) => Err(crate::Error::RequestDenied),
740            Err(e) => {
741                error!("Unable to receive CancelTcpIpForward result: {e:?}");
742                Err(crate::Error::Disconnect)
743            }
744        }
745    }
746
747    // Requests the server to open a UDS forward channel
748    pub async fn streamlocal_forward<A: Into<String>>(
749        &mut self,
750        socket_path: A,
751    ) -> Result<(), crate::Error> {
752        let (reply_send, reply_recv) = oneshot::channel();
753        self.sender
754            .send(Msg::StreamLocalForward {
755                reply_channel: Some(reply_send),
756                socket_path: socket_path.into(),
757            })
758            .await
759            .map_err(|_| crate::Error::SendError)?;
760
761        match reply_recv.await {
762            Ok(true) => Ok(()),
763            Ok(false) => Err(crate::Error::RequestDenied),
764            Err(e) => {
765                error!("Unable to receive StreamLocalForward result: {e:?}");
766                Err(crate::Error::Disconnect)
767            }
768        }
769    }
770
771    // Requests the server to close a UDS forward channel
772    pub async fn cancel_streamlocal_forward<A: Into<String>>(
773        &self,
774        socket_path: A,
775    ) -> Result<(), crate::Error> {
776        let (reply_send, reply_recv) = oneshot::channel();
777        self.sender
778            .send(Msg::CancelStreamLocalForward {
779                reply_channel: Some(reply_send),
780                socket_path: socket_path.into(),
781            })
782            .await
783            .map_err(|_| crate::Error::SendError)?;
784
785        match reply_recv.await {
786            Ok(true) => Ok(()),
787            Ok(false) => Err(crate::Error::RequestDenied),
788            Err(e) => {
789                error!("Unable to receive CancelStreamLocalForward result: {e:?}");
790                Err(crate::Error::Disconnect)
791            }
792        }
793    }
794
795    /// Sends a disconnect message.
796    pub async fn disconnect(
797        &self,
798        reason: Disconnect,
799        description: &str,
800        language_tag: &str,
801    ) -> Result<(), crate::Error> {
802        self.sender
803            .send(Msg::Disconnect {
804                reason,
805                description: description.into(),
806                language_tag: language_tag.into(),
807            })
808            .await
809            .map_err(|_| crate::Error::SendError)?;
810        Ok(())
811    }
812
813    /// Send data to the session referenced by this handler.
814    ///
815    /// This is useful for server-initiated channels; for channels created by
816    /// the client, prefer to use the Channel returned from the `open_*` methods.
817    pub async fn data(&self, id: ChannelId, data: CryptoVec) -> Result<(), CryptoVec> {
818        self.sender
819            .send(Msg::Channel(id, ChannelMsg::Data { data }))
820            .await
821            .map_err(|e| match e.0 {
822                Msg::Channel(_, ChannelMsg::Data { data, .. }) => data,
823                _ => unreachable!(),
824            })
825    }
826
827    /// Asynchronously perform a session re-key at the next opportunity
828    pub async fn rekey_soon(&self) -> Result<(), Error> {
829        self.sender
830            .send(Msg::Rekey)
831            .await
832            .map_err(|_| Error::SendError)?;
833
834        Ok(())
835    }
836
837    /// Send a keepalive package to the remote peer.
838    pub async fn send_keepalive(&self, want_reply: bool) -> Result<(), Error> {
839        self.sender
840            .send(Msg::Keepalive { want_reply })
841            .await
842            .map_err(|_| Error::SendError)
843    }
844
845    /// Send a keepalive/ping package to the remote peer, and wait for the reply/pong.
846    pub async fn send_ping(&self) -> Result<(), Error> {
847        let (sender, receiver) = oneshot::channel();
848        self.sender
849            .send(Msg::Ping {
850                reply_channel: sender,
851            })
852            .await
853            .map_err(|_| Error::SendError)?;
854        let _ = receiver.await;
855        Ok(())
856    }
857
858    /// Send a no-more-sessions request to the remote peer.
859    pub async fn no_more_sessions(&self, want_reply: bool) -> Result<(), Error> {
860        self.sender
861            .send(Msg::NoMoreSessions { want_reply })
862            .await
863            .map_err(|_| Error::SendError)
864    }
865}
866
867impl<H: Handler> Future for Handle<H> {
868    type Output = Result<(), H::Error>;
869    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
870        match Future::poll(Pin::new(&mut self.join), cx) {
871            Poll::Ready(r) => Poll::Ready(match r {
872                Ok(Ok(x)) => Ok(x),
873                Err(e) => Err(crate::Error::from(e).into()),
874                Ok(Err(e)) => Err(e),
875            }),
876            Poll::Pending => Poll::Pending,
877        }
878    }
879}
880
881/// Connect to a server at the address specified, using the [`Handler`]
882/// (implemented by you) and [`Config`] specified. Returns a future that
883/// resolves to a [`Handle`]. This handle can then be used to create channels,
884/// which in turn can be used to tunnel TCP connections, request a PTY, execute
885/// commands, etc. The future will resolve to an error if the connection fails.
886/// This function creates a connection to the `addr` specified using a
887/// [`tokio::net::TcpStream`] and then calls [`connect_stream`] under the hood.
888#[cfg(not(target_arch = "wasm32"))]
889pub async fn connect<H: Handler + Send + 'static, A: tokio::net::ToSocketAddrs>(
890    config: Arc<Config>,
891    addrs: A,
892    handler: H,
893) -> Result<Handle<H>, H::Error> {
894    let socket = map_err!(tokio::net::TcpStream::connect(addrs).await)?;
895    if config.as_ref().nodelay {
896        if let Err(e) = socket.set_nodelay(true) {
897            warn!("set_nodelay() failed: {e:?}");
898        }
899    }
900
901    connect_stream(config, socket, handler).await
902}
903
904/// Connect a stream to a server. This stream must implement
905/// [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`], as well as [`Unpin`]
906/// and [`Send`]. Typically, you may prefer to use [`connect`], which uses a
907/// [`tokio::net::TcpStream`] and then calls this function under the hood.
908pub async fn connect_stream<H, R>(
909    config: Arc<Config>,
910    mut stream: R,
911    handler: H,
912) -> Result<Handle<H>, H::Error>
913where
914    H: Handler + Send + 'static,
915    R: AsyncRead + AsyncWrite + Unpin + Send + 'static,
916{
917    // Writing SSH id.
918    let mut write_buffer = SSHBuffer::new();
919
920    debug!("ssh id = {:?}", config.as_ref().client_id);
921
922    write_buffer.send_ssh_id(&config.as_ref().client_id);
923    map_err!(stream.write_all(&write_buffer.buffer).await)?;
924
925    // Reading SSH id and allocating a session if correct.
926    let mut stream = SshRead::new(stream);
927    let sshid = stream.read_ssh_id().await?;
928
929    let (handle_sender, session_receiver) = channel(10);
930    let (session_sender, handle_receiver) = unbounded_channel();
931    if config.maximum_packet_size > 65535 {
932        error!(
933            "Maximum packet size ({:?}) should not larger than a TCP packet (65535)",
934            config.maximum_packet_size
935        );
936    }
937    let channel_buffer_size = config.channel_buffer_size;
938    let mut session = Session::new(
939        config.window_size,
940        CommonSession {
941            packet_writer: PacketWriter::clear(),
942            auth_user: String::new(),
943            auth_attempts: 0,
944            auth_method: None, // Client only.
945            remote_to_local: Box::new(clear::Key),
946            encrypted: None,
947            config,
948            wants_reply: false,
949            disconnected: false,
950            buffer: CryptoVec::new(),
951            strict_kex: false,
952            alive_timeouts: 0,
953            received_data: false,
954            remote_sshid: sshid.into(),
955        },
956        session_receiver,
957        session_sender,
958    );
959    session.begin_rekey()?;
960    let (kex_done_signal, kex_done_signal_rx) = oneshot::channel();
961    let join = russh_util::runtime::spawn(session.run(stream, handler, Some(kex_done_signal)));
962
963    if let Err(err) = kex_done_signal_rx.await {
964        // kex_done_signal Sender is dropped when the session
965        // fails before a succesful key exchange
966        debug!("kex_done_signal sender was dropped {err:?}");
967        join.await.map_err(crate::Error::Join)??;
968        return Err(H::Error::from(crate::Error::Disconnect));
969    }
970
971    Ok(Handle {
972        sender: handle_sender,
973        receiver: handle_receiver,
974        join,
975        channel_buffer_size,
976    })
977}
978
979async fn start_reading<R: AsyncRead + Unpin>(
980    mut stream_read: R,
981    mut buffer: SSHBuffer,
982    mut cipher: Box<dyn OpeningKey + Send>,
983) -> Result<(usize, R, SSHBuffer, Box<dyn OpeningKey + Send>), crate::Error> {
984    buffer.buffer.clear();
985    let n = cipher::read(&mut stream_read, &mut buffer, &mut *cipher).await?;
986    Ok((n, stream_read, buffer, cipher))
987}
988
989impl Session {
990    fn maybe_decompress(&mut self, buffer: &SSHBuffer) -> Result<IncomingSshPacket, Error> {
991        if let Some(ref mut enc) = self.common.encrypted {
992            let mut decomp = CryptoVec::new();
993            Ok(IncomingSshPacket {
994                #[allow(clippy::indexing_slicing)] // length checked
995                buffer: enc.decompress.decompress(
996                    &buffer.buffer[5..],
997                    &mut decomp,
998                )?.into(),
999                seqn: buffer.seqn,
1000            })
1001        } else {
1002            Ok(IncomingSshPacket {
1003                #[allow(clippy::indexing_slicing)] // length checked
1004                buffer: buffer.buffer[5..].into(),
1005                seqn: buffer.seqn,
1006            })
1007        }
1008    }
1009
1010    fn new(
1011        target_window_size: u32,
1012        common: CommonSession<Arc<Config>>,
1013        receiver: Receiver<Msg>,
1014        sender: UnboundedSender<Reply>,
1015    ) -> Self {
1016        let (inbound_channel_sender, inbound_channel_receiver) = channel(10);
1017        Self {
1018            common,
1019            receiver,
1020            sender,
1021            kex: SessionKexState::Idle,
1022            target_window_size,
1023            inbound_channel_sender,
1024            inbound_channel_receiver,
1025            channels: HashMap::new(),
1026            pending_reads: Vec::new(),
1027            pending_len: 0,
1028            open_global_requests: VecDeque::new(),
1029            server_sig_algs: None,
1030        }
1031    }
1032
1033    async fn run<H: Handler + Send, R: AsyncRead + AsyncWrite + Unpin + Send>(
1034        mut self,
1035        stream: SshRead<R>,
1036        mut handler: H,
1037        mut kex_done_signal: Option<oneshot::Sender<()>>,
1038    ) -> Result<(), H::Error> {
1039        let (stream_read, mut stream_write) = stream.split();
1040        let result = self
1041            .run_inner(
1042                stream_read,
1043                &mut stream_write,
1044                &mut handler,
1045                &mut kex_done_signal,
1046            )
1047            .await;
1048        trace!("disconnected");
1049        self.receiver.close();
1050        self.inbound_channel_receiver.close();
1051        map_err!(stream_write.shutdown().await)?;
1052        match result {
1053            Ok(v) => {
1054                handler
1055                    .disconnected(DisconnectReason::ReceivedDisconnect(v))
1056                    .await?;
1057                Ok(())
1058            }
1059            Err(e) => {
1060                if kex_done_signal.is_some() {
1061                    // The kex signal has not been consumed yet,
1062                    // so we can send return the concrete error to be propagated
1063                    // into the JoinHandle and returned from `connect_stream`
1064                    Err(e)
1065                } else {
1066                    // The kex signal has been consumed, so no one is
1067                    // awaiting the result of this coroutine
1068                    // We're better off passing the error into the Handler
1069                    debug!("disconnected {e:?}");
1070                    handler.disconnected(DisconnectReason::Error(e)).await?;
1071                    Err(H::Error::from(crate::Error::Disconnect))
1072                }
1073            }
1074        }
1075    }
1076
1077    async fn run_inner<H: Handler + Send, R: AsyncRead + AsyncWrite + Unpin + Send>(
1078        &mut self,
1079        stream_read: SshRead<ReadHalf<R>>,
1080        stream_write: &mut WriteHalf<R>,
1081        handler: &mut H,
1082        kex_done_signal: &mut Option<tokio::sync::oneshot::Sender<()>>,
1083    ) -> Result<RemoteDisconnectInfo, H::Error> {
1084        let mut result: Result<RemoteDisconnectInfo, H::Error> = Err(Error::Disconnect.into());
1085        self.flush()?;
1086
1087        map_err!(self.common.packet_writer.flush_into(stream_write).await)?;
1088
1089        let buffer = SSHBuffer::new();
1090
1091        // Allow handing out references to the cipher
1092        let mut opening_cipher = Box::new(clear::Key) as Box<dyn OpeningKey + Send>;
1093        std::mem::swap(&mut opening_cipher, &mut self.common.remote_to_local);
1094
1095        let keepalive_timer =
1096            crate::future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep);
1097        pin!(keepalive_timer);
1098
1099        let inactivity_timer =
1100            crate::future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep);
1101        pin!(inactivity_timer);
1102
1103        let reading = start_reading(stream_read, buffer, opening_cipher);
1104        pin!(reading);
1105
1106        #[allow(clippy::panic)] // false positive in select! macro
1107        while !self.common.disconnected {
1108            self.common.received_data = false;
1109            let mut sent_keepalive = false;
1110            tokio::select! {
1111                r = &mut reading => {
1112                    let (stream_read, mut buffer, mut opening_cipher) = match r {
1113                        Ok((_, stream_read, buffer, opening_cipher)) => (stream_read, buffer, opening_cipher),
1114                        Err(e) => return Err(e.into())
1115                    };
1116
1117                    std::mem::swap(&mut opening_cipher, &mut self.common.remote_to_local);
1118
1119                    if buffer.buffer.len() < 5 {
1120                        break
1121                    }
1122
1123                    let mut pkt = self.maybe_decompress(&buffer)?;
1124                    if !pkt.buffer.is_empty() {
1125                        #[allow(clippy::indexing_slicing)] // length checked
1126                        if pkt.buffer[0] == crate::msg::DISCONNECT {
1127                            debug!("received disconnect");
1128                            result = self.process_disconnect(&pkt).map_err(H::Error::from);
1129                        } else {
1130                            self.common.received_data = true;
1131                            reply(self, handler, kex_done_signal, &mut pkt).await?;
1132                            buffer.seqn = pkt.seqn; // TODO reply changes seqn internall, find cleaner way
1133                        }
1134                    }
1135
1136                    std::mem::swap(&mut opening_cipher, &mut self.common.remote_to_local);
1137                    reading.set(start_reading(stream_read, buffer, opening_cipher));
1138                }
1139                () = &mut keepalive_timer => {
1140                    self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1);
1141                    if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max {
1142                        debug!("Timeout, server not responding to keepalives");
1143                        return Err(crate::Error::KeepaliveTimeout.into());
1144                    }
1145                    sent_keepalive = true;
1146                    self.send_keepalive(true)?;
1147                }
1148                () = &mut inactivity_timer => {
1149                    debug!("timeout");
1150                    return Err(crate::Error::InactivityTimeout.into());
1151                }
1152                msg = self.receiver.recv(), if !self.kex.active() => {
1153                    match msg {
1154                        Some(msg) => self.handle_msg(msg)?,
1155                        None => {
1156                            self.common.disconnected = true;
1157                            break
1158                        }
1159                    };
1160
1161                    // eagerly take all outgoing messages so writes are batched
1162                    while !self.kex.active() {
1163                        match self.receiver.try_recv() {
1164                            Ok(next) => self.handle_msg(next)?,
1165                            Err(_) => break
1166                        }
1167                    }
1168                }
1169                msg = self.inbound_channel_receiver.recv(), if !self.kex.active() => {
1170                    match msg {
1171                        Some(msg) => self.handle_msg(msg)?,
1172                        None => (),
1173                    }
1174
1175                    // eagerly take all outgoing messages so writes are batched
1176                    while !self.kex.active() {
1177                        match self.inbound_channel_receiver.try_recv() {
1178                            Ok(next) => self.handle_msg(next)?,
1179                            Err(_) => break
1180                        }
1181                    }
1182                }
1183            };
1184
1185            self.flush()?;
1186            map_err!(self.common.packet_writer.flush_into(stream_write).await)?;
1187
1188            if let Some(ref mut enc) = self.common.encrypted {
1189                if let EncryptedState::InitCompression = enc.state {
1190                    enc.client_compression
1191                        .init_compress(self.common.packet_writer.compress());
1192                    enc.state = EncryptedState::Authenticated;
1193                }
1194            }
1195
1196            if self.common.received_data {
1197                // Reset the number of failed keepalive attempts. We don't
1198                // bother detecting keepalive response messages specifically
1199                // (OpenSSH_9.6p1 responds with REQUEST_FAILURE aka 82). Instead
1200                // we assume that the server is still alive if we receive any
1201                // data from it.
1202                self.common.alive_timeouts = 0;
1203            }
1204            if self.common.received_data || sent_keepalive {
1205                if let (futures::future::Either::Right(ref mut sleep), Some(d)) = (
1206                    keepalive_timer.as_mut().as_pin_mut(),
1207                    self.common.config.keepalive_interval,
1208                ) {
1209                    sleep.as_mut().reset(tokio::time::Instant::now() + d);
1210                }
1211            }
1212            if !sent_keepalive {
1213                if let (futures::future::Either::Right(ref mut sleep), Some(d)) = (
1214                    inactivity_timer.as_mut().as_pin_mut(),
1215                    self.common.config.inactivity_timeout,
1216                ) {
1217                    sleep.as_mut().reset(tokio::time::Instant::now() + d);
1218                }
1219            }
1220        }
1221
1222        result
1223    }
1224
1225    fn process_disconnect(
1226        &mut self,
1227        pkt: &IncomingSshPacket,
1228    ) -> Result<RemoteDisconnectInfo, Error> {
1229        let mut r = &pkt.buffer[..];
1230        u8::decode(&mut r)?; // skip message type
1231        self.common.disconnected = true;
1232
1233        let reason_code = u32::decode(&mut r)?.try_into()?;
1234        let message = String::decode(&mut r)?;
1235        let lang_tag = String::decode(&mut r)?;
1236
1237        Ok(RemoteDisconnectInfo {
1238            reason_code,
1239            message,
1240            lang_tag,
1241        })
1242    }
1243
1244    fn handle_msg(&mut self, msg: Msg) -> Result<(), crate::Error> {
1245        match msg {
1246            Msg::Authenticate { user, method } => {
1247                self.write_auth_request_if_needed(&user, method)?;
1248            }
1249            Msg::Signed { .. } => {}
1250            Msg::AuthInfoResponse { .. } => {}
1251            Msg::ChannelOpenSession { channel_ref } => {
1252                let id = self.channel_open_session()?;
1253                self.channels.insert(id, channel_ref);
1254            }
1255            Msg::ChannelOpenX11 {
1256                originator_address,
1257                originator_port,
1258                channel_ref,
1259            } => {
1260                let id = self.channel_open_x11(&originator_address, originator_port)?;
1261                self.channels.insert(id, channel_ref);
1262            }
1263            Msg::ChannelOpenDirectTcpIp {
1264                host_to_connect,
1265                port_to_connect,
1266                originator_address,
1267                originator_port,
1268                channel_ref,
1269            } => {
1270                let id = self.channel_open_direct_tcpip(
1271                    &host_to_connect,
1272                    port_to_connect,
1273                    &originator_address,
1274                    originator_port,
1275                )?;
1276                self.channels.insert(id, channel_ref);
1277            }
1278            Msg::ChannelOpenDirectStreamLocal {
1279                socket_path,
1280                channel_ref,
1281            } => {
1282                let id = self.channel_open_direct_streamlocal(&socket_path)?;
1283                self.channels.insert(id, channel_ref);
1284            }
1285            Msg::TcpIpForward {
1286                reply_channel,
1287                address,
1288                port,
1289            } => self.tcpip_forward(reply_channel, &address, port)?,
1290            Msg::CancelTcpIpForward {
1291                reply_channel,
1292                address,
1293                port,
1294            } => self.cancel_tcpip_forward(reply_channel, &address, port)?,
1295            Msg::StreamLocalForward {
1296                reply_channel,
1297                socket_path,
1298            } => self.streamlocal_forward(reply_channel, &socket_path)?,
1299            Msg::CancelStreamLocalForward {
1300                reply_channel,
1301                socket_path,
1302            } => self.cancel_streamlocal_forward(reply_channel, &socket_path)?,
1303            Msg::Disconnect {
1304                reason,
1305                description,
1306                language_tag,
1307            } => self.disconnect(reason, &description, &language_tag)?,
1308            Msg::Channel(id, ChannelMsg::Data { data }) => self.data(id, data)?,
1309            Msg::Channel(id, ChannelMsg::Eof) => {
1310                self.eof(id)?;
1311            }
1312            Msg::Channel(id, ChannelMsg::ExtendedData { data, ext }) => {
1313                self.extended_data(id, ext, data)?;
1314            }
1315            Msg::Channel(
1316                id,
1317                ChannelMsg::RequestPty {
1318                    want_reply,
1319                    term,
1320                    col_width,
1321                    row_height,
1322                    pix_width,
1323                    pix_height,
1324                    terminal_modes,
1325                },
1326            ) => self.request_pty(
1327                id,
1328                want_reply,
1329                &term,
1330                col_width,
1331                row_height,
1332                pix_width,
1333                pix_height,
1334                &terminal_modes,
1335            )?,
1336            Msg::Channel(
1337                id,
1338                ChannelMsg::WindowChange {
1339                    col_width,
1340                    row_height,
1341                    pix_width,
1342                    pix_height,
1343                },
1344            ) => self.window_change(id, col_width, row_height, pix_width, pix_height)?,
1345            Msg::Channel(
1346                id,
1347                ChannelMsg::RequestX11 {
1348                    want_reply,
1349                    single_connection,
1350                    x11_authentication_protocol,
1351                    x11_authentication_cookie,
1352                    x11_screen_number,
1353                },
1354            ) => self.request_x11(
1355                id,
1356                want_reply,
1357                single_connection,
1358                &x11_authentication_protocol,
1359                &x11_authentication_cookie,
1360                x11_screen_number,
1361            )?,
1362            Msg::Channel(
1363                id,
1364                ChannelMsg::SetEnv {
1365                    want_reply,
1366                    variable_name,
1367                    variable_value,
1368                },
1369            ) => self.set_env(id, want_reply, &variable_name, &variable_value)?,
1370            Msg::Channel(id, ChannelMsg::RequestShell { want_reply }) => {
1371                self.request_shell(want_reply, id)?
1372            }
1373            Msg::Channel(
1374                id,
1375                ChannelMsg::Exec {
1376                    want_reply,
1377                    command,
1378                },
1379            ) => self.exec(id, want_reply, &command)?,
1380            Msg::Channel(id, ChannelMsg::Signal { signal }) => self.signal(id, signal)?,
1381            Msg::Channel(id, ChannelMsg::RequestSubsystem { want_reply, name }) => {
1382                self.request_subsystem(want_reply, id, &name)?
1383            }
1384            Msg::Channel(id, ChannelMsg::AgentForward { want_reply }) => {
1385                self.agent_forward(id, want_reply)?
1386            }
1387            Msg::Channel(id, ChannelMsg::Close) => self.close(id)?,
1388            Msg::Rekey => self.initiate_rekey()?,
1389            Msg::AwaitExtensionInfo {
1390                extension_name,
1391                reply_channel,
1392            } => {
1393                if let Some(ref mut enc) = self.common.encrypted {
1394                    // Drop if the extension has been seen already
1395                    if !enc.received_extensions.contains(&extension_name) {
1396                        // There will be no new extension info after authentication
1397                        // has succeeded
1398                        if !matches!(enc.state, EncryptedState::Authenticated) {
1399                            enc.extension_info_awaiters
1400                                .entry(extension_name)
1401                                .or_insert(vec![])
1402                                .push(reply_channel);
1403                        }
1404                    }
1405                }
1406            }
1407            Msg::GetServerSigAlgs { reply_channel } => {
1408                let _ = reply_channel.send(self.server_sig_algs.clone());
1409            }
1410            Msg::Keepalive { want_reply } => {
1411                let _ = self.send_keepalive(want_reply);
1412            }
1413            Msg::Ping { reply_channel } => {
1414                let _ = self.send_ping(reply_channel);
1415            }
1416            Msg::NoMoreSessions { want_reply } => {
1417                let _ = self.no_more_sessions(want_reply);
1418            }
1419            msg => {
1420                // should be unreachable, since the receiver only gets
1421                // messages from methods implemented within russh
1422                unimplemented!("unimplemented (server-only?) message: {:?}", msg)
1423            }
1424        }
1425        Ok(())
1426    }
1427
1428    fn begin_rekey(&mut self) -> Result<(), crate::Error> {
1429        debug!("beginning re-key");
1430        let mut kex = ClientKex::new(
1431            self.common.config.clone(),
1432            &self.common.config.client_id,
1433            &self.common.remote_sshid,
1434            match &self.common.encrypted {
1435                None => KexCause::Initial,
1436                Some(enc) => KexCause::Rekey {
1437                    strict: self.common.strict_kex,
1438                    session_id: enc.session_id.clone(),
1439                },
1440            },
1441        );
1442
1443        kex.kexinit(&mut self.common.packet_writer)?;
1444        self.kex = SessionKexState::InProgress(kex);
1445        Ok(())
1446    }
1447
1448    /// Flush the temporary cleartext buffer into the encryption
1449    /// buffer. This does *not* flush to the socket.
1450    fn flush(&mut self) -> Result<(), crate::Error> {
1451        if let Some(ref mut enc) = self.common.encrypted {
1452            if enc.flush(
1453                &self.common.config.as_ref().limits,
1454                &mut self.common.packet_writer,
1455            )? && !self.kex.active()
1456            {
1457                self.begin_rekey()?;
1458            }
1459        }
1460        Ok(())
1461    }
1462
1463    /// Immediately trigger a session re-key after flushing all pending packets
1464    pub fn initiate_rekey(&mut self) -> Result<(), Error> {
1465        if let Some(ref mut enc) = self.common.encrypted {
1466            enc.rekey_wanted = true;
1467            self.flush()?
1468        }
1469        Ok(())
1470    }
1471}
1472
1473async fn reply<H: Handler>(
1474    session: &mut Session,
1475    handler: &mut H,
1476    kex_done_signal: &mut Option<tokio::sync::oneshot::Sender<()>>,
1477    pkt: &mut IncomingSshPacket,
1478) -> Result<(), H::Error> {
1479    if let Some(message_type) = pkt.buffer.first() {
1480        debug!(
1481            "< msg type {message_type:?}, seqn {:?}, len {}",
1482            pkt.seqn.0,
1483            pkt.buffer.len()
1484        );
1485        if session.common.strict_kex && session.common.encrypted.is_none() {
1486            let seqno = pkt.seqn.0 - 1; // was incremented after read()
1487            validate_server_msg_strict_kex(*message_type, seqno as usize)?;
1488        }
1489
1490        if [msg::IGNORE, msg::UNIMPLEMENTED, msg::DEBUG].contains(message_type) {
1491            return Ok(());
1492        }
1493    }
1494
1495    if pkt.buffer.first() == Some(&msg::KEXINIT) && session.kex == SessionKexState::Idle {
1496        // Not currently in a rekey but received KEXINIT
1497        debug!("server has initiated re-key");
1498        session.begin_rekey()?;
1499        // Kex will consume the packet right away
1500    }
1501
1502    let is_kex_msg = pkt.buffer.first().cloned().map(is_kex_msg).unwrap_or(false);
1503
1504    if is_kex_msg {
1505        if let SessionKexState::InProgress(kex) = session.kex.take() {
1506            let progress = kex.step(Some(pkt), &mut session.common.packet_writer)?;
1507
1508            match progress {
1509                KexProgress::NeedsReply { kex, reset_seqn } => {
1510                    debug!("kex impl continues: {kex:?}");
1511                    session.kex = SessionKexState::InProgress(kex);
1512                    if reset_seqn {
1513                        debug!("kex impl requests seqno reset");
1514                        session.common.reset_seqn();
1515                    }
1516                }
1517                KexProgress::Done {
1518                    server_host_key,
1519                    newkeys,
1520                } => {
1521                    debug!("kex impl has completed");
1522                    session.common.strict_kex =
1523                        session.common.strict_kex || newkeys.names.strict_kex();
1524
1525                    if let Some(ref mut enc) = session.common.encrypted {
1526                        // This is a rekey
1527                        enc.last_rekey = Instant::now();
1528                        session.common.packet_writer.buffer().bytes = 0;
1529                        enc.flush_all_pending()?;
1530                        let mut pending = std::mem::take(&mut session.pending_reads);
1531                        for p in pending.drain(..) {
1532                            session.process_packet(handler, &p).await?;
1533                        }
1534                        session.pending_reads = pending;
1535                        session.pending_len = 0;
1536                        session.common.newkeys(newkeys);
1537                    } else {
1538                        // This is the initial kex
1539                        if let Some(server_host_key) = &server_host_key {
1540                            let check = handler.check_server_key(server_host_key).await?;
1541                            if !check {
1542                                return Err(crate::Error::UnknownKey.into());
1543                            }
1544                        }
1545
1546                        session
1547                            .common
1548                            .encrypted(initial_encrypted_state(session), newkeys);
1549
1550                        if let Some(sender) = kex_done_signal.take() {
1551                            sender.send(()).unwrap_or(());
1552                        }
1553                    }
1554
1555                    session.kex = SessionKexState::Idle;
1556
1557                    if session.common.strict_kex {
1558                        pkt.seqn = Wrapping(0);
1559                    }
1560
1561                    debug!("kex done");
1562                }
1563            }
1564
1565            session.flush()?;
1566
1567            return Ok(());
1568        }
1569    }
1570
1571    session.client_read_encrypted(handler, pkt).await
1572}
1573
1574fn initial_encrypted_state(session: &Session) -> EncryptedState {
1575    if session.common.config.anonymous {
1576        EncryptedState::Authenticated
1577    } else {
1578        EncryptedState::WaitingAuthServiceRequest {
1579            accepted: false,
1580            sent: false,
1581        }
1582    }
1583}
1584
1585/// Parameters for dynamic group Diffie-Hellman key exchanges.
1586#[derive(Debug, Clone)]
1587pub struct GexParams {
1588    /// Minimum DH group size (in bits)
1589    min_group_size: usize,
1590    /// Preferred DH group size (in bits)
1591    preferred_group_size: usize,
1592    /// Maximum DH group size (in bits)
1593    max_group_size: usize,
1594}
1595
1596impl GexParams {
1597    pub fn new(
1598        min_group_size: usize,
1599        preferred_group_size: usize,
1600        max_group_size: usize,
1601    ) -> Result<Self, Error> {
1602        let this = Self {
1603            min_group_size,
1604            preferred_group_size,
1605            max_group_size,
1606        };
1607        this.validate()?;
1608        Ok(this)
1609    }
1610
1611    pub(crate) fn validate(&self) -> Result<(), Error> {
1612        if self.min_group_size < 2048 {
1613            return Err(Error::InvalidConfig(
1614                "min_group_size must be at least 2048 bits".into(),
1615            ));
1616        }
1617        if self.preferred_group_size < self.min_group_size {
1618            return Err(Error::InvalidConfig(
1619                "preferred_group_size must be at least as large as min_group_size".into(),
1620            ));
1621        }
1622        if self.max_group_size < self.preferred_group_size {
1623            return Err(Error::InvalidConfig(
1624                "max_group_size must be at least as large as preferred_group_size".into(),
1625            ));
1626        }
1627        Ok(())
1628    }
1629
1630    pub fn min_group_size(&self) -> usize {
1631        self.min_group_size
1632    }
1633
1634    pub fn preferred_group_size(&self) -> usize {
1635        self.preferred_group_size
1636    }
1637
1638    pub fn max_group_size(&self) -> usize {
1639        self.max_group_size
1640    }
1641}
1642
1643impl Default for GexParams {
1644    fn default() -> GexParams {
1645        GexParams {
1646            min_group_size: 3072,
1647            preferred_group_size: 8192,
1648            max_group_size: 8192,
1649        }
1650    }
1651}
1652
1653/// The configuration of clients.
1654#[derive(Debug)]
1655pub struct Config {
1656    /// The client ID string sent at the beginning of the protocol.
1657    pub client_id: SshId,
1658    /// The bytes and time limits before key re-exchange.
1659    pub limits: Limits,
1660    /// The initial size of a channel (used for flow control).
1661    pub window_size: u32,
1662    /// The maximal size of a single packet.
1663    pub maximum_packet_size: u32,
1664    /// Buffer size for each channel (a number of unprocessed messages to store before propagating backpressure to the TCP stream)
1665    pub channel_buffer_size: usize,
1666    /// Lists of preferred algorithms.
1667    pub preferred: negotiation::Preferred,
1668    /// Time after which the connection is garbage-collected.
1669    pub inactivity_timeout: Option<std::time::Duration>,
1670    /// If nothing is received from the server for this amount of time, send a keepalive message.
1671    pub keepalive_interval: Option<std::time::Duration>,
1672    /// If this many keepalives have been sent without reply, close the connection.
1673    pub keepalive_max: usize,
1674    /// Whether to expect and wait for an authentication call.
1675    pub anonymous: bool,
1676    /// DH dynamic group exchange parameters.
1677    pub gex: GexParams,
1678    /// If active, invoke `set_nodelay(true)` on the ssh socket; disabled by default (i.e. Nagle's algorithm is active).
1679    pub nodelay: bool,
1680}
1681
1682impl Default for Config {
1683    fn default() -> Config {
1684        Config {
1685            client_id: SshId::Standard(format!(
1686                "SSH-2.0-{}_{}",
1687                env!("CARGO_PKG_NAME"),
1688                env!("CARGO_PKG_VERSION")
1689            )),
1690            limits: Limits::default(),
1691            window_size: 2097152,
1692            maximum_packet_size: 32768,
1693            channel_buffer_size: 100,
1694            preferred: Default::default(),
1695            inactivity_timeout: None,
1696            keepalive_interval: None,
1697            keepalive_max: 3,
1698            anonymous: false,
1699            gex: Default::default(),
1700            nodelay: false,
1701        }
1702    }
1703}
1704
1705/// A client handler. Note that messages can be received from the
1706/// server at any time during a session.
1707///
1708/// You must at the very least implement the `check_server_key` fn.
1709/// The default implementation rejects all keys.
1710///
1711/// Note: this is an async trait. The trait functions return `impl Future`,
1712/// and you can simply define them as `async fn` instead.
1713#[cfg_attr(feature = "async-trait", async_trait::async_trait)]
1714pub trait Handler: Sized + Send {
1715    type Error: From<crate::Error> + Send + core::fmt::Debug;
1716
1717    /// Called when the server sends us an authentication banner. This
1718    /// is usually meant to be shown to the user, see
1719    /// [RFC4252](https://tools.ietf.org/html/rfc4252#section-5.4) for
1720    /// more details.
1721    #[allow(unused_variables)]
1722    fn auth_banner(
1723        &mut self,
1724        banner: &str,
1725        session: &mut Session,
1726    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1727        async { Ok(()) }
1728    }
1729
1730    /// Called to check the server's public key. This is a very important
1731    /// step to help prevent man-in-the-middle attacks. The default
1732    /// implementation rejects all keys.
1733    #[allow(unused_variables)]
1734    fn check_server_key(
1735        &mut self,
1736        server_public_key: &ssh_key::PublicKey,
1737    ) -> impl Future<Output = Result<bool, Self::Error>> + Send {
1738        async { Ok(false) }
1739    }
1740
1741    /// Called when the server confirmed our request to open a
1742    /// channel. A channel can only be written to after receiving this
1743    /// message (this library panics otherwise).
1744    #[allow(unused_variables)]
1745    fn channel_open_confirmation(
1746        &mut self,
1747        id: ChannelId,
1748        max_packet_size: u32,
1749        window_size: u32,
1750        session: &mut Session,
1751    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1752        async { Ok(()) }
1753    }
1754
1755    /// Called when the server signals success.
1756    #[allow(unused_variables)]
1757    fn channel_success(
1758        &mut self,
1759        channel: ChannelId,
1760        session: &mut Session,
1761    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1762        async { Ok(()) }
1763    }
1764
1765    /// Called when the server signals failure.
1766    #[allow(unused_variables)]
1767    fn channel_failure(
1768        &mut self,
1769        channel: ChannelId,
1770        session: &mut Session,
1771    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1772        async { Ok(()) }
1773    }
1774
1775    /// Called when the server closes a channel.
1776    #[allow(unused_variables)]
1777    fn channel_close(
1778        &mut self,
1779        channel: ChannelId,
1780        session: &mut Session,
1781    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1782        async { Ok(()) }
1783    }
1784
1785    /// Called when the server sends EOF to a channel.
1786    #[allow(unused_variables)]
1787    fn channel_eof(
1788        &mut self,
1789        channel: ChannelId,
1790        session: &mut Session,
1791    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1792        async { Ok(()) }
1793    }
1794
1795    /// Called when the server rejected our request to open a channel.
1796    #[allow(unused_variables)]
1797    fn channel_open_failure(
1798        &mut self,
1799        channel: ChannelId,
1800        reason: ChannelOpenFailure,
1801        description: &str,
1802        language: &str,
1803        session: &mut Session,
1804    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1805        async { Ok(()) }
1806    }
1807
1808    /// Called when the server opens a channel for a new remote port forwarding connection
1809    #[allow(unused_variables)]
1810    fn server_channel_open_forwarded_tcpip(
1811        &mut self,
1812        channel: Channel<Msg>,
1813        connected_address: &str,
1814        connected_port: u32,
1815        originator_address: &str,
1816        originator_port: u32,
1817        session: &mut Session,
1818    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1819        async { Ok(()) }
1820    }
1821
1822    // Called when the server opens a channel for a new remote UDS forwarding connection
1823    #[allow(unused_variables)]
1824    fn server_channel_open_forwarded_streamlocal(
1825        &mut self,
1826        channel: Channel<Msg>,
1827        socket_path: &str,
1828        session: &mut Session,
1829    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1830        async { Ok(()) }
1831    }
1832
1833    /// Called when the server opens an agent forwarding channel
1834    #[allow(unused_variables)]
1835    fn server_channel_open_agent_forward(
1836        &mut self,
1837        channel: Channel<Msg>,
1838        session: &mut Session,
1839    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1840        async { Ok(()) }
1841    }
1842
1843    /// Called when the server attempts to open a channel of unknown type. It may return `true`,
1844    /// if the channel of unknown type should be accepted. In this case,
1845    /// [Handler::server_channel_open_unknown] will be called soon after. If it returns `false`,
1846    /// the channel will not be created and a rejection message will be sent to the server.
1847    #[allow(unused_variables)]
1848    fn should_accept_unknown_server_channel(
1849        &mut self,
1850        id: ChannelId,
1851        channel_type: &str,
1852    ) -> impl Future<Output = bool> + Send {
1853        async { false }
1854    }
1855
1856    /// Called when the server opens an unknown channel.
1857    #[allow(unused_variables)]
1858    fn server_channel_open_unknown(
1859        &mut self,
1860        channel: Channel<Msg>,
1861        session: &mut Session,
1862    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1863        async { Ok(()) }
1864    }
1865
1866    /// Called when the server opens a session channel.
1867    #[allow(unused_variables)]
1868    fn server_channel_open_session(
1869        &mut self,
1870        channel: Channel<Msg>,
1871        session: &mut Session,
1872    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1873        async { Ok(()) }
1874    }
1875
1876    /// Called when the server opens a direct tcp/ip channel (non-standard).
1877    #[allow(unused_variables)]
1878    fn server_channel_open_direct_tcpip(
1879        &mut self,
1880        channel: Channel<Msg>,
1881        host_to_connect: &str,
1882        port_to_connect: u32,
1883        originator_address: &str,
1884        originator_port: u32,
1885        session: &mut Session,
1886    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1887        async { Ok(()) }
1888    }
1889
1890    /// Called when the server opens a direct-streamlocal channel (non-standard).
1891    #[allow(unused_variables)]
1892    fn server_channel_open_direct_streamlocal(
1893        &mut self,
1894        channel: Channel<Msg>,
1895        socket_path: &str,
1896        session: &mut Session,
1897    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1898        async { Ok(()) }
1899    }
1900
1901    /// Called when the server opens an X11 channel.
1902    #[allow(unused_variables)]
1903    fn server_channel_open_x11(
1904        &mut self,
1905        channel: Channel<Msg>,
1906        originator_address: &str,
1907        originator_port: u32,
1908        session: &mut Session,
1909    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1910        async { Ok(()) }
1911    }
1912
1913    /// Called when the server sends us data. The `extended_code`
1914    /// parameter is a stream identifier, `None` is usually the
1915    /// standard output, and `Some(1)` is the standard error. See
1916    /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-5.2).
1917    #[allow(unused_variables)]
1918    fn data(
1919        &mut self,
1920        channel: ChannelId,
1921        data: &[u8],
1922        session: &mut Session,
1923    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1924        async { Ok(()) }
1925    }
1926
1927    /// Called when the server sends us data. The `extended_code`
1928    /// parameter is a stream identifier, `None` is usually the
1929    /// standard output, and `Some(1)` is the standard error. See
1930    /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-5.2).
1931    #[allow(unused_variables)]
1932    fn extended_data(
1933        &mut self,
1934        channel: ChannelId,
1935        ext: u32,
1936        data: &[u8],
1937        session: &mut Session,
1938    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1939        async { Ok(()) }
1940    }
1941
1942    /// The server informs this client of whether the client may
1943    /// perform control-S/control-Q flow control. See
1944    /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.8).
1945    #[allow(unused_variables)]
1946    fn xon_xoff(
1947        &mut self,
1948        channel: ChannelId,
1949        client_can_do: bool,
1950        session: &mut Session,
1951    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1952        async { Ok(()) }
1953    }
1954
1955    /// The remote process has exited, with the given exit status.
1956    #[allow(unused_variables)]
1957    fn exit_status(
1958        &mut self,
1959        channel: ChannelId,
1960        exit_status: u32,
1961        session: &mut Session,
1962    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1963        async { Ok(()) }
1964    }
1965
1966    /// The remote process exited upon receiving a signal.
1967    #[allow(unused_variables)]
1968    fn exit_signal(
1969        &mut self,
1970        channel: ChannelId,
1971        signal_name: Sig,
1972        core_dumped: bool,
1973        error_message: &str,
1974        lang_tag: &str,
1975        session: &mut Session,
1976    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1977        async { Ok(()) }
1978    }
1979
1980    /// Called when the network window is adjusted, meaning that we
1981    /// can send more bytes. This is useful if this client wants to
1982    /// send huge amounts of data, for instance if we have called
1983    /// `Session::data` before, and it returned less than the
1984    /// full amount of data.
1985    #[allow(unused_variables)]
1986    fn window_adjusted(
1987        &mut self,
1988        channel: ChannelId,
1989        new_size: u32,
1990        session: &mut Session,
1991    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
1992        async { Ok(()) }
1993    }
1994
1995    /// Called when this client adjusts the network window. Return the
1996    /// next target window and maximum packet size.
1997    #[allow(unused_variables)]
1998    fn adjust_window(&mut self, channel: ChannelId, window: u32) -> u32 {
1999        window
2000    }
2001
2002    /// Called when the server signals success.
2003    #[allow(unused_variables)]
2004    fn openssh_ext_host_keys_announced(
2005        &mut self,
2006        keys: Vec<PublicKey>,
2007        session: &mut Session,
2008    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
2009        async move {
2010            debug!("openssh_ext_hostkeys_announced: {:?}", keys);
2011            Ok(())
2012        }
2013    }
2014
2015    /// Called when the server sent a disconnect message
2016    ///
2017    /// If reason is an Error, this function should re-return the error so the join can also evaluate it
2018    #[allow(unused_variables)]
2019    fn disconnected(
2020        &mut self,
2021        reason: DisconnectReason<Self::Error>,
2022    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
2023        async {
2024            debug!("disconnected: {:?}", reason);
2025            match reason {
2026                DisconnectReason::ReceivedDisconnect(_) => Ok(()),
2027                DisconnectReason::Error(e) => Err(e),
2028            }
2029        }
2030    }
2031}