use anyhow::Result;
use encoding_rs::{CoderResult, UTF_8};
use sshx_core::proto::{client_update::ClientMessage, TerminalData};
use sshx_core::Sid;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::mpsc,
};
use crate::encrypt::Encrypt;
use crate::terminal::Terminal;
const CONTENT_CHUNK_SIZE: usize = 1 << 16; const CONTENT_ROLLING_BYTES: usize = 8 << 20; const CONTENT_PRUNE_BYTES: usize = 12 << 20; #[derive(Debug, Clone)]
pub enum Runner {
Shell(String),
Echo,
}
pub enum ShellData {
Data(Vec<u8>),
Sync(u64),
Size(u32, u32),
}
impl Runner {
pub async fn run(
&self,
id: Sid,
encrypt: Encrypt,
shell_rx: mpsc::Receiver<ShellData>,
output_tx: mpsc::Sender<ClientMessage>,
) -> Result<()> {
match self {
Self::Shell(shell) => shell_task(id, encrypt, shell, shell_rx, output_tx).await,
Self::Echo => echo_task(id, encrypt, shell_rx, output_tx).await,
}
}
}
async fn shell_task(
id: Sid,
encrypt: Encrypt,
shell: &str,
mut shell_rx: mpsc::Receiver<ShellData>,
output_tx: mpsc::Sender<ClientMessage>,
) -> Result<()> {
let mut term = Terminal::new(shell).await?;
term.set_winsize(24, 80)?;
let mut content = String::new(); let mut content_offset = 0; let mut decoder = UTF_8.new_decoder(); let mut seq = 0; let mut seq_outdated = 0; let mut buf = [0u8; 4096]; let mut finished = false; while !finished {
tokio::select! {
result = term.read(&mut buf) => {
let n = result?;
if n == 0 {
finished = true;
} else {
content.reserve(decoder.max_utf8_buffer_length(n).unwrap());
let (result, _, _) = decoder.decode_to_string(&buf[..n], &mut content, false);
debug_assert!(result == CoderResult::InputEmpty);
}
}
item = shell_rx.recv() => {
match item {
Some(ShellData::Data(data)) => {
term.write_all(&data).await?;
}
Some(ShellData::Sync(seq2)) => {
if seq2 < seq as u64 {
seq_outdated += 1;
if seq_outdated >= 3 {
seq = seq2 as usize;
}
}
}
Some(ShellData::Size(rows, cols)) => {
term.set_winsize(rows as u16, cols as u16)?;
}
None => finished = true, }
}
}
if finished {
content.reserve(decoder.max_utf8_buffer_length(0).unwrap());
let (result, _, _) = decoder.decode_to_string(&[], &mut content, true);
debug_assert!(result == CoderResult::InputEmpty);
}
if content_offset + content.len() > seq {
let start = prev_char_boundary(&content, seq - content_offset);
let end = prev_char_boundary(&content, (start + CONTENT_CHUNK_SIZE).min(content.len()));
let data = encrypt.segment(
0x100000000 | id.0 as u64, (content_offset + start) as u64,
content[start..end].as_bytes(),
);
let data = TerminalData {
id: id.0,
data: data.into(),
seq: (content_offset + start) as u64,
};
output_tx.send(ClientMessage::Data(data)).await?;
seq = content_offset + end;
seq_outdated = 0;
}
if content.len() > CONTENT_PRUNE_BYTES && seq - CONTENT_ROLLING_BYTES > content_offset {
let pruned = (seq - CONTENT_ROLLING_BYTES) - content_offset;
let pruned = prev_char_boundary(&content, pruned);
content_offset += pruned;
content.drain(..pruned);
}
}
Ok(())
}
fn prev_char_boundary(s: &str, i: usize) -> usize {
(0..=i)
.rev()
.find(|&j| s.is_char_boundary(j))
.expect("no previous char boundary")
}
async fn echo_task(
id: Sid,
encrypt: Encrypt,
mut shell_rx: mpsc::Receiver<ShellData>,
output_tx: mpsc::Sender<ClientMessage>,
) -> Result<()> {
let mut seq = 0;
while let Some(item) = shell_rx.recv().await {
match item {
ShellData::Data(data) => {
let msg = String::from_utf8_lossy(&data);
let term_data = TerminalData {
id: id.0,
data: encrypt
.segment(0x100000000 | id.0 as u64, seq, msg.as_bytes())
.into(),
seq,
};
output_tx.send(ClientMessage::Data(term_data)).await?;
seq += msg.len() as u64;
}
ShellData::Sync(_) => (),
ShellData::Size(_, _) => (),
}
}
Ok(())
}