SNode.C
Loading...
Searching...
No Matches
AuthorizationServer.cpp
Go to the documentation of this file.
1/*
2 * SNode.C - A Slim Toolkit for Network Communication
3 * Copyright (C) Volker Christian <me@vchrist.at>
4 * 2020, 2021, 2022, 2023, 2024, 2025
5 *
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Lesser General Public License as published
8 * by the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 */
19
20/*
21 * MIT License
22 *
23 * Permission is hereby granted, free of charge, to any person obtaining a copy
24 * of this software and associated documentation files (the "Software"), to deal
25 * in the Software without restriction, including without limitation the rights
26 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
27 * copies of the Software, and to permit persons to whom the Software is
28 * furnished to do so, subject to the following conditions:
29 *
30 * The above copyright notice and this permission notice shall be included in
31 * all copies or substantial portions of the Software.
32 *
33 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
34 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
35 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
36 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
37 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
38 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
39 * THE SOFTWARE.
40 */
41
42#include "database/mariadb/MariaDBClient.h"
43#include "database/mariadb/MariaDBCommandSequence.h"
44#include "express/legacy/in/WebApp.h"
45#include "express/middleware/JsonMiddleware.h"
46#include "express/middleware/StaticMiddleware.h"
47#include "log/Logger.h"
48#include "utils/sha1.h"
49
50#include <chrono>
51#include <ctime>
52#include <fstream>
53#include <iomanip>
54#include <mysql.h>
55#include <nlohmann/json.hpp>
56#include <sstream>
57#include <string>
58
59// IWYU pragma: no_include <nlohmann/json_fwd.hpp>
60// IWYU pragma: no_include <nlohmann/detail/json_ref.hpp>
61
62void addQueryParamToUri(std::string& uri, const std::string& queryParamName, const std::string& queryParamValue) {
63 if (uri.find('?') == std::string::npos) {
64 uri += '?';
65 } else {
66 uri += '&';
67 }
68 uri += queryParamName + "=" + queryParamValue;
69}
70
71std::string timeToString(std::chrono::time_point<std::chrono::system_clock> time) {
72 auto in_time_t = std::chrono::system_clock::to_time_t(time);
73 std::stringstream ss;
74 ss << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d %X");
75 return ss.str();
76}
77
79 const size_t uuidLength{36};
80 char uuidCharArray[uuidLength];
81 std::ifstream file("/proc/sys/kernel/random/uuid");
82 file.getline(uuidCharArray, uuidLength);
83 file.close();
84 return std::string{uuidCharArray};
85}
86
87std::string hashSha1(const std::string& str) {
88 utils::SHA1 checksum;
89 checksum.update(str);
90 return checksum.final();
91}
92
93int main(int argc, char* argv[]) {
94 express::WebApp::init(argc, argv);
95 const express::legacy::in::WebApp app("OAuth2AuthorizationServer");
96
97 const database::mariadb::MariaDBConnectionDetails details{
98 .connectionName = "authorization",
99 .hostname = "localhost",
100 .username = "rathalin",
101 .password = "rathalin",
102 .database = "oauth2",
103 .port = 3306,
104 .socket = "/run/mysqld/mysqld.sock",
105 .flags = 0,
106 };
107 database::mariadb::MariaDBClient db{details, [](const database::mariadb::MariaDBState& state) {
108 if (state.error != 0) {
109 VLOG(0) << "MySQL error: " << state.errorMessage << " [" << state.error << "]";
110 } else if (state.connected) {
111 VLOG(0) << "MySQL connected";
112 } else {
113 VLOG(0) << "MySQL disconnected";
114 }
115 }};
116
117 app.use(express::middleware::JsonMiddleware());
118
119 const express::Router router{};
120
121 // Middleware to catch requests without a valid client_id
122 router.use([&db] MIDDLEWARE(req, res, next) {
123 const std::string queryClientId{req->query("client_id")};
124 if (queryClientId.length() > 0) {
125 db.query(
126 "select count(*) from client where uuid = '" + queryClientId + "'",
127 [req, res, next, queryClientId](const MYSQL_ROW row) {
128 if (row != nullptr) {
129 if (std::stoi(row[0]) > 0) {
130 VLOG(1) << "Valid client id '" << queryClientId << "'";
131 VLOG(1) << "Next with " << req->httpVersion << " " << req->method << " " << req->url;
132 next();
133 } else {
134 VLOG(1) << "Invalid client id '" << queryClientId << "'";
135 res->sendStatus(401);
136 }
137 }
138 },
139 [res](const std::string& errorString, unsigned int errorNumber) {
140 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
141 res->sendStatus(500);
142 });
143 } else {
144 res->status(401).send("Invalid client_id");
145 }
146 });
147
148 router.get("/authorize", [&db] APPLICATION(req, res) {
149 // REQUIRED: response_type, client_id
150 // OPTIONAL: redirect_uri, scope
151 // RECOMMENDED: state
152 const std::string paramResponseType{req->query("response_type")};
153 const std::string paramClientId{req->query("client_id")};
154 const std::string paramRedirectUri{req->query("redirect_uri")};
155 const std::string paramScope{req->query("scope")};
156 const std::string paramState{req->query("state")};
157
158 VLOG(1) << "Query params: "
159 << "response_type=" << req->query("response_type") << ", "
160 << "redirect_uri=" << req->query("redirect_uri") << ", "
161 << "scope=" << req->query("scope") << ", "
162 << "state=" << req->query("state") << "\n";
163
164 if (paramResponseType != "code") {
165 VLOG(1) << "Auth invalid, sending Bad Request";
166 res->sendStatus(400);
167 return;
168 }
169
170 if (!paramRedirectUri.empty()) {
171 db.exec(
172 "update client set redirect_uri = '" + paramRedirectUri + "' where uuid = '" + paramClientId + "'",
173 [paramRedirectUri]() {
174 VLOG(1) << "Database: Set redirect_uri to " << paramRedirectUri;
175 },
176 [](const std::string& errorString, unsigned int errorNumber) {
177 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
178 });
179 }
180
181 if (!paramScope.empty()) {
182 db.exec(
183 "update client set scope = '" + paramScope + "' where uuid = '" + paramClientId + "'",
184 [paramScope]() {
185 VLOG(1) << "Database: Set scope to " << paramScope;
186 },
187 [](const std::string& errorString, unsigned int errorNumber) {
188 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
189 });
190 }
191
192 if (!paramState.empty()) {
193 db.exec(
194 "update client set state = '" + paramState + "' where uuid = '" + paramClientId + "'",
195 [paramState]() {
196 VLOG(1) << "Database: Set state to " << paramState;
197 },
198 [](const std::string& errorString, unsigned int errorNumber) {
199 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
200 });
201 }
202
203 VLOG(1) << "Auth request valid, redirecting to login";
204 std::string loginUri{"/oauth2/login"};
205 addQueryParamToUri(loginUri, "client_id", paramClientId);
206 res->redirect(loginUri);
207 });
208
209 router.get("/login", [] APPLICATION(req, res) {
210 res->sendFile("/home/rathalin/projects/snode.c/src/oauth2/authorization_server/vue-frontend-oauth2-auth-server/dist/index.html",
211 [req](int ret) {
212 if (ret != 0) {
213 PLOG(ERROR) << req->url;
214 }
215 });
216 });
217
218 router.post("/login", [&db] APPLICATION(req, res) {
219 req->getAttribute<nlohmann::json>(
220 [req, res, &db](nlohmann::json& body) {
221 db.query(
222 "select email, password_hash, password_salt, redirect_uri, state "
223 "from client "
224 "where uuid = '" +
225 req->query("client_id") + "'",
226 [req, res, &db, &body](const MYSQL_ROW row) {
227 if (row != nullptr) {
228 const std::string dbEmail{row[0]};
229 const std::string dbPasswordHash{row[1]};
230 const std::string dbPasswordSalt{row[2]};
231 const std::string dbRedirectUri{row[3]};
232 const std::string dbState{row[4]};
233 const std::string queryEmail{body["email"]};
234 const std::string queryPassword{body["password"]};
235 // Check email and password
236 if (dbEmail != queryEmail) {
237 res->status(401).send("Invalid email address");
238 } else if (dbPasswordHash != hashSha1(dbPasswordSalt + queryPassword)) {
239 res->status(401).send("Invalid password");
240 } else {
241 // Generate auth code which expires after 10 minutes
242 const unsigned int expireMinutes{10};
243 const std::string authCode{getNewUUID()};
244 db.exec(
245 "insert into token(uuid, expire_datetime) "
246 "values('" +
247 authCode + "', '" +
248 timeToString(std::chrono::system_clock::now() + std::chrono::minutes(expireMinutes)) + "')",
249 []() {
250 },
251 [res](const std::string& errorString, unsigned int errorNumber) {
252 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
253 res->sendStatus(500);
254 })
255 .query(
256 "select last_insert_id()",
257 [req, res, &db, dbState, dbRedirectUri, authCode](const MYSQL_ROW row) {
258 if (row != nullptr) {
259 db.exec(
260 "update client "
261 "set auth_code_id = '" +
262 std::string{row[0]} +
263 "' "
264 "where uuid = '" +
265 req->query("client_id") + "'",
266 [res, dbState, dbRedirectUri, authCode]() {
267 // Redirect back to the client app
268 std::string clientRedirectUri{dbRedirectUri};
269 addQueryParamToUri(clientRedirectUri, "code", authCode);
270 if (!dbState.empty()) {
271 addQueryParamToUri(clientRedirectUri, "state", dbState);
272 }
273 // Set CORS header
274 res->set("Access-Control-Allow-Origin", "*");
275 const nlohmann::json responseJson = {{"redirect_uri", clientRedirectUri}};
276 const std::string responseJsonString{responseJson.dump(4)};
277 VLOG(1) << "Sending json reponse: " << responseJsonString;
278 res->send(responseJsonString);
279 },
280 [res](const std::string& errorString, unsigned int errorNumber) {
281 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
282 res->sendStatus(500);
283 });
284 }
285 },
286 [res](const std::string& errorString, unsigned int errorNumber) {
287 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
288 res->sendStatus(500);
289 });
290 }
291 }
292 },
293 [res](const std::string& errorString, unsigned int errorNumber) {
294 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
295 res->sendStatus(500);
296 });
297 },
298 [res]([[maybe_unused]] const std::string& key) {
299 res->sendStatus(500);
300 });
301 });
302
303 router.get("/token", [&db] APPLICATION(req, res) {
304 res->set("Access-Control-Allow-Origin", "*");
305 auto queryGrantType = req->query("grant_type");
306 VLOG(1) << "GrandType: " << queryGrantType;
307 auto queryCode = req->query("code");
308 VLOG(1) << "Code: " << queryCode;
309 auto queryRedirectUri = req->query("redirect_uri");
310 VLOG(1) << "RedirectUri: " << queryRedirectUri;
311 if (queryGrantType != "authorization_code") {
312 res->status(400).send("Invalid query parameter 'grant_type', value must be 'authorization_code'");
313 return;
314 }
315 if (queryCode.length() == 0) {
316 res->status(400).send("Missing query parameter 'code'");
317 return;
318 }
319 if (queryRedirectUri.length() == 0) {
320 res->status(400).send("Missing query parameter 'redirect_uri'");
321 return;
322 }
323 db.query(
324 "select count(*) "
325 "from client "
326 "where uuid = '" +
327 req->query("client_id") +
328 "' "
329 "and redirect_uri = '" +
330 queryRedirectUri + "'",
331 [req, res, &db](const MYSQL_ROW row) {
332 if (row != nullptr) {
333 if (std::stoi(row[0]) == 0) {
334 res->status(400).send("Query param 'redirect_uri' must be the same as in the initial request");
335 } else {
336 db.query(
337 "select count(*) "
338 "from client c "
339 "join token a "
340 "on c.auth_code_id = a.id "
341 "where c.uuid = '" +
342 req->query("client_id") +
343 "' "
344 "and a.uuid = '" +
345 req->query("code") +
346 "' "
347 "and timestampdiff(second, current_timestamp(), a.expire_datetime) > 0",
348 [req, res, &db](const MYSQL_ROW row) {
349 if (row != nullptr) {
350 if (std::stoi(row[0]) == 0) {
351 res->status(401).send("Invalid auth token");
352 return;
353 }
354 // Generate access and refresh token
355 const std::string accessToken{getNewUUID()};
356 const unsigned int accessTokenExpireSeconds{60 * 60}; // 1 hour
357 const std::string refreshToken{getNewUUID()};
358 const unsigned int refreshTokenExpireSeconds{60 * 60 * 24}; // 24 hours
359 db.exec(
360 "insert into token(uuid, expire_datetime) "
361 "values('" +
362 accessToken + "', '" +
363 timeToString(std::chrono::system_clock::now() +
364 std::chrono::seconds(accessTokenExpireSeconds)) +
365 "')",
366 []() {
367 },
368 [res](const std::string& errorString, unsigned int errorNumber) {
369 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
370 res->sendStatus(500);
371 })
372 .query(
373 "select last_insert_id()",
374 [req, res, &db](const MYSQL_ROW row) {
375 if (row != nullptr) {
376 db.exec(
377 "update client "
378 "set access_token_id = '" +
379 std::string{row[0]} +
380 "' "
381 "where uuid = '" +
382 req->query("client_id") + "'",
383 []() {
384 },
385 [res](const std::string& errorString, unsigned int errorNumber) {
386 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
387 res->sendStatus(500);
388 });
389 }
390 },
391 [res](const std::string& errorString, unsigned int errorNumber) {
392 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
393 res->sendStatus(500);
394 })
395 .exec(
396 "insert into token(uuid, expire_datetime) "
397 "values('" +
398 refreshToken + "', '" +
399 timeToString(std::chrono::system_clock::now() +
400 std::chrono::seconds(refreshTokenExpireSeconds)) +
401 "')",
402 []() {
403 },
404 [res](const std::string& errorString, unsigned int errorNumber) {
405 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
406 res->sendStatus(500);
407 })
408 .query(
409 "select last_insert_id()",
410 [req, res, &db, accessToken, accessTokenExpireSeconds, refreshToken](const MYSQL_ROW row) {
411 if (row != nullptr) {
412 db.exec(
413 "update client "
414 "set refresh_token_id = '" +
415 std::string{row[0]} +
416 "' "
417 "where uuid = '" +
418 req->query("client_id") + "'",
419 [res, accessToken, accessTokenExpireSeconds, refreshToken]() {
420 // Send auth token and refresh token
421 const nlohmann::json jsonResponse = {{"access_token", accessToken},
422 {"expires_in", accessTokenExpireSeconds},
423 {"refresh_token", refreshToken}};
424 const std::string jsonResponseString{jsonResponse.dump(4)};
425 res->send(jsonResponseString);
426 },
427 [res](const std::string& errorString, unsigned int errorNumber) {
428 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
429 res->sendStatus(500);
430 });
431 }
432 },
433 [res](const std::string& errorString, unsigned int errorNumber) {
434 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
435 res->sendStatus(500);
436 });
437 }
438 },
439 [res](const std::string& errorString, unsigned int errorNumber) {
440 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
441 res->sendStatus(500);
442 });
443 }
444 }
445 },
446 [res](const std::string& errorString, unsigned int errorNumber) {
447 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
448 res->sendStatus(500);
449 });
450 });
451
452 router.post("/token/refresh", [&db] APPLICATION(req, res) {
453 res->set("Access-Control-Allow-Origin", "*");
454 auto queryClientId = req->query("client_id");
455 VLOG(1) << "ClientId: " << queryClientId;
456 auto queryGrantType = req->query("grant_type");
457 VLOG(1) << "GrandType: " << queryGrantType;
458 auto queryRefreshToken = req->query("refresh_token");
459 VLOG(1) << "RefreshToken: " << queryRefreshToken;
460 auto queryState = req->query("state");
461 VLOG(1) << "State: " << queryState;
462 if (queryGrantType.length() == 0) {
463 res->status(400).send("Missing query parameter 'grant_type'");
464 return;
465 }
466 if (queryGrantType != "refresh_token") {
467 res->status(400).send("Invalid query parameter 'grant_type', value must be 'refresh_token'");
468 return;
469 }
470 if (queryRefreshToken.empty()) {
471 res->status(400).send("Missing query parameter 'refresh_token'");
472 }
473 db.query(
474 "select count(*) "
475 "from client c "
476 "join token r "
477 "on c.refresh_token_id = r.id "
478 "where c.uuid = '" +
479 req->query("client_id") +
480 "' "
481 "and r.uuid = '" +
482 req->query("refresh_token") +
483 "' "
484 "and timestampdiff(second, current_timestamp(), r.expire_datetime) > 0",
485 [req, res, &db](const MYSQL_ROW row) {
486 if (row != nullptr) {
487 if (std::stoi(row[0]) == 0) {
488 res->status(401).send("Invalid refresh token");
489 return;
490 }
491 // Generate access token
492 std::string accessToken{getNewUUID()};
493 unsigned int accessTokenExpireSeconds{60 * 60}; // 1 hour
494 db.exec(
495 "insert into token(uuid, expire_datetime) "
496 "values('" +
497 accessToken + "', '" +
498 timeToString(std::chrono::system_clock::now() + std::chrono::seconds(accessTokenExpireSeconds)) + "')",
499 []() {
500 },
501 [res](const std::string& errorString, unsigned int errorNumber) {
502 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
503 res->sendStatus(500);
504 })
505 .query(
506 "select last_insert_id()",
507 [req, res, &db, accessToken, accessTokenExpireSeconds](const MYSQL_ROW row) {
508 if (row != nullptr) {
509 db.exec(
510 "update client "
511 "set access_token_id = '" +
512 std::string{row[0]} +
513 "' "
514 "where uuid = '" +
515 req->query("client_id") + "'",
516 [res, accessToken, accessTokenExpireSeconds]() {
517 const nlohmann::json responseJson = {{"access_token", accessToken},
518 {"expires_in", accessTokenExpireSeconds}};
519 res->send(responseJson.dump(4));
520 },
521 [res](const std::string& errorString, unsigned int errorNumber) {
522 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
523 res->sendStatus(500);
524 });
525 }
526 },
527 [res](const std::string& errorString, unsigned int errorNumber) {
528 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
529 res->sendStatus(500);
530 });
531 }
532 },
533 [res](const std::string& errorString, unsigned int errorNumber) {
534 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
535 res->sendStatus(500);
536 });
537 });
538
539 router.post("/token/validate", [&db] APPLICATION(req, res) {
540 VLOG(1) << "POST /token/validate";
541 req->getAttribute<nlohmann::json>([res, &db](nlohmann::json& jsonBody) {
542 if (!jsonBody.contains("access_token")) {
543 VLOG(1) << "Missing 'access_token' in json";
544 res->status(500).send("Missing 'access_token' in json");
545 return;
546 }
547 const std::string jsonAccessToken{jsonBody["access_token"]};
548 if (!jsonBody.contains("client_id")) {
549 VLOG(1) << "Missing 'client_id' in json";
550 res->status(500).send("Missing 'client_id' in json");
551 return;
552 }
553 const std::string jsonClientId{jsonBody["client_id"]};
554 db.query(
555 "select count(*) "
556 "from client c "
557 "join token a "
558 "on c.access_token_id = a.id "
559 "where c.uuid = '" +
560 jsonClientId +
561 "' "
562 "and a.uuid = '" +
563 jsonAccessToken + "'",
564 [res, jsonClientId, jsonAccessToken](const MYSQL_ROW row) {
565 if (row != nullptr) {
566 if (std::stoi(row[0]) == 0) {
567 const nlohmann::json errorJson = {{"error", "Invalid access token"}};
568 VLOG(1) << "Sending 401: Invalid access token '" << jsonAccessToken << "'";
569 res->status(401).send(errorJson.dump(4));
570 } else {
571 VLOG(1) << "Sending 200: Valid access token '" << jsonAccessToken << "";
572 const nlohmann::json successJson = {{"success", "Valid access token"}};
573 res->status(200).send(successJson.dump(4));
574 }
575 }
576 },
577 [res](const std::string& errorString, unsigned int errorNumber) {
578 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
579 res->sendStatus(500);
580 });
581 });
582 });
583
584 app.use("/oauth2", router);
585 app.use(express::middleware::StaticMiddleware(
586 "/home/rathalin/projects/snode.c/src/oauth2/authorization_server/vue-frontend-oauth2-auth-server/dist/"));
587
588 app.listen(8082, [](const express::legacy::in::WebApp::SocketAddress& socketAddress, const core::socket::State& state) {
589 switch (state) {
590 case core::socket::State::OK:
591 VLOG(1) << "OAuth2AuthorizationServer: listening on '" << socketAddress.toString() << "'";
592 break;
593 case core::socket::State::DISABLED:
594 VLOG(1) << "OAuth2AuthorizationServer: disabled";
595 break;
596 case core::socket::State::ERROR:
597 VLOG(1) << "OAuth2AuthorizationServer: error occurred";
598 break;
599 case core::socket::State::FATAL:
600 VLOG(1) << "OAuth2AuthorizationServer: fatal error occurred";
601 break;
602 }
603 });
604
605 return express::WebApp::start();
606}
std::string timeToString(std::chrono::time_point< std::chrono::system_clock > time)
void addQueryParamToUri(std::string &uri, const std::string &queryParamName, const std::string &queryParamValue)
std::string getNewUUID()
std::string hashSha1(const std::string &str)
#define APPLICATION(req, res)
Definition Router.h:68
#define MIDDLEWARE(req, res, next)
Definition Router.h:63
MariaDBCommandSequence & exec(const std::string &sql, const std::function< void(void)> &onExec, const std::function< void(const std::string &, unsigned int)> &onError)
Definition Config.h:59
int main(int argc, char *argv[])