SNode.C
Loading...
Searching...
No Matches
AuthorizationServer.cpp
Go to the documentation of this file.
1/*
2 * SNode.C - A Slim Toolkit for Network Communication
3 * Copyright (C) Volker Christian <me@vchrist.at>
4 * 2020, 2021, 2022, 2023, 2024, 2025
5 *
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Lesser General Public License as published
8 * by the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 */
19
20/*
21 * MIT License
22 *
23 * Permission is hereby granted, free of charge, to any person obtaining a copy
24 * of this software and associated documentation files (the "Software"), to deal
25 * in the Software without restriction, including without limitation the rights
26 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
27 * copies of the Software, and to permit persons to whom the Software is
28 * furnished to do so, subject to the following conditions:
29 *
30 * The above copyright notice and this permission notice shall be included in
31 * all copies or substantial portions of the Software.
32 *
33 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
34 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
35 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
36 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
37 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
38 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
39 * THE SOFTWARE.
40 */
41
42#include "database/mariadb/MariaDBClient.h"
43#include "database/mariadb/MariaDBCommandSequence.h"
44#include "express/legacy/in/WebApp.h"
45#include "express/middleware/JsonMiddleware.h"
46#include "express/middleware/StaticMiddleware.h"
47#include "log/Logger.h"
48#include "utils/sha1.h"
49
50#include <chrono>
51#include <ctime>
52#include <fstream>
53#include <iomanip>
54#include <mysql.h>
55#include <nlohmann/json.hpp>
56#include <sstream>
57#include <string>
58
59// IWYU pragma: no_include <nlohmann/json_fwd.hpp>
60// IWYU pragma: no_include <nlohmann/detail/json_ref.hpp>
61
62void addQueryParamToUri(std::string& uri, const std::string& queryParamName, const std::string& queryParamValue) {
63 if (uri.find('?') == std::string::npos) {
64 uri += '?';
65 } else {
66 uri += '&';
67 }
68 uri += queryParamName + "=" + queryParamValue;
69}
70
71std::string timeToString(std::chrono::time_point<std::chrono::system_clock> time) {
72 auto in_time_t = std::chrono::system_clock::to_time_t(time);
73 std::stringstream ss;
74 ss << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d %X");
75 return ss.str();
76}
77
79 const size_t uuidLength{36};
80 char uuidCharArray[uuidLength];
81 std::ifstream file("/proc/sys/kernel/random/uuid");
82 file.getline(uuidCharArray, uuidLength);
83 file.close();
84 return std::string{uuidCharArray};
85}
86
87std::string hashSha1(const std::string& str) {
88 utils::SHA1 checksum;
89 checksum.update(str);
90 return checksum.final();
91}
92
93int main(int argc, char* argv[]) {
94 express::WebApp::init(argc, argv);
95 const express::legacy::in::WebApp app("OAuth2AuthorizationServer");
96
97 const database::mariadb::MariaDBConnectionDetails details{
98 .hostname = "localhost",
99 .username = "rathalin",
100 .password = "rathalin",
101 .database = "oauth2",
102 .port = 3306,
103 .socket = "/run/mysqld/mysqld.sock",
104 .flags = 0,
105 };
106 database::mariadb::MariaDBClient db{details};
107
108 app.use(express::middleware::JsonMiddleware());
109
110 const express::Router router{};
111
112 // Middleware to catch requests without a valid client_id
113 router.use([&db] MIDDLEWARE(req, res, next) {
114 const std::string queryClientId{req->query("client_id")};
115 if (queryClientId.length() > 0) {
116 db.query(
117 "select count(*) from client where uuid = '" + queryClientId + "'",
118 [req, res, next, queryClientId](const MYSQL_ROW row) {
119 if (row != nullptr) {
120 if (std::stoi(row[0]) > 0) {
121 VLOG(1) << "Valid client id '" << queryClientId << "'";
122 VLOG(1) << "Next with " << req->httpVersion << " " << req->method << " " << req->url;
123 next();
124 } else {
125 VLOG(1) << "Invalid client id '" << queryClientId << "'";
126 res->sendStatus(401);
127 }
128 }
129 },
130 [res](const std::string& errorString, unsigned int errorNumber) {
131 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
132 res->sendStatus(500);
133 });
134 } else {
135 res->status(401).send("Invalid client_id");
136 }
137 });
138
139 router.get("/authorize", [&db] APPLICATION(req, res) {
140 // REQUIRED: response_type, client_id
141 // OPTIONAL: redirect_uri, scope
142 // RECOMMENDED: state
143 const std::string paramResponseType{req->query("response_type")};
144 const std::string paramClientId{req->query("client_id")};
145 const std::string paramRedirectUri{req->query("redirect_uri")};
146 const std::string paramScope{req->query("scope")};
147 const std::string paramState{req->query("state")};
148
149 VLOG(1) << "Query params: "
150 << "response_type=" << req->query("response_type") << ", "
151 << "redirect_uri=" << req->query("redirect_uri") << ", "
152 << "scope=" << req->query("scope") << ", "
153 << "state=" << req->query("state") << "\n";
154
155 if (paramResponseType != "code") {
156 VLOG(1) << "Auth invalid, sending Bad Request";
157 res->sendStatus(400);
158 return;
159 }
160
161 if (!paramRedirectUri.empty()) {
162 db.exec(
163 "update client set redirect_uri = '" + paramRedirectUri + "' where uuid = '" + paramClientId + "'",
164 [paramRedirectUri]() {
165 VLOG(1) << "Database: Set redirect_uri to " << paramRedirectUri;
166 },
167 [](const std::string& errorString, unsigned int errorNumber) {
168 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
169 });
170 }
171
172 if (!paramScope.empty()) {
173 db.exec(
174 "update client set scope = '" + paramScope + "' where uuid = '" + paramClientId + "'",
175 [paramScope]() {
176 VLOG(1) << "Database: Set scope to " << paramScope;
177 },
178 [](const std::string& errorString, unsigned int errorNumber) {
179 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
180 });
181 }
182
183 if (!paramState.empty()) {
184 db.exec(
185 "update client set state = '" + paramState + "' where uuid = '" + paramClientId + "'",
186 [paramState]() {
187 VLOG(1) << "Database: Set state to " << paramState;
188 },
189 [](const std::string& errorString, unsigned int errorNumber) {
190 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
191 });
192 }
193
194 VLOG(1) << "Auth request valid, redirecting to login";
195 std::string loginUri{"/oauth2/login"};
196 addQueryParamToUri(loginUri, "client_id", paramClientId);
197 res->redirect(loginUri);
198 });
199
200 router.get("/login", [] APPLICATION(req, res) {
201 res->sendFile("/home/rathalin/projects/snode.c/src/oauth2/authorization_server/vue-frontend-oauth2-auth-server/dist/index.html",
202 [req](int ret) {
203 if (ret != 0) {
204 PLOG(ERROR) << req->url;
205 }
206 });
207 });
208
209 router.post("/login", [&db] APPLICATION(req, res) {
210 req->getAttribute<nlohmann::json>(
211 [req, res, &db](nlohmann::json& body) {
212 db.query(
213 "select email, password_hash, password_salt, redirect_uri, state "
214 "from client "
215 "where uuid = '" +
216 req->query("client_id") + "'",
217 [req, res, &db, &body](const MYSQL_ROW row) {
218 if (row != nullptr) {
219 const std::string dbEmail{row[0]};
220 const std::string dbPasswordHash{row[1]};
221 const std::string dbPasswordSalt{row[2]};
222 const std::string dbRedirectUri{row[3]};
223 const std::string dbState{row[4]};
224 const std::string queryEmail{body["email"]};
225 const std::string queryPassword{body["password"]};
226 // Check email and password
227 if (dbEmail != queryEmail) {
228 res->status(401).send("Invalid email address");
229 } else if (dbPasswordHash != hashSha1(dbPasswordSalt + queryPassword)) {
230 res->status(401).send("Invalid password");
231 } else {
232 // Generate auth code which expires after 10 minutes
233 const unsigned int expireMinutes{10};
234 const std::string authCode{getNewUUID()};
235 db.exec(
236 "insert into token(uuid, expire_datetime) "
237 "values('" +
238 authCode + "', '" +
239 timeToString(std::chrono::system_clock::now() + std::chrono::minutes(expireMinutes)) + "')",
240 []() {
241 },
242 [res](const std::string& errorString, unsigned int errorNumber) {
243 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
244 res->sendStatus(500);
245 })
246 .query(
247 "select last_insert_id()",
248 [req, res, &db, dbState, dbRedirectUri, authCode](const MYSQL_ROW row) {
249 if (row != nullptr) {
250 db.exec(
251 "update client "
252 "set auth_code_id = '" +
253 std::string{row[0]} +
254 "' "
255 "where uuid = '" +
256 req->query("client_id") + "'",
257 [res, dbState, dbRedirectUri, authCode]() {
258 // Redirect back to the client app
259 std::string clientRedirectUri{dbRedirectUri};
260 addQueryParamToUri(clientRedirectUri, "code", authCode);
261 if (!dbState.empty()) {
262 addQueryParamToUri(clientRedirectUri, "state", dbState);
263 }
264 // Set CORS header
265 res->set("Access-Control-Allow-Origin", "*");
266 const nlohmann::json responseJson = {{"redirect_uri", clientRedirectUri}};
267 const std::string responseJsonString{responseJson.dump(4)};
268 VLOG(1) << "Sending json reponse: " << responseJsonString;
269 res->send(responseJsonString);
270 },
271 [res](const std::string& errorString, unsigned int errorNumber) {
272 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
273 res->sendStatus(500);
274 });
275 }
276 },
277 [res](const std::string& errorString, unsigned int errorNumber) {
278 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
279 res->sendStatus(500);
280 });
281 }
282 }
283 },
284 [res](const std::string& errorString, unsigned int errorNumber) {
285 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
286 res->sendStatus(500);
287 });
288 },
289 [res]([[maybe_unused]] const std::string& key) {
290 res->sendStatus(500);
291 });
292 });
293
294 router.get("/token", [&db] APPLICATION(req, res) {
295 res->set("Access-Control-Allow-Origin", "*");
296 auto queryGrantType = req->query("grant_type");
297 VLOG(1) << "GrandType: " << queryGrantType;
298 auto queryCode = req->query("code");
299 VLOG(1) << "Code: " << queryCode;
300 auto queryRedirectUri = req->query("redirect_uri");
301 VLOG(1) << "RedirectUri: " << queryRedirectUri;
302 if (queryGrantType != "authorization_code") {
303 res->status(400).send("Invalid query parameter 'grant_type', value must be 'authorization_code'");
304 return;
305 }
306 if (queryCode.length() == 0) {
307 res->status(400).send("Missing query parameter 'code'");
308 return;
309 }
310 if (queryRedirectUri.length() == 0) {
311 res->status(400).send("Missing query parameter 'redirect_uri'");
312 return;
313 }
314 db.query(
315 "select count(*) "
316 "from client "
317 "where uuid = '" +
318 req->query("client_id") +
319 "' "
320 "and redirect_uri = '" +
321 queryRedirectUri + "'",
322 [req, res, &db](const MYSQL_ROW row) {
323 if (row != nullptr) {
324 if (std::stoi(row[0]) == 0) {
325 res->status(400).send("Query param 'redirect_uri' must be the same as in the initial request");
326 } else {
327 db.query(
328 "select count(*) "
329 "from client c "
330 "join token a "
331 "on c.auth_code_id = a.id "
332 "where c.uuid = '" +
333 req->query("client_id") +
334 "' "
335 "and a.uuid = '" +
336 req->query("code") +
337 "' "
338 "and timestampdiff(second, current_timestamp(), a.expire_datetime) > 0",
339 [req, res, &db](const MYSQL_ROW row) {
340 if (row != nullptr) {
341 if (std::stoi(row[0]) == 0) {
342 res->status(401).send("Invalid auth token");
343 return;
344 }
345 // Generate access and refresh token
346 const std::string accessToken{getNewUUID()};
347 const unsigned int accessTokenExpireSeconds{60 * 60}; // 1 hour
348 const std::string refreshToken{getNewUUID()};
349 const unsigned int refreshTokenExpireSeconds{60 * 60 * 24}; // 24 hours
350 db.exec(
351 "insert into token(uuid, expire_datetime) "
352 "values('" +
353 accessToken + "', '" +
354 timeToString(std::chrono::system_clock::now() +
355 std::chrono::seconds(accessTokenExpireSeconds)) +
356 "')",
357 []() {
358 },
359 [res](const std::string& errorString, unsigned int errorNumber) {
360 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
361 res->sendStatus(500);
362 })
363 .query(
364 "select last_insert_id()",
365 [req, res, &db](const MYSQL_ROW row) {
366 if (row != nullptr) {
367 db.exec(
368 "update client "
369 "set access_token_id = '" +
370 std::string{row[0]} +
371 "' "
372 "where uuid = '" +
373 req->query("client_id") + "'",
374 []() {
375 },
376 [res](const std::string& errorString, unsigned int errorNumber) {
377 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
378 res->sendStatus(500);
379 });
380 }
381 },
382 [res](const std::string& errorString, unsigned int errorNumber) {
383 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
384 res->sendStatus(500);
385 })
386 .exec(
387 "insert into token(uuid, expire_datetime) "
388 "values('" +
389 refreshToken + "', '" +
390 timeToString(std::chrono::system_clock::now() +
391 std::chrono::seconds(refreshTokenExpireSeconds)) +
392 "')",
393 []() {
394 },
395 [res](const std::string& errorString, unsigned int errorNumber) {
396 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
397 res->sendStatus(500);
398 })
399 .query(
400 "select last_insert_id()",
401 [req, res, &db, accessToken, accessTokenExpireSeconds, refreshToken](const MYSQL_ROW row) {
402 if (row != nullptr) {
403 db.exec(
404 "update client "
405 "set refresh_token_id = '" +
406 std::string{row[0]} +
407 "' "
408 "where uuid = '" +
409 req->query("client_id") + "'",
410 [res, accessToken, accessTokenExpireSeconds, refreshToken]() {
411 // Send auth token and refresh token
412 const nlohmann::json jsonResponse = {{"access_token", accessToken},
413 {"expires_in", accessTokenExpireSeconds},
414 {"refresh_token", refreshToken}};
415 const std::string jsonResponseString{jsonResponse.dump(4)};
416 res->send(jsonResponseString);
417 },
418 [res](const std::string& errorString, unsigned int errorNumber) {
419 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
420 res->sendStatus(500);
421 });
422 }
423 },
424 [res](const std::string& errorString, unsigned int errorNumber) {
425 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
426 res->sendStatus(500);
427 });
428 }
429 },
430 [res](const std::string& errorString, unsigned int errorNumber) {
431 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
432 res->sendStatus(500);
433 });
434 }
435 }
436 },
437 [res](const std::string& errorString, unsigned int errorNumber) {
438 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
439 res->sendStatus(500);
440 });
441 });
442
443 router.post("/token/refresh", [&db] APPLICATION(req, res) {
444 res->set("Access-Control-Allow-Origin", "*");
445 auto queryClientId = req->query("client_id");
446 VLOG(1) << "ClientId: " << queryClientId;
447 auto queryGrantType = req->query("grant_type");
448 VLOG(1) << "GrandType: " << queryGrantType;
449 auto queryRefreshToken = req->query("refresh_token");
450 VLOG(1) << "RefreshToken: " << queryRefreshToken;
451 auto queryState = req->query("state");
452 VLOG(1) << "State: " << queryState;
453 if (queryGrantType.length() == 0) {
454 res->status(400).send("Missing query parameter 'grant_type'");
455 return;
456 }
457 if (queryGrantType != "refresh_token") {
458 res->status(400).send("Invalid query parameter 'grant_type', value must be 'refresh_token'");
459 return;
460 }
461 if (queryRefreshToken.empty()) {
462 res->status(400).send("Missing query parameter 'refresh_token'");
463 }
464 db.query(
465 "select count(*) "
466 "from client c "
467 "join token r "
468 "on c.refresh_token_id = r.id "
469 "where c.uuid = '" +
470 req->query("client_id") +
471 "' "
472 "and r.uuid = '" +
473 req->query("refresh_token") +
474 "' "
475 "and timestampdiff(second, current_timestamp(), r.expire_datetime) > 0",
476 [req, res, &db](const MYSQL_ROW row) {
477 if (row != nullptr) {
478 if (std::stoi(row[0]) == 0) {
479 res->status(401).send("Invalid refresh token");
480 return;
481 }
482 // Generate access token
483 std::string accessToken{getNewUUID()};
484 unsigned int accessTokenExpireSeconds{60 * 60}; // 1 hour
485 db.exec(
486 "insert into token(uuid, expire_datetime) "
487 "values('" +
488 accessToken + "', '" +
489 timeToString(std::chrono::system_clock::now() + std::chrono::seconds(accessTokenExpireSeconds)) + "')",
490 []() {
491 },
492 [res](const std::string& errorString, unsigned int errorNumber) {
493 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
494 res->sendStatus(500);
495 })
496 .query(
497 "select last_insert_id()",
498 [req, res, &db, accessToken, accessTokenExpireSeconds](const MYSQL_ROW row) {
499 if (row != nullptr) {
500 db.exec(
501 "update client "
502 "set access_token_id = '" +
503 std::string{row[0]} +
504 "' "
505 "where uuid = '" +
506 req->query("client_id") + "'",
507 [res, accessToken, accessTokenExpireSeconds]() {
508 const nlohmann::json responseJson = {{"access_token", accessToken},
509 {"expires_in", accessTokenExpireSeconds}};
510 res->send(responseJson.dump(4));
511 },
512 [res](const std::string& errorString, unsigned int errorNumber) {
513 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
514 res->sendStatus(500);
515 });
516 }
517 },
518 [res](const std::string& errorString, unsigned int errorNumber) {
519 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
520 res->sendStatus(500);
521 });
522 }
523 },
524 [res](const std::string& errorString, unsigned int errorNumber) {
525 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
526 res->sendStatus(500);
527 });
528 });
529
530 router.post("/token/validate", [&db] APPLICATION(req, res) {
531 VLOG(1) << "POST /token/validate";
532 req->getAttribute<nlohmann::json>([res, &db](nlohmann::json& jsonBody) {
533 if (!jsonBody.contains("access_token")) {
534 VLOG(1) << "Missing 'access_token' in json";
535 res->status(500).send("Missing 'access_token' in json");
536 return;
537 }
538 const std::string jsonAccessToken{jsonBody["access_token"]};
539 if (!jsonBody.contains("client_id")) {
540 VLOG(1) << "Missing 'client_id' in json";
541 res->status(500).send("Missing 'client_id' in json");
542 return;
543 }
544 const std::string jsonClientId{jsonBody["client_id"]};
545 db.query(
546 "select count(*) "
547 "from client c "
548 "join token a "
549 "on c.access_token_id = a.id "
550 "where c.uuid = '" +
551 jsonClientId +
552 "' "
553 "and a.uuid = '" +
554 jsonAccessToken + "'",
555 [res, jsonClientId, jsonAccessToken](const MYSQL_ROW row) {
556 if (row != nullptr) {
557 if (std::stoi(row[0]) == 0) {
558 const nlohmann::json errorJson = {{"error", "Invalid access token"}};
559 VLOG(1) << "Sending 401: Invalid access token '" << jsonAccessToken << "'";
560 res->status(401).send(errorJson.dump(4));
561 } else {
562 VLOG(1) << "Sending 200: Valid access token '" << jsonAccessToken << "";
563 const nlohmann::json successJson = {{"success", "Valid access token"}};
564 res->status(200).send(successJson.dump(4));
565 }
566 }
567 },
568 [res](const std::string& errorString, unsigned int errorNumber) {
569 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
570 res->sendStatus(500);
571 });
572 });
573 });
574
575 app.use("/oauth2", router);
576 app.use(express::middleware::StaticMiddleware(
577 "/home/rathalin/projects/snode.c/src/oauth2/authorization_server/vue-frontend-oauth2-auth-server/dist/"));
578
579 app.listen(8082, [](const express::legacy::in::WebApp::SocketAddress& socketAddress, const core::socket::State& state) {
580 switch (state) {
581 case core::socket::State::OK:
582 VLOG(1) << "OAuth2AuthorizationServer: listening on '" << socketAddress.toString() << "'";
583 break;
584 case core::socket::State::DISABLED:
585 VLOG(1) << "OAuth2AuthorizationServer: disabled";
586 break;
587 case core::socket::State::ERROR:
588 VLOG(1) << "OAuth2AuthorizationServer: error occurred";
589 break;
590 case core::socket::State::FATAL:
591 VLOG(1) << "OAuth2AuthorizationServer: fatal error occurred";
592 break;
593 }
594 });
595
596 return express::WebApp::start();
597}
std::string timeToString(std::chrono::time_point< std::chrono::system_clock > time)
void addQueryParamToUri(std::string &uri, const std::string &queryParamName, const std::string &queryParamValue)
std::string getNewUUID()
std::string hashSha1(const std::string &str)
#define APPLICATION(req, res)
Definition Router.h:68
#define MIDDLEWARE(req, res, next)
Definition Router.h:63
MariaDBCommandSequence & exec(const std::string &sql, const std::function< void(void)> &onExec, const std::function< void(const std::string &, unsigned int)> &onError)
MariaDBClient(const MariaDBConnectionDetails &details)
Definition Config.h:59
int main(int argc, char *argv[])