socks5_protocol/
sync.rs

1use crate::{
2    Address, AuthRequest, AuthResponse, Command, CommandReply, CommandRequest, CommandResponse,
3    Error, Result, Version,
4};
5use std::{io, net::SocketAddr};
6
7/// Read `Self` from `io::Read` or write `Self` to `io::Write`.
8pub trait FromIO {
9    /// Read `Self` from `io::Read`.
10    fn read_from(reader: &mut impl io::Read) -> Result<Self>
11    where
12        Self: Sized;
13
14    /// Write `Self` to `io::Write`.
15    fn write_to(&self, writer: &mut impl io::Write) -> Result<()>;
16}
17
18impl FromIO for Version {
19    fn read_from(reader: &mut impl io::Read) -> Result<Self>
20    where
21        Self: Sized,
22    {
23        let version = &mut [0u8];
24        reader.read_exact(version)?;
25        match version[0] {
26            5 => Ok(Version::V5),
27            other => Err(Error::InvalidVersion(other)),
28        }
29    }
30
31    fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
32        let v = match self {
33            Version::V5 => 5u8,
34        };
35        writer.write_all(&[v])?;
36        Ok(())
37    }
38}
39
40impl FromIO for AuthRequest {
41    fn read_from(reader: &mut impl io::Read) -> Result<Self>
42    where
43        Self: Sized,
44    {
45        let count = &mut [0u8];
46        reader.read_exact(count)?;
47        let mut methods = vec![0u8; count[0] as usize];
48        reader.read_exact(&mut methods)?;
49
50        Ok(AuthRequest(methods.into_iter().map(Into::into).collect()))
51    }
52
53    fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
54        let count = self.0.len();
55        if count > 255 {
56            return Err(Error::TooManyMethods);
57        }
58
59        writer.write_all(&[count as u8])?;
60        writer.write_all(
61            &self
62                .0
63                .iter()
64                .map(|i| Into::<u8>::into(*i))
65                .collect::<Vec<_>>(),
66        )?;
67
68        Ok(())
69    }
70}
71
72impl FromIO for AuthResponse {
73    fn read_from(reader: &mut impl io::Read) -> Result<Self>
74    where
75        Self: Sized,
76    {
77        let method = &mut [0u8];
78        reader.read_exact(method)?;
79        Ok(AuthResponse(method[0].into()))
80    }
81
82    fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
83        writer.write_all(&[self.0.into()])?;
84        Ok(())
85    }
86}
87
88impl FromIO for CommandRequest {
89    fn read_from(reader: &mut impl io::Read) -> Result<Self>
90    where
91        Self: Sized,
92    {
93        let buf = &mut [0u8; 3];
94        reader.read_exact(buf)?;
95        if buf[0] != 5 {
96            return Err(Error::InvalidVersion(buf[0]));
97        }
98        if buf[2] != 0 {
99            return Err(Error::InvalidHandshake);
100        }
101        let cmd = match buf[1] {
102            1 => Command::Connect,
103            2 => Command::Bind,
104            3 => Command::UdpAssociate,
105            _ => return Err(Error::InvalidCommand(buf[1])),
106        };
107
108        let address = Address::read_from(reader)?;
109
110        Ok(CommandRequest {
111            command: cmd,
112            address,
113        })
114    }
115
116    fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
117        let cmd = match self.command {
118            Command::Connect => 1u8,
119            Command::Bind => 2,
120            Command::UdpAssociate => 3,
121        };
122        writer.write_all(&[0x05, cmd, 0x00])?;
123        self.address.write_to(writer)?;
124        Ok(())
125    }
126}
127
128impl FromIO for CommandResponse {
129    fn read_from(reader: &mut impl io::Read) -> Result<Self>
130    where
131        Self: Sized,
132    {
133        let buf = &mut [0u8; 3];
134        reader.read_exact(buf)?;
135        if buf[0] != 5 {
136            return Err(Error::InvalidVersion(buf[0]));
137        }
138        if buf[2] != 0 {
139            return Err(Error::InvalidHandshake);
140        }
141        let reply = CommandReply::from_u8(buf[1])?;
142
143        let address = Address::read_from(reader)?;
144
145        if reply != CommandReply::Succeeded {
146            return Err(Error::CommandReply(reply));
147        }
148
149        Ok(CommandResponse { reply, address })
150    }
151
152    fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
153        writer.write_all(&[0x05, self.reply.to_u8(), 0x00])?;
154        self.address.write_to(writer)?;
155        Ok(())
156    }
157}
158
159impl Address {
160    fn read_port_from(reader: &mut impl io::Read) -> Result<u16> {
161        let mut buf = [0u8; 2];
162        reader.read_exact(&mut buf)?;
163        let port = u16::from_be_bytes(buf);
164        Ok(port)
165    }
166    fn write_port_to(writer: &mut impl io::Write, port: u16) -> Result<()> {
167        writer.write_all(&port.to_be_bytes())?;
168        Ok(())
169    }
170}
171
172impl FromIO for Address {
173    fn read_from(reader: &mut impl io::Read) -> Result<Self>
174    where
175        Self: Sized,
176    {
177        let mut atyp = [0u8; 1];
178        reader.read_exact(&mut atyp)?;
179
180        Ok(match atyp[0] {
181            1 => {
182                let mut ip = [0u8; 4];
183                reader.read_exact(&mut ip)?;
184                Address::SocketAddr(SocketAddr::new(ip.into(), Self::read_port_from(reader)?))
185            }
186            3 => {
187                let mut len = [0u8; 1];
188                reader.read_exact(&mut len)?;
189                let len = len[0] as usize;
190                let mut domain = vec![0u8; len];
191                reader.read_exact(&mut domain)?;
192
193                let domain =
194                    String::from_utf8(domain).map_err(|e| Error::InvalidDomain(e.into_bytes()))?;
195
196                Address::Domain(domain, Self::read_port_from(reader)?)
197            }
198            4 => {
199                let mut ip = [0u8; 16];
200                reader.read_exact(&mut ip)?;
201                Address::SocketAddr(SocketAddr::new(ip.into(), Self::read_port_from(reader)?))
202            }
203            _ => return Err(Error::InvalidAddressType(atyp[0])),
204        })
205    }
206
207    fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
208        match self {
209            Address::SocketAddr(SocketAddr::V4(addr)) => {
210                writer.write_all(&[0x01])?;
211                writer.write_all(&addr.ip().octets())?;
212                Self::write_port_to(writer, addr.port())?;
213            }
214            Address::SocketAddr(SocketAddr::V6(addr)) => {
215                writer.write_all(&[0x04])?;
216                writer.write_all(&addr.ip().octets())?;
217                Self::write_port_to(writer, addr.port())?;
218            }
219            Address::Domain(domain, port) => {
220                if domain.len() >= 256 {
221                    return Err(Error::DomainTooLong(domain.len()));
222                }
223                let header = [0x03, domain.len() as u8];
224                writer.write_all(&header)?;
225                writer.write_all(domain.as_bytes())?;
226                Self::write_port_to(writer, *port)?;
227            }
228        };
229        Ok(())
230    }
231}