1use std::time::{Duration, Instant};
2
3use anyhow::Result;
4use tokio::sync::{
5 Mutex,
6 mpsc::{Sender, channel},
7};
8
9use tonic::{
10 Request, Response, Status,
11 transport::{Channel, Server},
12};
13
14#[cfg(feature = "ssl")]
15use tonic::transport::{Certificate, ClientTlsConfig, Identity, ServerTlsConfig};
16
17use protos::{
18 GetTurnPasswordRequest, PasswordAlgorithm as ProtoPasswordAlgorithm, SessionQueryParams,
19 TurnAllocatedEvent, TurnChannelBindEvent, TurnCreatePermissionEvent, TurnDestroyEvent,
20 TurnRefreshEvent, TurnServerInfo, TurnSession, TurnSessionStatistics,
21 turn_hooks_service_client::TurnHooksServiceClient,
22 turn_service_server::{TurnService, TurnServiceServer},
23};
24
25use crate::{
26 Service,
27 codec::{crypto::Password, message::attributes::PasswordAlgorithm},
28 config::Config,
29 service::session::{Identifier, Session},
30 statistics::Statistics,
31};
32
33impl From<PasswordAlgorithm> for ProtoPasswordAlgorithm {
34 fn from(val: PasswordAlgorithm) -> Self {
35 match val {
36 PasswordAlgorithm::Md5 => Self::Md5,
37 PasswordAlgorithm::Sha256 => Self::Sha256,
38 }
39 }
40}
41
42pub trait IdString {
43 type Error;
44
45 fn to_string(&self) -> String;
46 fn from_string(s: String) -> Result<Self, Self::Error>
47 where
48 Self: Sized;
49}
50
51impl IdString for Identifier {
52 type Error = Status;
53
54 fn to_string(&self) -> String {
55 format!("{}/{}", self.source(), self.interface())
56 }
57
58 fn from_string(s: String) -> Result<Self, Self::Error> {
59 let (source, interface) = s
60 .split_once('/')
61 .ok_or(Status::invalid_argument("Invalid identifier"))?;
62
63 Ok(Self::new(
64 source
65 .parse()
66 .map_err(|_| Status::invalid_argument("Invalid source address"))?,
67 interface
68 .parse()
69 .map_err(|_| Status::invalid_argument("Invalid interface address"))?,
70 ))
71 }
72}
73
74struct RpcService {
75 config: Config,
76 service: Service,
77 statistics: Statistics,
78 uptime: Instant,
79}
80
81#[tonic::async_trait]
82impl TurnService for RpcService {
83 async fn get_info(&self, _: Request<()>) -> Result<Response<TurnServerInfo>, Status> {
84 Ok(Response::new(TurnServerInfo {
85 software: crate::SOFTWARE.to_string(),
86 uptime: self.uptime.elapsed().as_secs(),
87 interfaces: self
88 .config
89 .server
90 .get_external_addresses()
91 .iter()
92 .map(|addr| addr.to_string())
93 .collect(),
94 port_capacity: self.config.server.port_range.size() as u32,
95 port_allocated: self.service.get_session_manager().allocated() as u32,
96 }))
97 }
98
99 async fn get_session(
100 &self,
101 request: Request<SessionQueryParams>,
102 ) -> Result<Response<TurnSession>, Status> {
103 if let Some(Session::Authenticated {
104 username,
105 allocate_port,
106 allocate_channels,
107 permissions,
108 expires,
109 ..
110 }) = self
111 .service
112 .get_session_manager()
113 .get_session(&Identifier::from_string(request.into_inner().id)?)
114 .get_ref()
115 {
116 Ok(Response::new(TurnSession {
117 username: username.to_string(),
118 permissions: permissions.iter().map(|p| *p as i32).collect(),
119 channels: allocate_channels.iter().map(|p| *p as i32).collect(),
120 port: allocate_port.map(|p| p as i32),
121 expires: *expires as i64,
122 }))
123 } else {
124 Err(Status::not_found("Session not found"))
125 }
126 }
127
128 async fn get_session_statistics(
129 &self,
130 request: Request<SessionQueryParams>,
131 ) -> Result<Response<TurnSessionStatistics>, Status> {
132 if let Some(counts) = self
133 .statistics
134 .get(&Identifier::from_string(request.into_inner().id)?)
135 {
136 Ok(Response::new(TurnSessionStatistics {
137 received_bytes: counts.received_bytes as u64,
138 send_bytes: counts.send_bytes as u64,
139 received_pkts: counts.received_pkts as u64,
140 send_pkts: counts.send_pkts as u64,
141 error_pkts: counts.error_pkts as u64,
142 }))
143 } else {
144 Err(Status::not_found("Session not found"))
145 }
146 }
147
148 async fn destroy_session(
149 &self,
150 request: Request<SessionQueryParams>,
151 ) -> Result<Response<()>, Status> {
152 if self
153 .service
154 .get_session_manager()
155 .refresh(&Identifier::from_string(request.into_inner().id)?, 0)
156 {
157 Ok(Response::new(()))
158 } else {
159 Err(Status::failed_precondition("Session not found"))
160 }
161 }
162}
163
164pub enum HooksEvent {
165 Allocated(TurnAllocatedEvent),
166 ChannelBind(TurnChannelBindEvent),
167 CreatePermission(TurnCreatePermissionEvent),
168 Refresh(TurnRefreshEvent),
169 Destroy(TurnDestroyEvent),
170}
171
172struct RpcHooksServiceInner {
173 event_channel: Sender<HooksEvent>,
174 client: Mutex<TurnHooksServiceClient<Channel>>,
175}
176
177pub struct RpcHooksService(Option<RpcHooksServiceInner>);
178
179impl RpcHooksService {
180 pub async fn new(config: &Config) -> Result<Self> {
181 if let Some(hooks) = &config.hooks {
182 let (event_channel, mut rx) = channel(hooks.max_channel_size);
183 let client = {
184 let mut builder = Channel::builder(hooks.endpoint.as_str().try_into()?);
185
186 builder = builder.timeout(Duration::from_secs(hooks.timeout as u64));
187
188 #[cfg(feature = "ssl")]
189 if let Some(ssl) = &hooks.ssl {
190 builder = builder.tls_config(
191 ClientTlsConfig::new()
192 .ca_certificate(Certificate::from_pem(ssl.certificate_chain.clone()))
193 .domain_name(
194 url::Url::parse(&hooks.endpoint)?.domain().ok_or_else(|| {
195 anyhow::anyhow!("Invalid hooks server domain")
196 })?,
197 ),
198 )?;
199 }
200
201 TurnHooksServiceClient::new(
202 builder
203 .connect_timeout(Duration::from_secs(5))
204 .timeout(Duration::from_secs(1))
205 .connect_lazy(),
206 )
207 };
208
209 {
210 let mut client = client.clone();
211
212 tokio::spawn(async move {
213 while let Some(event) = rx.recv().await {
214 if match event {
215 HooksEvent::Allocated(event) => {
216 client.on_allocated_event(Request::new(event)).await
217 }
218 HooksEvent::ChannelBind(event) => {
219 client.on_channel_bind_event(Request::new(event)).await
220 }
221 HooksEvent::CreatePermission(event) => {
222 client.on_create_permission_event(Request::new(event)).await
223 }
224 HooksEvent::Refresh(event) => {
225 client.on_refresh_event(Request::new(event)).await
226 }
227 HooksEvent::Destroy(event) => {
228 client.on_destroy_event(Request::new(event)).await
229 }
230 }
231 .is_err()
232 {
233 break;
234 }
235 }
236 });
237 }
238
239 log::info!("create hooks client, endpoint={}", hooks.endpoint);
240
241 Ok(Self(Some(RpcHooksServiceInner {
242 client: Mutex::new(client),
243 event_channel,
244 })))
245 } else {
246 Ok(Self(None))
247 }
248 }
249
250 pub fn send_event(&self, event: HooksEvent) {
251 if let Some(inner) = &self.0
252 && !inner.event_channel.is_closed()
253 && let Err(e) = inner.event_channel.try_send(event)
254 {
255 log::error!("Failed to send event to hooks server: {}", e);
256 }
257 }
258
259 pub async fn get_password(
260 &self,
261 realm: &str,
262 username: &str,
263 algorithm: PasswordAlgorithm,
264 ) -> Option<Password> {
265 if let Some(inner) = &self.0 {
266 let algorithm: ProtoPasswordAlgorithm = algorithm.into();
267 let password = inner
268 .client
269 .lock()
270 .await
271 .get_password(Request::new(GetTurnPasswordRequest {
272 realm: realm.to_string(),
273 username: username.to_string(),
274 algorithm: algorithm as i32,
275 }))
276 .await
277 .ok()?
278 .into_inner()
279 .password;
280
281 return Some(match algorithm {
282 ProtoPasswordAlgorithm::Md5 => Password::Md5(password.try_into().ok()?),
283 ProtoPasswordAlgorithm::Sha256 => Password::Sha256(password.try_into().ok()?),
284 ProtoPasswordAlgorithm::Unspecified => unreachable!(),
285 });
286 }
287
288 None
289 }
290}
291
292pub async fn start_server(config: Config, service: Service, statistics: Statistics) -> Result<()> {
293 if let Some(api) = &config.api {
294 let mut builder = Server::builder();
295
296 builder = builder
297 .timeout(Duration::from_secs(api.timeout as u64))
298 .accept_http1(false);
299
300 #[cfg(feature = "ssl")]
301 if let Some(ssl) = &api.ssl {
302 builder = builder.tls_config(ServerTlsConfig::new().identity(Identity::from_pem(
303 ssl.certificate_chain.clone(),
304 ssl.private_key.clone(),
305 )))?;
306 }
307
308 log::info!("api server listening: listen={}", api.listen);
309
310 builder
311 .add_service(TurnServiceServer::new(RpcService {
312 config: config.clone(),
313 uptime: Instant::now(),
314 statistics,
315 service,
316 }))
317 .serve(api.listen)
318 .await?;
319 } else {
320 std::future::pending().await
321 }
322
323 Ok(())
324}