SNode.C
Loading...
Searching...
No Matches
SocketAcceptor.hpp
Go to the documentation of this file.
1/*
2 * SNode.C - a slim toolkit for network communication
3 * Copyright (C) Volker Christian <me@vchrist.at>
4 * 2020, 2021, 2022, 2023, 2024, 2025
5 *
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Lesser General Public License as published
8 * by the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 */
19
20#include "core/socket/stream/SocketAcceptor.hpp"
21#include "core/socket/stream/tls/SocketAcceptor.h"
22
23#ifndef DOXYGEN_SHOULD_SKIP_THIS
24
25#include "core/socket/stream/tls/ssl_utils.h"
26#include "log/Logger.h"
27
28#include <algorithm>
29#include <openssl/ssl.h>
30#include <string>
31
32#endif /* DOXYGEN_SHOULD_SKIP_THIS */
33
34namespace core::socket::stream::tls {
35
36 template <typename PhysicalServerSocket, typename Config>
37 SocketAcceptor<PhysicalServerSocket, Config>::SocketAcceptor(
38 const std::shared_ptr<SocketContextFactory>& socketContextFactory,
39 const std::function<void(SocketConnection*)>& onConnect,
40 const std::function<void(SocketConnection*)>& onConnected,
41 const std::function<void(SocketConnection*)>& onDisconnect,
42 const std::function<void(const SocketAddress&, core::socket::State)>& onStatus,
43 const std::shared_ptr<Config>& config)
44 : Super(
45 socketContextFactory,
46 [onConnect, this](SocketConnection* socketConnection) { // onConnect
47 onConnect(socketConnection);
48
49 SSL* ssl = socketConnection->startSSL(socketConnection->getFd(),
50 Super::config->getSslCtx(),
51 Super::config->getInitTimeout(),
52 Super::config->getShutdownTimeout(),
53 !Super::config->getNoCloseNotifyIsEOF());
54 if (ssl != nullptr) {
55 SSL_set_accept_state(ssl);
56 SSL_set_ex_data(ssl, 1, Super::config.get());
57 }
58 },
59 [socketContextFactory, onConnected](SocketConnection* socketConnection) { // on Connected
60 LOG(TRACE) << socketConnection->getConnectionName() << " SSL/TLS: Start handshake";
61 if (!socketConnection->doSSLHandshake(
62 [socketContextFactory,
63 onConnected,
64 socketConnection]() { // onSuccess
65 LOG(DEBUG) << socketConnection->getConnectionName() << " SSL/TLS: Handshake success";
66
67 onConnected(socketConnection);
68
69 socketConnection->connectSocketContext(socketContextFactory);
70 },
71 [socketConnection]() { // onTimeout
72 LOG(ERROR) << socketConnection->getConnectionName() << "SSL/TLS: Handshake timed out";
73
74 socketConnection->close();
75 },
76 [socketConnection](int sslErr) { //
77 ssl_log(socketConnection->getConnectionName() + " SSL/TLS: Handshake failed", sslErr);
78
79 socketConnection->close();
80 })) {
81 LOG(ERROR) << socketConnection->getConnectionName() + " SSL/TLS: Handshake failed";
82
83 socketConnection->close();
84 }
85 },
86 [onDisconnect, instanceName = config->getInstanceName()](SocketConnection* socketConnection) { // onDisconnect
87 socketConnection->stopSSL();
88 onDisconnect(socketConnection);
89 },
90 onStatus,
91 config) {
92 }
93
94 template <typename PhysicalSocketServer, typename Config>
95 SocketAcceptor<PhysicalSocketServer, Config>::SocketAcceptor(const SocketAcceptor& socketAcceptor)
97 }
98
99 template <typename PhysicalClientSocket, typename Config>
100 void SocketAcceptor<PhysicalClientSocket, Config>::useNextSocketAddress() {
101 new SocketAcceptor(*this);
102 }
103
104 template <typename PhysicalSocketServer, typename Config>
105 void SocketAcceptor<PhysicalSocketServer, Config>::init() {
106 if (!config->getDisabled()) {
107 LOG(TRACE) << config->getInstanceName() << " SSL/TLS: SSL_CTX creating ...";
108 SSL_CTX* sslCtx = config->getSslCtx();
109
110 if (sslCtx != nullptr) {
111 LOG(DEBUG) << config->getInstanceName() << " SSL/TLS: SSL_CTX created";
112
113 SSL_CTX_set_client_hello_cb(sslCtx, clientHelloCallback, nullptr);
114
115 Super::init();
116 } else {
117 LOG(ERROR) << config->getInstanceName() << " SSL/TLS: SSL/TLS creation failed";
118
119 Super::onStatus(Super::config->Local::getSocketAddress(), core::socket::STATE_ERROR);
120 Super::destruct();
121 }
122 } else {
123 Super::init();
124 }
125 }
126
127 template <typename PhysicalSocketServer, typename Config>
128 int SocketAcceptor<PhysicalSocketServer, Config>::clientHelloCallback(SSL* ssl, int* al, [[maybe_unused]] void* arg) {
129 int ret = SSL_CLIENT_HELLO_SUCCESS;
130
131 std::string connectionName = *static_cast<std::string*>(SSL_get_ex_data(ssl, 0));
132 Config* config = static_cast<Config*>(SSL_get_ex_data(ssl, 1));
133
134 std::string serverNameIndication = core::socket::stream::tls::ssl_get_servername_from_client_hello(ssl);
135
136 if (!serverNameIndication.empty()) {
137 SSL_CTX* sniSslCtx = config->getSniCtx(serverNameIndication);
138
139 if (sniSslCtx != nullptr) {
140 LOG(DEBUG) << connectionName << " SSL/TLS: Setting sni certificate for '" << serverNameIndication << "'";
141 core::socket::stream::tls::ssl_set_ssl_ctx(ssl, sniSslCtx);
142 } else if (config->getForceSni()) {
143 LOG(ERROR) << connectionName << " SSL/TLS: No sni certificate found for '" << serverNameIndication
144 << "' but forceSni set - terminating";
145 ret = SSL_CLIENT_HELLO_ERROR;
146 *al = SSL_AD_UNRECOGNIZED_NAME;
147 } else {
148 LOG(WARNING) << connectionName << " SSL/TLS: No sni certificate found for '" << serverNameIndication
149 << "'. Still using master certificate";
150 }
151 } else {
152 LOG(DEBUG) << connectionName << " SSL/TLS: No sni certificate requested from client. Still using master certificate";
153 }
154
155 return ret;
156 }
157
158} // namespace core::socket::stream::tls
static int clientHelloCallback(SSL *ssl, int *al, void *arg)
SocketAcceptor(const SocketAcceptor &socketAcceptor)