1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
//! Types related to the HTTPS webhook server.

use super::handle;
use crate::{errors, event_loop::Webhook};
use hyper::{server::conn::Http, service::service_fn};
use hyper::{Body, Request};
use tracing::instrument;

#[cfg(feature = "tls")]
pub use native_tls::Identity;
#[cfg(feature = "tls")]
use native_tls::TlsAcceptor;
use std::{convert::Infallible, net::SocketAddr, sync::Arc};
use tokio::net::TcpListener;
use tokio::time::timeout;
#[cfg(feature = "rustls")]
pub use tokio_rustls::rustls::ServerConfig;
#[cfg(feature = "rustls")]
use tokio_rustls::TlsAcceptor;

/// Configures the HTTPS webhook server.
#[must_use = "webhook server needs to be `start`ed to run the event loop"]
pub struct Https<'a> {
    webhook: Webhook<'a>,
    #[cfg(feature = "tls")]
    identity: Identity,
    #[cfg(feature = "rustls")]
    config: ServerConfig,
}

impl<'a> Https<'a> {
    pub(crate) const fn new(
        webhook: Webhook<'a>,
        #[cfg(feature = "tls")] identity: Identity,
        #[cfg(feature = "rustls")] config: ServerConfig,
    ) -> Self {
        Self {
            webhook,
            #[cfg(feature = "tls")]
            identity,
            #[cfg(feature = "rustls")]
            config,
        }
    }
}

impl<'a> Https<'a> {
    /// Starts the event loop.
    #[instrument(name = "https_webhook", skip(self))]
    pub async fn start(self) -> Result<Infallible, errors::HttpsWebhook> {
        let Webhook {
            event_loop,
            bind_to,
            port,
            updates_url,
            url,
            ip_address,
            certificate,
            max_connections,
            allowed_updates,
            request_timeout,
            drop_pending_updates,
        } = self.webhook;

        let set_webhook = event_loop
            .bot
            .set_webhook(
                url,
                ip_address,
                certificate,
                max_connections,
                allowed_updates,
                drop_pending_updates,
            )
            .call();

        timeout(request_timeout, set_webhook).await??;

        let set_commands = event_loop.set_commands_descriptions();
        match timeout(request_timeout, set_commands).await {
            Ok(Err(method)) => {
                return Err(errors::HttpsWebhook::SetMyCommands(method))
            }
            Err(timeout) => {
                return Err(errors::HttpsWebhook::SetMyCommandsTimeout(timeout))
            }
            Ok(_) => (),
        };

        let event_loop = Arc::new(event_loop);
        let addr = SocketAddr::new(bind_to, port);
        let updates_url = Arc::new(updates_url);

        #[cfg(feature = "tls")]
        let tls_acceptor = {
            let tls_acceptor = TlsAcceptor::builder(self.identity).build()?;
            tokio_native_tls::TlsAcceptor::from(tls_acceptor)
        };
        #[cfg(feature = "rustls")]
        let tls_acceptor = TlsAcceptor::from(Arc::new(self.config));

        let server = TcpListener::bind(&addr).await?;

        let http_proto = Http::new();

        loop {
            let (tcp_stream, _) = server.accept().await?;
            let tls_stream = tls_acceptor.accept(tcp_stream).await?;

            let event_loop = Arc::clone(&event_loop);
            let updates_url = Arc::clone(&updates_url);

            let service = service_fn(move |request: Request<Body>| {
                handle(
                    Arc::clone(&event_loop),
                    request,
                    Arc::clone(&updates_url),
                )
            });

            let conn = http_proto.serve_connection(tls_stream, service);

            conn.await?;
        }
    }
}