1use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11 time::Duration,
12};
13
14use futures_util::{StreamExt, stream::SelectAll};
15use hyper::{Request, Response};
16use hyper_util::{
17 rt::{TokioExecutor, TokioIo},
18 server::conn::auto::Connection,
19 service::TowerToHyperService,
20};
21use mas_context::LogContext;
22use pin_project_lite::pin_project;
23use thiserror::Error;
24use tokio_rustls::rustls::ServerConfig;
25use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
26use tower::Service;
27use tower_http::add_extension::AddExtension;
28use tracing::Instrument;
29
30use crate::{
31 ConnectionInfo,
32 maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo},
33 proxy_protocol::{MaybeProxyAcceptor, ProxyAcceptError},
34 rewind::Rewind,
35 unix_or_tcp::{SocketAddr, UnixOrTcpConnection, UnixOrTcpListener},
36};
37
38const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5);
40
41pub struct Server<S> {
42 tls: Option<Arc<ServerConfig>>,
43 proxy: bool,
44 listener: UnixOrTcpListener,
45 service: S,
46}
47
48impl<S> Server<S> {
49 pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
53 where
54 L: TryInto<UnixOrTcpListener>,
55 {
56 Ok(Self {
57 tls: None,
58 proxy: false,
59 listener: listener.try_into()?,
60 service,
61 })
62 }
63
64 #[must_use]
65 pub fn new(listener: impl Into<UnixOrTcpListener>, service: S) -> Self {
66 Self {
67 tls: None,
68 proxy: false,
69 listener: listener.into(),
70 service,
71 }
72 }
73
74 #[must_use]
75 pub const fn with_proxy(mut self) -> Self {
76 self.proxy = true;
77 self
78 }
79
80 #[must_use]
81 pub fn with_tls(mut self, config: Arc<ServerConfig>) -> Self {
82 self.tls = Some(config);
83 self
84 }
85
86 pub async fn run<B>(
88 self,
89 soft_shutdown_token: CancellationToken,
90 hard_shutdown_token: CancellationToken,
91 ) where
92 S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
93 S::Future: Send + 'static,
94 S::Error: std::error::Error + Send + Sync + 'static,
95 B: http_body::Body + Send + 'static,
96 B::Data: Send,
97 B::Error: std::error::Error + Send + Sync + 'static,
98 {
99 run_servers(
100 std::iter::once(self),
101 soft_shutdown_token,
102 hard_shutdown_token,
103 )
104 .await;
105 }
106}
107
108#[derive(Debug, Error)]
109#[non_exhaustive]
110enum AcceptError {
111 #[error("failed to complete the TLS handshake")]
112 TlsHandshake {
113 #[source]
114 source: std::io::Error,
115 },
116
117 #[error("failed to complete the proxy protocol handshake")]
118 ProxyHandshake {
119 #[source]
120 source: ProxyAcceptError,
121 },
122
123 #[error("connection handshake timed out")]
124 HandshakeTimeout {
125 #[source]
126 source: tokio::time::error::Elapsed,
127 },
128}
129
130impl AcceptError {
131 fn tls_handshake(source: std::io::Error) -> Self {
132 Self::TlsHandshake { source }
133 }
134
135 fn proxy_handshake(source: ProxyAcceptError) -> Self {
136 Self::ProxyHandshake { source }
137 }
138
139 fn handshake_timeout(source: tokio::time::error::Elapsed) -> Self {
140 Self::HandshakeTimeout { source }
141 }
142}
143
144#[allow(clippy::type_complexity)]
150#[tracing::instrument(
151 name = "accept",
152 skip_all,
153 fields(
154 network.protocol.name = "http",
155 network.peer.address,
156 network.peer.port,
157 ),
158)]
159async fn accept<S, B>(
160 maybe_proxy_acceptor: &MaybeProxyAcceptor,
161 maybe_tls_acceptor: &MaybeTlsAcceptor,
162 peer_addr: SocketAddr,
163 stream: UnixOrTcpConnection,
164 service: S,
165) -> Result<
166 Connection<
167 'static,
168 TokioIo<MaybeTlsStream<Rewind<UnixOrTcpConnection>>>,
169 TowerToHyperService<AddExtension<S, ConnectionInfo>>,
170 TokioExecutor,
171 >,
172 AcceptError,
173>
174where
175 S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
176 S::Error: std::error::Error + Send + Sync + 'static,
177 S::Future: Send + 'static,
178 B: http_body::Body + Send + 'static,
179 B::Data: Send,
180 B::Error: std::error::Error + Send + Sync + 'static,
181{
182 let span = tracing::Span::current();
183
184 match peer_addr {
185 SocketAddr::Net(addr) => {
186 span.record("network.peer.address", tracing::field::display(addr.ip()));
187 span.record("network.peer.port", addr.port());
188 }
189 SocketAddr::Unix(ref addr) => {
190 span.record("network.peer.address", tracing::field::debug(addr));
191 }
192 }
193
194 tokio::time::timeout(HANDSHAKE_TIMEOUT, async move {
196 let (proxy, stream) = maybe_proxy_acceptor
197 .accept(stream)
198 .await
199 .map_err(AcceptError::proxy_handshake)?;
200
201 let stream = maybe_tls_acceptor
202 .accept(stream)
203 .await
204 .map_err(AcceptError::tls_handshake)?;
205
206 let tls = stream.tls_info();
207
208 let is_h2 = tls.as_ref().is_some_and(TlsStreamInfo::is_alpn_h2);
210
211 let info = ConnectionInfo {
212 tls,
213 proxy,
214 net_peer_addr: peer_addr.into_net(),
215 };
216
217 let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
218 if is_h2 {
219 builder = builder.http2_only();
220 }
221 builder.http1().keep_alive(true);
222
223 let service = TowerToHyperService::new(AddExtension::new(service, info));
224
225 let conn = builder
226 .serve_connection(TokioIo::new(stream), service)
227 .into_owned();
228
229 Ok(conn)
230 })
231 .instrument(span)
232 .await
233 .map_err(AcceptError::handshake_timeout)?
234}
235
236pin_project! {
237 struct AbortableConnection<C> {
247 #[pin]
248 connection: C,
249 #[pin]
250 cancellation_future: WaitForCancellationFutureOwned,
251 did_start_shutdown: bool,
252 }
253}
254
255impl<C> AbortableConnection<C> {
256 fn new(connection: C, cancellation_token: CancellationToken) -> Self {
257 Self {
258 connection,
259 cancellation_future: cancellation_token.cancelled_owned(),
260 did_start_shutdown: false,
261 }
262 }
263}
264
265impl<T, S, B> Future
266 for AbortableConnection<Connection<'static, T, TowerToHyperService<S>, TokioExecutor>>
267where
268 Connection<'static, T, TowerToHyperService<S>, TokioExecutor>: Future,
269 S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
270 S::Future: Send + 'static,
271 S::Error: std::error::Error + Send + Sync,
272 T: hyper::rt::Read + hyper::rt::Write + Unpin,
273 B: http_body::Body + Send + 'static,
274 B::Data: Send,
275 B::Error: std::error::Error + Send + Sync + 'static,
276{
277 type Output = <Connection<'static, T, TowerToHyperService<S>, TokioExecutor> as Future>::Output;
278
279 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
280 let mut this = self.project();
281
282 if let Poll::Ready(()) = this.cancellation_future.poll(cx) {
283 if !*this.did_start_shutdown {
284 *this.did_start_shutdown = true;
285 this.connection.as_mut().graceful_shutdown();
286 }
287 }
288
289 this.connection.poll(cx)
290 }
291}
292
293#[allow(clippy::too_many_lines)]
294pub async fn run_servers<S, B>(
295 listeners: impl IntoIterator<Item = Server<S>>,
296 soft_shutdown_token: CancellationToken,
297 hard_shutdown_token: CancellationToken,
298) where
299 S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
300 S::Future: Send + 'static,
301 S::Error: std::error::Error + Send + Sync + 'static,
302 B: http_body::Body + Send + 'static,
303 B::Data: Send,
304 B::Error: std::error::Error + Send + Sync + 'static,
305{
306 let _guard = soft_shutdown_token.clone().drop_guard();
309
310 let mut accept_stream: SelectAll<_> = listeners
312 .into_iter()
313 .map(|server| {
314 let maybe_proxy_acceptor = MaybeProxyAcceptor::new(server.proxy);
315 let maybe_tls_acceptor = MaybeTlsAcceptor::new(server.tls);
316 futures_util::stream::poll_fn(move |cx| {
317 let res =
318 std::task::ready!(server.listener.poll_accept(cx)).map(|(addr, stream)| {
319 (
320 maybe_proxy_acceptor,
321 maybe_tls_acceptor.clone(),
322 server.service.clone(),
323 addr,
324 stream,
325 )
326 });
327 Poll::Ready(Some(res))
328 })
329 })
330 .collect();
331
332 let mut accept_tasks = tokio::task::JoinSet::new();
334 let mut connection_tasks = tokio::task::JoinSet::new();
336
337 loop {
338 tokio::select! {
339 biased;
340
341 () = soft_shutdown_token.cancelled() => {
343 tracing::debug!("Shutting down listeners");
344 break;
345 },
346
347 res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
349 match res {
350 Some(Ok(Some(connection))) => {
351 let token = soft_shutdown_token.child_token();
352 connection_tasks.spawn(LogContext::new("http-serve").run(async move || {
353 tracing::debug!("Accepted connection");
354 if let Err(e) = AbortableConnection::new(connection, token).await {
355 tracing::warn!(error = &*e as &dyn std::error::Error, "Failed to serve connection");
356 }
357 }));
358 },
359 Some(Ok(None)) => { },
360 Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
361 None => tracing::error!("Join set was polled even though it was empty"),
362 }
363 },
364
365 res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
367 match res {
368 Some(Ok(())) => { },
369 Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
370 None => tracing::error!("Join set was polled even though it was empty"),
371 }
372 },
373
374 res = accept_stream.next() => {
376 let Some(res) = res else { continue };
377
378 accept_tasks.spawn(LogContext::new("http-accept").run(async move || {
382 let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = match res {
383 Ok(res) => res,
384 Err(e) => {
385 tracing::warn!(error = &e as &dyn std::error::Error, "Failed to accept connection from the underlying socket");
386 return None;
387 }
388 };
389
390 match accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await {
391 Ok(connection) => Some(connection),
392 Err(e) => {
393 tracing::warn!(error = &e as &dyn std::error::Error, "Failed to accept connection");
394 None
395 }
396 }
397 }));
398 },
399 };
400 }
401
402 if !accept_tasks.is_empty() || !connection_tasks.is_empty() {
404 tracing::info!(
405 "There are {active} active connections ({pending} pending), performing a graceful shutdown. Send the shutdown signal again to force.",
406 active = connection_tasks.len(),
407 pending = accept_tasks.len(),
408 );
409
410 while !accept_tasks.is_empty() || !connection_tasks.is_empty() {
411 tokio::select! {
412 biased;
413
414 res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
416 match res {
417 Some(Ok(Some(connection))) => {
418 let token = soft_shutdown_token.child_token();
419 connection_tasks.spawn(LogContext::new("http-serve").run(async || {
420 tracing::debug!("Accepted connection");
421 if let Err(e) = AbortableConnection::new(connection, token).await {
422 tracing::warn!(error = &*e as &dyn std::error::Error, "Failed to serve connection");
423 }
424 }));
425 }
426 Some(Ok(None)) => { },
427 Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
428 None => tracing::error!("Join set was polled even though it was empty"),
429 }
430 },
431
432 res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
434 match res {
435 Some(Ok(())) => { },
436 Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
437 None => tracing::error!("Join set was polled even though it was empty"),
438 }
439 },
440
441 () = hard_shutdown_token.cancelled() => {
443 tracing::warn!(
444 "Forcing shutdown ({active} active connections, {pending} pending connections)",
445 active = connection_tasks.len(),
446 pending = accept_tasks.len(),
447 );
448 break;
449 },
450 }
451 }
452 }
453
454 accept_tasks.shutdown().await;
455 connection_tasks.shutdown().await;
456}