turn_server/
grpc.rs

1pub mod proto {
2    tonic::include_proto!("turn.server");
3}
4
5use std::time::{Duration, Instant};
6
7use anyhow::{Result, anyhow};
8use tokio::sync::{
9    Mutex,
10    mpsc::{Sender, channel},
11};
12
13use tonic::{
14    Request, Response, Status,
15    transport::{Channel, Server},
16};
17
18#[cfg(feature = "ssl")]
19use tonic::transport::{Certificate, ClientTlsConfig, Identity, ServerTlsConfig};
20
21use proto::{
22    PasswordAlgorithm as ProtoPasswordAlgorithm, SessionQueryParams, TurnAllocatedEvent,
23    TurnChannelBindEvent, TurnCreatePermissionEvent, TurnDestroyEvent, TurnRefreshEvent,
24    TurnServerInfo, TurnSession, TurnSessionStatistics,
25    turn_hooks_service_client::TurnHooksServiceClient,
26    turn_service_server::{TurnService, TurnServiceServer},
27    GetTurnPasswordRequest,
28};
29
30use crate::{
31    Service,
32    codec::{crypto::Password, message::attributes::PasswordAlgorithm},
33    config::Config,
34    service::session::{Identifier, Session},
35    statistics::Statistics,
36};
37
38impl From<PasswordAlgorithm> for ProtoPasswordAlgorithm {
39    fn from(val: PasswordAlgorithm) -> Self {
40        match val {
41            PasswordAlgorithm::Md5 => Self::Md5,
42            PasswordAlgorithm::Sha256 => Self::Sha256,
43        }
44    }
45}
46
47pub trait IdString {
48    type Error;
49
50    fn to_string(&self) -> String;
51    fn from_string(s: String) -> Result<Self, Self::Error>
52    where
53        Self: Sized;
54}
55
56impl IdString for Identifier {
57    type Error = anyhow::Error;
58
59    fn to_string(&self) -> String {
60        format!("{}/{}", self.source, self.interface)
61    }
62
63    fn from_string(s: String) -> Result<Self, Self::Error> {
64        let (source, interface) = s.split_once('/').ok_or(anyhow!("Invalid identifier"))?;
65
66        Ok(Self {
67            source: source.parse()?,
68            interface: interface.parse()?,
69        })
70    }
71}
72
73struct RpcService {
74    config: Config,
75    service: Service,
76    statistics: Statistics,
77    uptime: Instant,
78}
79
80#[tonic::async_trait]
81impl TurnService for RpcService {
82    async fn get_info(&self, _: Request<()>) -> Result<Response<TurnServerInfo>, Status> {
83        Ok(Response::new(TurnServerInfo {
84            software: crate::SOFTWARE.to_string(),
85            uptime: self.uptime.elapsed().as_secs(),
86            interfaces: self
87                .config
88                .server
89                .get_externals()
90                .iter()
91                .map(|addr| addr.to_string())
92                .collect(),
93            port_capacity: self.config.server.port_range.size() as u32,
94            port_allocated: self.service.get_session_manager().allocated() as u32,
95        }))
96    }
97
98    async fn get_session(
99        &self,
100        request: Request<SessionQueryParams>,
101    ) -> Result<Response<TurnSession>, Status> {
102        if let Some(Session::Authenticated {
103            username,
104            allocate_port,
105            allocate_channels,
106            permissions,
107            expires,
108            ..
109        }) = self
110            .service
111            .get_session_manager()
112            .get_session(
113                &Identifier::from_string(request.into_inner().id)
114                    .map_err(|_| Status::invalid_argument("Invalid identifier"))?,
115            )
116            .get_ref()
117        {
118            Ok(Response::new(TurnSession {
119                username: username.to_string(),
120                permissions: permissions.iter().map(|p| *p as i32).collect(),
121                channels: allocate_channels.iter().map(|p| *p as i32).collect(),
122                port: allocate_port.map(|p| p as i32),
123                expires: *expires as i64,
124            }))
125        } else {
126            Err(Status::not_found("Session not found"))
127        }
128    }
129
130    async fn get_session_statistics(
131        &self,
132        request: Request<SessionQueryParams>,
133    ) -> Result<Response<TurnSessionStatistics>, Status> {
134        if let Some(counts) = self.statistics.get(
135            &Identifier::from_string(request.into_inner().id)
136                .map_err(|_| Status::invalid_argument("Invalid identifier"))?,
137        ) {
138            Ok(Response::new(TurnSessionStatistics {
139                received_bytes: counts.received_bytes as u64,
140                send_bytes: counts.send_bytes as u64,
141                received_pkts: counts.received_pkts as u64,
142                send_pkts: counts.send_pkts as u64,
143                error_pkts: counts.error_pkts as u64,
144            }))
145        } else {
146            Err(Status::not_found("Session not found"))
147        }
148    }
149
150    async fn destroy_session(
151        &self,
152        request: Request<SessionQueryParams>,
153    ) -> Result<Response<()>, Status> {
154        if self.service.get_session_manager().refresh(
155            &Identifier::from_string(request.into_inner().id)
156                .map_err(|_| Status::invalid_argument("Invalid identifier"))?,
157            0,
158        ) {
159            Ok(Response::new(()))
160        } else {
161            Err(Status::failed_precondition("Session not found"))
162        }
163    }
164}
165
166pub async fn start_server(config: Config, service: Service, statistics: Statistics) -> Result<()> {
167    if let Some(api) = &config.api {
168        let mut builder = Server::builder();
169
170        builder = builder
171            .timeout(Duration::from_secs(api.timeout as u64))
172            .accept_http1(false);
173
174        #[cfg(feature = "ssl")]
175        if let Some(ssl) = &api.ssl {
176            builder = builder.tls_config(ServerTlsConfig::new().identity(Identity::from_pem(
177                ssl.certificate_chain.clone(),
178                ssl.private_key.clone(),
179            )))?;
180        }
181
182        log::info!("api server listening: listen={}", api.listen);
183
184        builder
185            .add_service(TurnServiceServer::new(RpcService {
186                config: config.clone(),
187                uptime: Instant::now(),
188                statistics,
189                service,
190            }))
191            .serve(api.listen)
192            .await?;
193    } else {
194        std::future::pending().await
195    }
196
197    Ok(())
198}
199
200pub enum HooksEvent {
201    Allocated(TurnAllocatedEvent),
202    ChannelBind(TurnChannelBindEvent),
203    CreatePermission(TurnCreatePermissionEvent),
204    Refresh(TurnRefreshEvent),
205    Destroy(TurnDestroyEvent),
206}
207
208struct RpcHooksServiceInner {
209    event_channel: Sender<HooksEvent>,
210    client: Mutex<TurnHooksServiceClient<Channel>>,
211}
212
213pub struct RpcHooksService(Option<RpcHooksServiceInner>);
214
215impl RpcHooksService {
216    pub async fn new(config: &Config) -> Result<Self> {
217        if let Some(hooks) = &config.hooks {
218            let (event_channel, mut rx) = channel(hooks.max_channel_size);
219            let client = {
220                let mut builder = Channel::builder(hooks.endpoint.as_str().try_into()?);
221
222                builder = builder.timeout(Duration::from_secs(hooks.timeout as u64));
223
224                #[cfg(feature = "ssl")]
225                if let Some(ssl) = &hooks.ssl {
226                    builder = builder.tls_config(
227                        ClientTlsConfig::new()
228                            .ca_certificate(Certificate::from_pem(ssl.certificate_chain.clone()))
229                            .domain_name(
230                                url::Url::parse(&hooks.endpoint)?
231                                    .domain()
232                                    .ok_or_else(|| anyhow!("Invalid hooks server domain"))?,
233                            ),
234                    )?;
235                }
236
237                TurnHooksServiceClient::new(
238                    builder
239                        .connect_timeout(Duration::from_secs(5))
240                        .timeout(Duration::from_secs(1))
241                        .connect()
242                        .await?,
243                )
244            };
245
246            {
247                let mut client = client.clone();
248
249                tokio::spawn(async move {
250                    while let Some(event) = rx.recv().await {
251                        if match event {
252                            HooksEvent::Allocated(event) => {
253                                client.on_allocated_event(Request::new(event)).await
254                            }
255                            HooksEvent::ChannelBind(event) => {
256                                client.on_channel_bind_event(Request::new(event)).await
257                            }
258                            HooksEvent::CreatePermission(event) => {
259                                client.on_create_permission_event(Request::new(event)).await
260                            }
261                            HooksEvent::Refresh(event) => {
262                                client.on_refresh_event(Request::new(event)).await
263                            }
264                            HooksEvent::Destroy(event) => {
265                                client.on_destroy_event(Request::new(event)).await
266                            }
267                        }
268                        .is_err()
269                        {
270                            break;
271                        }
272                    }
273                });
274            }
275
276            log::info!("create hooks client, endpoint={}", hooks.endpoint);
277
278            Ok(Self(Some(RpcHooksServiceInner {
279                client: Mutex::new(client),
280                event_channel,
281            })))
282        } else {
283            Ok(Self(None))
284        }
285    }
286
287    pub fn send_event(&self, event: HooksEvent) {
288        if let Some(inner) = &self.0
289            && !inner.event_channel.is_closed()
290            && let Err(e) = inner.event_channel.try_send(event)
291        {
292            log::error!("Failed to send event to hooks server: {}", e);
293        }
294    }
295
296    pub async fn get_password(
297        &self,
298        username: &str,
299        algorithm: PasswordAlgorithm,
300    ) -> Option<Password> {
301        if let Some(inner) = &self.0 {
302            let algorithm: ProtoPasswordAlgorithm = algorithm.into();
303            let password = inner
304                .client
305                .lock()
306                .await
307                .get_password(Request::new(GetTurnPasswordRequest {
308                    username: username.to_string(),
309                    algorithm: algorithm as i32,
310                }))
311                .await
312                .ok()?
313                .into_inner()
314                .password;
315
316            return Some(match algorithm {
317                ProtoPasswordAlgorithm::Md5 => Password::Md5(password.try_into().ok()?),
318                ProtoPasswordAlgorithm::Sha256 => Password::Sha256(password.try_into().ok()?),
319                ProtoPasswordAlgorithm::Unspecified => unreachable!(),
320            });
321        }
322
323        None
324    }
325}