2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
42#include "core/socket/stream/SocketConnection.hpp"
43#include "core/socket/stream/tls/SocketConnection.h"
44#include "core/socket/stream/tls/TLSHandshake.h"
45#include "core/socket/stream/tls/TLSShutdown.h"
47#ifndef DOXYGEN_SHOULD_SKIP_THIS
49#include "core/socket/stream/tls/ssl_utils.h"
50#include "log/Logger.h"
52#include <openssl/ssl.h>
60 template <
typename PhysicalSocket>
62 PhysicalSocket&& physicalSocket,
63 const std::function<
void(SocketConnection*)>& onDisconnect,
64 const std::string& configuredServer,
65 const SocketAddress& localAddress,
66 const SocketAddress& remoteAddress,
67 const utils::
Timeval& readTimeout,
68 const utils::
Timeval& writeTimeout,
69 std::size_t readBlockSize,
70 std::size_t writeBlockSize,
71 const utils::
Timeval& terminateTimeout)
74 std::move(physicalSocket),
75 [onDisconnect,
this]() {
88 template <
typename PhysicalSocket>
89 SSL* SocketConnection<PhysicalSocket>::
getSSL()
const {
93 template <
typename PhysicalSocket>
94 SSL* SocketConnection<PhysicalSocket>::
startSSL(
95 int fd, SSL_CTX* ctx,
const utils::
Timeval& sslInitTimeout,
const utils::
Timeval& sslShutdownTimeout,
bool closeNotifyIsEOF) {
101 if (
ssl !=
nullptr) {
102 SSL_set_ex_data(
ssl, 0,
const_cast<std::string*>(&Super::getConnectionName()));
104 if (SSL_set_fd(
ssl, fd) == 1) {
105 SocketReader::ssl =
ssl;
106 SocketWriter::ssl =
ssl;
107 SocketReader::closeNotifyIsEOF = closeNotifyIsEOF;
108 SocketWriter::closeNotifyIsEOF = closeNotifyIsEOF;
119 template <
typename PhysicalSocket>
120 void SocketConnection<PhysicalSocket>::
stopSSL() {
121 if (
ssl !=
nullptr) {
125 SocketReader::ssl =
nullptr;
126 SocketWriter::ssl =
nullptr;
130 template <
typename PhysicalSocket>
131 bool SocketConnection<PhysicalSocket>::
doSSLHandshake(
const std::function<
void()>& onSuccess,
132 const std::function<
void()>& onTimeout,
133 const std::function<
void(
int)>& onStatus) {
134 if (
ssl !=
nullptr) {
135 if (!SocketReader::isSuspended()) {
136 SocketReader::suspend();
138 if (!SocketWriter::isSuspended()) {
139 SocketWriter::suspend();
143 Super::getConnectionName()
,
145 [onSuccess,
this]() {
146 SocketReader::span();
152 [onStatus](
int sslErr) {
158 return ssl !=
nullptr;
161 template <
typename PhysicalSocket>
163 bool resumeSocketReader =
false;
164 bool resumeSocketWriter =
false;
166 if (!SocketReader::isSuspended()) {
167 SocketReader::suspend();
168 resumeSocketReader =
true;
171 if (!SocketWriter::isSuspended()) {
172 SocketWriter::suspend();
173 resumeSocketWriter =
true;
177 Super::getConnectionName()
,
179 [
this, resumeSocketReader, resumeSocketWriter]() {
180 if (resumeSocketReader) {
181 SocketReader::resume();
183 if (resumeSocketWriter) {
184 SocketWriter::resume();
186 if (SSL_get_shutdown(
ssl) == (SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN)) {
187 LOG(DEBUG) << Super::getConnectionName() <<
" SSL/TLS: Close_notify received and sent";
189 LOG(DEBUG) << Super::getConnectionName() <<
" SSL/TLS: Close_notify sent";
191 if (SSL_get_shutdown(
ssl) == SSL_SENT_SHUTDOWN && SocketWriter::closeNotifyIsEOF) {
192 LOG(TRACE) << Super::getConnectionName() <<
" SSL/TLS: Close_notify is EOF: setting sslShutdownTimeout to "
198 [
this, resumeSocketReader, resumeSocketWriter]() {
199 if (resumeSocketReader) {
200 SocketReader::resume();
202 if (resumeSocketWriter) {
203 SocketWriter::resume();
205 LOG(ERROR) << Super::getConnectionName() <<
" SSL/TLS: Shutdown handshake timed out";
206 Super::doWriteShutdown([
this]() {
207 SocketConnection::close();
210 [
this, resumeSocketReader, resumeSocketWriter](
int sslErr) {
211 if (resumeSocketReader) {
212 SocketReader::resume();
214 if (resumeSocketWriter) {
215 SocketWriter::resume();
217 ssl_log(Super::getConnectionName() +
" SSL/TLS: Shutdown handshake failed", sslErr);
218 Super::doWriteShutdown([
this]() {
219 SocketConnection::close();
225 template <
typename PhysicalSocket>
227 if ((SSL_get_shutdown(
ssl) & SSL_RECEIVED_SHUTDOWN) != 0) {
228 if ((SSL_get_shutdown(
ssl) & SSL_SENT_SHUTDOWN) != 0) {
229 LOG(DEBUG) << Super::getConnectionName() <<
" SSL/TLS: Close_notify sent and received";
231 SocketWriter::shutdownInProgress =
false;
233 LOG(DEBUG) << Super::getConnectionName() <<
" SSL/TLS: Close_notify received";
238 LOG(ERROR) << Super::getConnectionName() <<
" SSL/TLS: Unexpected EOF error";
240 SocketWriter::shutdownInProgress =
false;
241 SSL_set_shutdown(
ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
245 template <
typename PhysicalSocket>
246 void SocketConnection<PhysicalSocket>::
doWriteShutdown(
const std::function<
void()>& onShutdown) {
247 if ((SSL_get_shutdown(
ssl) & SSL_SENT_SHUTDOWN) == 0) {
248 LOG(DEBUG) << Super::getConnectionName() <<
" SSL/TLS: Send close_notify";
252 Super::doWriteShutdown(onShutdown);
SocketConnection(const std::string &instanceName, PhysicalSocket &&physicalSocket, const std::function< void(SocketConnection *)> &onDisconnect, const std::string &configuredServer, const SocketAddress &localAddress, const SocketAddress &remoteAddress, const utils::Timeval &readTimeout, const utils::Timeval &writeTimeout, std::size_t readBlockSize, std::size_t writeBlockSize, const utils::Timeval &terminateTimeout)
utils::Timeval sslShutdownTimeout
SSL * startSSL(int fd, SSL_CTX *ctx, const utils::Timeval &sslInitTimeout, const utils::Timeval &sslShutdownTimeout, bool closeNotifyIsEOF)
bool doSSLHandshake(const std::function< void()> &onSuccess, const std::function< void()> &onTimeout, const std::function< void(int)> &onStatus) final
void onReadShutdown() final
void doWriteShutdown(const std::function< void()> &onShutdown) final
utils::Timeval sslInitTimeout
static void doHandshake(const std::string &instanceName, SSL *ssl, const std::function< void(void)> &onSuccess, const std::function< void(void)> &onTimeout, const std::function< void(int)> &onStatus, const utils::Timeval &timeout)
static void doShutdown(const std::string &instanceName, SSL *ssl, const std::function< void(void)> &onSuccess, const std::function< void(void)> &onTimeout, const std::function< void(int)> &onStatus, const utils::Timeval &timeout)
Timeval & operator=(const Timeval &timeVal)