1use crate::{
2 Address, AuthRequest, AuthResponse, Command, CommandReply, CommandRequest, CommandResponse,
3 Error, Result, Version,
4};
5use std::{io, net::SocketAddr};
6
7pub trait FromIO {
9 fn read_from(reader: &mut impl io::Read) -> Result<Self>
11 where
12 Self: Sized;
13
14 fn write_to(&self, writer: &mut impl io::Write) -> Result<()>;
16}
17
18impl FromIO for Version {
19 fn read_from(reader: &mut impl io::Read) -> Result<Self>
20 where
21 Self: Sized,
22 {
23 let version = &mut [0u8];
24 reader.read_exact(version)?;
25 match version[0] {
26 5 => Ok(Version::V5),
27 other => Err(Error::InvalidVersion(other)),
28 }
29 }
30
31 fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
32 let v = match self {
33 Version::V5 => 5u8,
34 };
35 writer.write_all(&[v])?;
36 Ok(())
37 }
38}
39
40impl FromIO for AuthRequest {
41 fn read_from(reader: &mut impl io::Read) -> Result<Self>
42 where
43 Self: Sized,
44 {
45 let count = &mut [0u8];
46 reader.read_exact(count)?;
47 let mut methods = vec![0u8; count[0] as usize];
48 reader.read_exact(&mut methods)?;
49
50 Ok(AuthRequest(methods.into_iter().map(Into::into).collect()))
51 }
52
53 fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
54 let count = self.0.len();
55 if count > 255 {
56 return Err(Error::TooManyMethods);
57 }
58
59 writer.write_all(&[count as u8])?;
60 writer.write_all(
61 &self
62 .0
63 .iter()
64 .map(|i| Into::<u8>::into(*i))
65 .collect::<Vec<_>>(),
66 )?;
67
68 Ok(())
69 }
70}
71
72impl FromIO for AuthResponse {
73 fn read_from(reader: &mut impl io::Read) -> Result<Self>
74 where
75 Self: Sized,
76 {
77 let method = &mut [0u8];
78 reader.read_exact(method)?;
79 Ok(AuthResponse(method[0].into()))
80 }
81
82 fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
83 writer.write_all(&[self.0.into()])?;
84 Ok(())
85 }
86}
87
88impl FromIO for CommandRequest {
89 fn read_from(reader: &mut impl io::Read) -> Result<Self>
90 where
91 Self: Sized,
92 {
93 let buf = &mut [0u8; 3];
94 reader.read_exact(buf)?;
95 if buf[0] != 5 {
96 return Err(Error::InvalidVersion(buf[0]));
97 }
98 if buf[2] != 0 {
99 return Err(Error::InvalidHandshake);
100 }
101 let cmd = match buf[1] {
102 1 => Command::Connect,
103 2 => Command::Bind,
104 3 => Command::UdpAssociate,
105 _ => return Err(Error::InvalidCommand(buf[1])),
106 };
107
108 let address = Address::read_from(reader)?;
109
110 Ok(CommandRequest {
111 command: cmd,
112 address,
113 })
114 }
115
116 fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
117 let cmd = match self.command {
118 Command::Connect => 1u8,
119 Command::Bind => 2,
120 Command::UdpAssociate => 3,
121 };
122 writer.write_all(&[0x05, cmd, 0x00])?;
123 self.address.write_to(writer)?;
124 Ok(())
125 }
126}
127
128impl FromIO for CommandResponse {
129 fn read_from(reader: &mut impl io::Read) -> Result<Self>
130 where
131 Self: Sized,
132 {
133 let buf = &mut [0u8; 3];
134 reader.read_exact(buf)?;
135 if buf[0] != 5 {
136 return Err(Error::InvalidVersion(buf[0]));
137 }
138 if buf[2] != 0 {
139 return Err(Error::InvalidHandshake);
140 }
141 let reply = CommandReply::from_u8(buf[1])?;
142
143 let address = Address::read_from(reader)?;
144
145 if reply != CommandReply::Succeeded {
146 return Err(Error::CommandReply(reply));
147 }
148
149 Ok(CommandResponse { reply, address })
150 }
151
152 fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
153 writer.write_all(&[0x05, self.reply.to_u8(), 0x00])?;
154 self.address.write_to(writer)?;
155 Ok(())
156 }
157}
158
159impl Address {
160 fn read_port_from(reader: &mut impl io::Read) -> Result<u16> {
161 let mut buf = [0u8; 2];
162 reader.read_exact(&mut buf)?;
163 let port = u16::from_be_bytes(buf);
164 Ok(port)
165 }
166 fn write_port_to(writer: &mut impl io::Write, port: u16) -> Result<()> {
167 writer.write_all(&port.to_be_bytes())?;
168 Ok(())
169 }
170}
171
172impl FromIO for Address {
173 fn read_from(reader: &mut impl io::Read) -> Result<Self>
174 where
175 Self: Sized,
176 {
177 let mut atyp = [0u8; 1];
178 reader.read_exact(&mut atyp)?;
179
180 Ok(match atyp[0] {
181 1 => {
182 let mut ip = [0u8; 4];
183 reader.read_exact(&mut ip)?;
184 Address::SocketAddr(SocketAddr::new(ip.into(), Self::read_port_from(reader)?))
185 }
186 3 => {
187 let mut len = [0u8; 1];
188 reader.read_exact(&mut len)?;
189 let len = len[0] as usize;
190 let mut domain = vec![0u8; len];
191 reader.read_exact(&mut domain)?;
192
193 let domain =
194 String::from_utf8(domain).map_err(|e| Error::InvalidDomain(e.into_bytes()))?;
195
196 Address::Domain(domain, Self::read_port_from(reader)?)
197 }
198 4 => {
199 let mut ip = [0u8; 16];
200 reader.read_exact(&mut ip)?;
201 Address::SocketAddr(SocketAddr::new(ip.into(), Self::read_port_from(reader)?))
202 }
203 _ => return Err(Error::InvalidAddressType(atyp[0])),
204 })
205 }
206
207 fn write_to(&self, writer: &mut impl io::Write) -> Result<()> {
208 match self {
209 Address::SocketAddr(SocketAddr::V4(addr)) => {
210 writer.write_all(&[0x01])?;
211 writer.write_all(&addr.ip().octets())?;
212 Self::write_port_to(writer, addr.port())?;
213 }
214 Address::SocketAddr(SocketAddr::V6(addr)) => {
215 writer.write_all(&[0x04])?;
216 writer.write_all(&addr.ip().octets())?;
217 Self::write_port_to(writer, addr.port())?;
218 }
219 Address::Domain(domain, port) => {
220 if domain.len() >= 256 {
221 return Err(Error::DomainTooLong(domain.len()));
222 }
223 let header = [0x03, domain.len() as u8];
224 writer.write_all(&header)?;
225 writer.write_all(domain.as_bytes())?;
226 Self::write_port_to(writer, *port)?;
227 }
228 };
229 Ok(())
230 }
231}