SNode.C
Loading...
Searching...
No Matches
AuthorizationServer.cpp File Reference
#include "database/mariadb/MariaDBClient.h"
#include "database/mariadb/MariaDBCommandSequence.h"
#include "express/legacy/in/WebApp.h"
#include "express/middleware/JsonMiddleware.h"
#include "express/middleware/StaticMiddleware.h"
#include "log/Logger.h"
#include "utils/sha1.h"
#include <chrono>
#include <ctime>
#include <fstream>
#include <iomanip>
#include <mysql.h>
#include <nlohmann/json.hpp>
#include <sstream>
#include <string>
Include dependency graph for AuthorizationServer.cpp:

Go to the source code of this file.

Functions

void addQueryParamToUri (std::string &uri, const std::string &queryParamName, const std::string &queryParamValue)
std::string timeToString (std::chrono::time_point< std::chrono::system_clock > time)
std::string getNewUUID ()
std::string hashSha1 (const std::string &str)
int main (int argc, char *argv[])

Function Documentation

◆ addQueryParamToUri()

void addQueryParamToUri ( std::string & uri,
const std::string & queryParamName,
const std::string & queryParamValue )

Definition at line 62 of file AuthorizationServer.cpp.

62 {
63 if (uri.find('?') == std::string::npos) {
64 uri += '?';
65 } else {
66 uri += '&';
67 }
68 uri += queryParamName + "=" + queryParamValue;
69}

◆ getNewUUID()

std::string getNewUUID ( )

Definition at line 78 of file AuthorizationServer.cpp.

78 {
79 const size_t uuidLength{36};
80 char uuidCharArray[uuidLength];
81 std::ifstream file("/proc/sys/kernel/random/uuid");
82 file.getline(uuidCharArray, uuidLength);
83 file.close();
84 return std::string{uuidCharArray};
85}

◆ hashSha1()

std::string hashSha1 ( const std::string & str)

Definition at line 87 of file AuthorizationServer.cpp.

87 {
88 utils::SHA1 checksum;
89 checksum.update(str);
90 return checksum.final();
91}
std::string final()
Definition sha1.cpp:79
void update(const std::string &s)
Definition sha1.cpp:57

◆ main()

int main ( int argc,
char * argv[] )

Definition at line 93 of file AuthorizationServer.cpp.

93 {
94 express::WebApp::init(argc, argv);
95 const express::legacy::in::WebApp app("OAuth2AuthorizationServer");
96
98 .connectionName = "authorization",
99 .hostname = "localhost",
100 .username = "rathalin",
101 .password = "rathalin",
102 .database = "oauth2",
103 .port = 3306,
104 .socket = "/run/mysqld/mysqld.sock",
105 .flags = 0,
106 };
108 if (state.error != 0) {
109 VLOG(0) << "MySQL error: " << state.errorMessage << " [" << state.error << "]";
110 } else if (state.connected) {
111 VLOG(0) << "MySQL connected";
112 } else {
113 VLOG(0) << "MySQL disconnected";
114 }
115 }};
116
118
119 const express::Router router{};
120
121 // Middleware to catch requests without a valid client_id
122 router.use([&db] MIDDLEWARE(req, res, next) {
123 const std::string queryClientId{req->query("client_id")};
124 if (queryClientId.length() > 0) {
125 db.query(
126 "select count(*) from client where uuid = '" + queryClientId + "'",
127 [req, res, next, queryClientId](const MYSQL_ROW row) {
128 if (row != nullptr) {
129 if (std::stoi(row[0]) > 0) {
130 VLOG(1) << "Valid client id '" << queryClientId << "'";
131 VLOG(1) << "Next with " << req->httpVersion << " " << req->method << " " << req->url;
132 next();
133 } else {
134 VLOG(1) << "Invalid client id '" << queryClientId << "'";
135 res->sendStatus(401);
136 }
137 }
138 },
139 [res](const std::string& errorString, unsigned int errorNumber) {
140 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
141 res->sendStatus(500);
142 });
143 } else {
144 res->status(401).send("Invalid client_id");
145 }
146 });
147
148 router.get("/authorize", [&db] APPLICATION(req, res) {
149 // REQUIRED: response_type, client_id
150 // OPTIONAL: redirect_uri, scope
151 // RECOMMENDED: state
152 const std::string paramResponseType{req->query("response_type")};
153 const std::string paramClientId{req->query("client_id")};
154 const std::string paramRedirectUri{req->query("redirect_uri")};
155 const std::string paramScope{req->query("scope")};
156 const std::string paramState{req->query("state")};
157
158 VLOG(1) << "Query params: "
159 << "response_type=" << req->query("response_type") << ", "
160 << "redirect_uri=" << req->query("redirect_uri") << ", "
161 << "scope=" << req->query("scope") << ", "
162 << "state=" << req->query("state") << "\n";
163
164 if (paramResponseType != "code") {
165 VLOG(1) << "Auth invalid, sending Bad Request";
166 res->sendStatus(400);
167 return;
168 }
169
170 if (!paramRedirectUri.empty()) {
171 db.exec(
172 "update client set redirect_uri = '" + paramRedirectUri + "' where uuid = '" + paramClientId + "'",
173 [paramRedirectUri]() {
174 VLOG(1) << "Database: Set redirect_uri to " << paramRedirectUri;
175 },
176 [](const std::string& errorString, unsigned int errorNumber) {
177 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
178 });
179 }
180
181 if (!paramScope.empty()) {
182 db.exec(
183 "update client set scope = '" + paramScope + "' where uuid = '" + paramClientId + "'",
184 [paramScope]() {
185 VLOG(1) << "Database: Set scope to " << paramScope;
186 },
187 [](const std::string& errorString, unsigned int errorNumber) {
188 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
189 });
190 }
191
192 if (!paramState.empty()) {
193 db.exec(
194 "update client set state = '" + paramState + "' where uuid = '" + paramClientId + "'",
195 [paramState]() {
196 VLOG(1) << "Database: Set state to " << paramState;
197 },
198 [](const std::string& errorString, unsigned int errorNumber) {
199 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
200 });
201 }
202
203 VLOG(1) << "Auth request valid, redirecting to login";
204 std::string loginUri{"/oauth2/login"};
205 addQueryParamToUri(loginUri, "client_id", paramClientId);
206 res->redirect(loginUri);
207 });
208
209 router.get("/login", [] APPLICATION(req, res) {
210 res->sendFile("/home/rathalin/projects/snode.c/src/oauth2/authorization_server/vue-frontend-oauth2-auth-server/dist/index.html",
211 [req](int ret) {
212 if (ret != 0) {
213 PLOG(ERROR) << req->url;
214 }
215 });
216 });
217
218 router.post("/login", [&db] APPLICATION(req, res) {
219 req->getAttribute<nlohmann::json>(
220 [req, res, &db](nlohmann::json& body) {
221 db.query(
222 "select email, password_hash, password_salt, redirect_uri, state "
223 "from client "
224 "where uuid = '" +
225 req->query("client_id") + "'",
226 [req, res, &db, &body](const MYSQL_ROW row) {
227 if (row != nullptr) {
228 const std::string dbEmail{row[0]};
229 const std::string dbPasswordHash{row[1]};
230 const std::string dbPasswordSalt{row[2]};
231 const std::string dbRedirectUri{row[3]};
232 const std::string dbState{row[4]};
233 const std::string queryEmail{body["email"]};
234 const std::string queryPassword{body["password"]};
235 // Check email and password
236 if (dbEmail != queryEmail) {
237 res->status(401).send("Invalid email address");
238 } else if (dbPasswordHash != hashSha1(dbPasswordSalt + queryPassword)) {
239 res->status(401).send("Invalid password");
240 } else {
241 // Generate auth code which expires after 10 minutes
242 const unsigned int expireMinutes{10};
243 const std::string authCode{getNewUUID()};
244 db.exec(
245 "insert into token(uuid, expire_datetime) "
246 "values('" +
247 authCode + "', '" +
248 timeToString(std::chrono::system_clock::now() + std::chrono::minutes(expireMinutes)) + "')",
249 []() {
250 },
251 [res](const std::string& errorString, unsigned int errorNumber) {
252 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
253 res->sendStatus(500);
254 })
255 .query(
256 "select last_insert_id()",
257 [req, res, &db, dbState, dbRedirectUri, authCode](const MYSQL_ROW row) {
258 if (row != nullptr) {
259 db.exec(
260 "update client "
261 "set auth_code_id = '" +
262 std::string{row[0]} +
263 "' "
264 "where uuid = '" +
265 req->query("client_id") + "'",
266 [res, dbState, dbRedirectUri, authCode]() {
267 // Redirect back to the client app
268 std::string clientRedirectUri{dbRedirectUri};
269 addQueryParamToUri(clientRedirectUri, "code", authCode);
270 if (!dbState.empty()) {
271 addQueryParamToUri(clientRedirectUri, "state", dbState);
272 }
273 // Set CORS header
274 res->set("Access-Control-Allow-Origin", "*");
275 const nlohmann::json responseJson = {{"redirect_uri", clientRedirectUri}};
276 const std::string responseJsonString{responseJson.dump(4)};
277 VLOG(1) << "Sending json reponse: " << responseJsonString;
278 res->send(responseJsonString);
279 },
280 [res](const std::string& errorString, unsigned int errorNumber) {
281 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
282 res->sendStatus(500);
283 });
284 }
285 },
286 [res](const std::string& errorString, unsigned int errorNumber) {
287 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
288 res->sendStatus(500);
289 });
290 }
291 }
292 },
293 [res](const std::string& errorString, unsigned int errorNumber) {
294 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
295 res->sendStatus(500);
296 });
297 },
298 [res]([[maybe_unused]] const std::string& key) {
299 res->sendStatus(500);
300 });
301 });
302
303 router.get("/token", [&db] APPLICATION(req, res) {
304 res->set("Access-Control-Allow-Origin", "*");
305 auto queryGrantType = req->query("grant_type");
306 VLOG(1) << "GrandType: " << queryGrantType;
307 auto queryCode = req->query("code");
308 VLOG(1) << "Code: " << queryCode;
309 auto queryRedirectUri = req->query("redirect_uri");
310 VLOG(1) << "RedirectUri: " << queryRedirectUri;
311 if (queryGrantType != "authorization_code") {
312 res->status(400).send("Invalid query parameter 'grant_type', value must be 'authorization_code'");
313 return;
314 }
315 if (queryCode.length() == 0) {
316 res->status(400).send("Missing query parameter 'code'");
317 return;
318 }
319 if (queryRedirectUri.length() == 0) {
320 res->status(400).send("Missing query parameter 'redirect_uri'");
321 return;
322 }
323 db.query(
324 "select count(*) "
325 "from client "
326 "where uuid = '" +
327 req->query("client_id") +
328 "' "
329 "and redirect_uri = '" +
330 queryRedirectUri + "'",
331 [req, res, &db](const MYSQL_ROW row) {
332 if (row != nullptr) {
333 if (std::stoi(row[0]) == 0) {
334 res->status(400).send("Query param 'redirect_uri' must be the same as in the initial request");
335 } else {
336 db.query(
337 "select count(*) "
338 "from client c "
339 "join token a "
340 "on c.auth_code_id = a.id "
341 "where c.uuid = '" +
342 req->query("client_id") +
343 "' "
344 "and a.uuid = '" +
345 req->query("code") +
346 "' "
347 "and timestampdiff(second, current_timestamp(), a.expire_datetime) > 0",
348 [req, res, &db](const MYSQL_ROW row) {
349 if (row != nullptr) {
350 if (std::stoi(row[0]) == 0) {
351 res->status(401).send("Invalid auth token");
352 return;
353 }
354 // Generate access and refresh token
355 const std::string accessToken{getNewUUID()};
356 const unsigned int accessTokenExpireSeconds{60 * 60}; // 1 hour
357 const std::string refreshToken{getNewUUID()};
358 const unsigned int refreshTokenExpireSeconds{60 * 60 * 24}; // 24 hours
359 db.exec(
360 "insert into token(uuid, expire_datetime) "
361 "values('" +
362 accessToken + "', '" +
363 timeToString(std::chrono::system_clock::now() +
364 std::chrono::seconds(accessTokenExpireSeconds)) +
365 "')",
366 []() {
367 },
368 [res](const std::string& errorString, unsigned int errorNumber) {
369 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
370 res->sendStatus(500);
371 })
372 .query(
373 "select last_insert_id()",
374 [req, res, &db](const MYSQL_ROW row) {
375 if (row != nullptr) {
376 db.exec(
377 "update client "
378 "set access_token_id = '" +
379 std::string{row[0]} +
380 "' "
381 "where uuid = '" +
382 req->query("client_id") + "'",
383 []() {
384 },
385 [res](const std::string& errorString, unsigned int errorNumber) {
386 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
387 res->sendStatus(500);
388 });
389 }
390 },
391 [res](const std::string& errorString, unsigned int errorNumber) {
392 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
393 res->sendStatus(500);
394 })
395 .exec(
396 "insert into token(uuid, expire_datetime) "
397 "values('" +
398 refreshToken + "', '" +
399 timeToString(std::chrono::system_clock::now() +
400 std::chrono::seconds(refreshTokenExpireSeconds)) +
401 "')",
402 []() {
403 },
404 [res](const std::string& errorString, unsigned int errorNumber) {
405 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
406 res->sendStatus(500);
407 })
408 .query(
409 "select last_insert_id()",
410 [req, res, &db, accessToken, accessTokenExpireSeconds, refreshToken](const MYSQL_ROW row) {
411 if (row != nullptr) {
412 db.exec(
413 "update client "
414 "set refresh_token_id = '" +
415 std::string{row[0]} +
416 "' "
417 "where uuid = '" +
418 req->query("client_id") + "'",
419 [res, accessToken, accessTokenExpireSeconds, refreshToken]() {
420 // Send auth token and refresh token
421 const nlohmann::json jsonResponse = {{"access_token", accessToken},
422 {"expires_in", accessTokenExpireSeconds},
423 {"refresh_token", refreshToken}};
424 const std::string jsonResponseString{jsonResponse.dump(4)};
425 res->send(jsonResponseString);
426 },
427 [res](const std::string& errorString, unsigned int errorNumber) {
428 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
429 res->sendStatus(500);
430 });
431 }
432 },
433 [res](const std::string& errorString, unsigned int errorNumber) {
434 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
435 res->sendStatus(500);
436 });
437 }
438 },
439 [res](const std::string& errorString, unsigned int errorNumber) {
440 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
441 res->sendStatus(500);
442 });
443 }
444 }
445 },
446 [res](const std::string& errorString, unsigned int errorNumber) {
447 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
448 res->sendStatus(500);
449 });
450 });
451
452 router.post("/token/refresh", [&db] APPLICATION(req, res) {
453 res->set("Access-Control-Allow-Origin", "*");
454 auto queryClientId = req->query("client_id");
455 VLOG(1) << "ClientId: " << queryClientId;
456 auto queryGrantType = req->query("grant_type");
457 VLOG(1) << "GrandType: " << queryGrantType;
458 auto queryRefreshToken = req->query("refresh_token");
459 VLOG(1) << "RefreshToken: " << queryRefreshToken;
460 auto queryState = req->query("state");
461 VLOG(1) << "State: " << queryState;
462 if (queryGrantType.length() == 0) {
463 res->status(400).send("Missing query parameter 'grant_type'");
464 return;
465 }
466 if (queryGrantType != "refresh_token") {
467 res->status(400).send("Invalid query parameter 'grant_type', value must be 'refresh_token'");
468 return;
469 }
470 if (queryRefreshToken.empty()) {
471 res->status(400).send("Missing query parameter 'refresh_token'");
472 }
473 db.query(
474 "select count(*) "
475 "from client c "
476 "join token r "
477 "on c.refresh_token_id = r.id "
478 "where c.uuid = '" +
479 req->query("client_id") +
480 "' "
481 "and r.uuid = '" +
482 req->query("refresh_token") +
483 "' "
484 "and timestampdiff(second, current_timestamp(), r.expire_datetime) > 0",
485 [req, res, &db](const MYSQL_ROW row) {
486 if (row != nullptr) {
487 if (std::stoi(row[0]) == 0) {
488 res->status(401).send("Invalid refresh token");
489 return;
490 }
491 // Generate access token
492 std::string accessToken{getNewUUID()};
493 unsigned int accessTokenExpireSeconds{60 * 60}; // 1 hour
494 db.exec(
495 "insert into token(uuid, expire_datetime) "
496 "values('" +
497 accessToken + "', '" +
498 timeToString(std::chrono::system_clock::now() + std::chrono::seconds(accessTokenExpireSeconds)) + "')",
499 []() {
500 },
501 [res](const std::string& errorString, unsigned int errorNumber) {
502 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
503 res->sendStatus(500);
504 })
505 .query(
506 "select last_insert_id()",
507 [req, res, &db, accessToken, accessTokenExpireSeconds](const MYSQL_ROW row) {
508 if (row != nullptr) {
509 db.exec(
510 "update client "
511 "set access_token_id = '" +
512 std::string{row[0]} +
513 "' "
514 "where uuid = '" +
515 req->query("client_id") + "'",
516 [res, accessToken, accessTokenExpireSeconds]() {
517 const nlohmann::json responseJson = {{"access_token", accessToken},
518 {"expires_in", accessTokenExpireSeconds}};
519 res->send(responseJson.dump(4));
520 },
521 [res](const std::string& errorString, unsigned int errorNumber) {
522 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
523 res->sendStatus(500);
524 });
525 }
526 },
527 [res](const std::string& errorString, unsigned int errorNumber) {
528 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
529 res->sendStatus(500);
530 });
531 }
532 },
533 [res](const std::string& errorString, unsigned int errorNumber) {
534 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
535 res->sendStatus(500);
536 });
537 });
538
539 router.post("/token/validate", [&db] APPLICATION(req, res) {
540 VLOG(1) << "POST /token/validate";
541 req->getAttribute<nlohmann::json>([res, &db](nlohmann::json& jsonBody) {
542 if (!jsonBody.contains("access_token")) {
543 VLOG(1) << "Missing 'access_token' in json";
544 res->status(500).send("Missing 'access_token' in json");
545 return;
546 }
547 const std::string jsonAccessToken{jsonBody["access_token"]};
548 if (!jsonBody.contains("client_id")) {
549 VLOG(1) << "Missing 'client_id' in json";
550 res->status(500).send("Missing 'client_id' in json");
551 return;
552 }
553 const std::string jsonClientId{jsonBody["client_id"]};
554 db.query(
555 "select count(*) "
556 "from client c "
557 "join token a "
558 "on c.access_token_id = a.id "
559 "where c.uuid = '" +
560 jsonClientId +
561 "' "
562 "and a.uuid = '" +
563 jsonAccessToken + "'",
564 [res, jsonClientId, jsonAccessToken](const MYSQL_ROW row) {
565 if (row != nullptr) {
566 if (std::stoi(row[0]) == 0) {
567 const nlohmann::json errorJson = {{"error", "Invalid access token"}};
568 VLOG(1) << "Sending 401: Invalid access token '" << jsonAccessToken << "'";
569 res->status(401).send(errorJson.dump(4));
570 } else {
571 VLOG(1) << "Sending 200: Valid access token '" << jsonAccessToken << "";
572 const nlohmann::json successJson = {{"success", "Valid access token"}};
573 res->status(200).send(successJson.dump(4));
574 }
575 }
576 },
577 [res](const std::string& errorString, unsigned int errorNumber) {
578 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
579 res->sendStatus(500);
580 });
581 });
582 });
583
584 app.use("/oauth2", router);
586 "/home/rathalin/projects/snode.c/src/oauth2/authorization_server/vue-frontend-oauth2-auth-server/dist/"));
587
588 app.listen(8082, [](const express::legacy::in::WebApp::SocketAddress& socketAddress, const core::socket::State& state) {
589 switch (state) {
591 VLOG(1) << "OAuth2AuthorizationServer: listening on '" << socketAddress.toString() << "'";
592 break;
594 VLOG(1) << "OAuth2AuthorizationServer: disabled";
595 break;
597 VLOG(1) << "OAuth2AuthorizationServer: error occurred";
598 break;
600 VLOG(1) << "OAuth2AuthorizationServer: fatal error occurred";
601 break;
602 }
603 });
604
605 return express::WebApp::start();
606}
void addQueryParamToUri(std::string &uri, const std::string &queryParamName, const std::string &queryParamValue)
#define APPLICATION(req, res)
Definition Router.h:68
#define MIDDLEWARE(req, res, next)
Definition Router.h:63
static constexpr int DISABLED
Definition State.h:56
static constexpr int ERROR
Definition State.h:57
static constexpr int FATAL
Definition State.h:58
static constexpr int OK
Definition State.h:55
MariaDBCommandSequence & exec(const std::string &sql, const std::function< void(void)> &onExec, const std::function< void(const std::string &, unsigned int)> &onError)
MariaDBCommandSequence & query(const std::string &sql, const std::function< void(const MYSQL_ROW)> &onQuery, const std::function< void(const std::string &, unsigned int)> &onError)
typename Server::SocketAddress SocketAddress
Definition WebAppT.h:70
static void init(int argc, char *argv[])
Definition WebApp.cpp:56
static int start(const utils::Timeval &timeOut={LONG_MAX, 0})
Definition WebApp.cpp:60
WebAppT< web::http::legacy::in::Server > WebApp
Definition WebApp.h:56
Router router(database::mariadb::MariaDBClient &db)
Definition testregex.cpp:68

References database::mariadb::MariaDBState::connected, database::mariadb::MariaDBConnectionDetails::connectionName, database::mariadb::MariaDBConnectionDetails::database, database::mariadb::MariaDBState::error, database::mariadb::MariaDBState::errorMessage, database::mariadb::MariaDBClientASyncAPI::exec(), database::mariadb::MariaDBConnectionDetails::flags, database::mariadb::MariaDBConnectionDetails::hostname, database::mariadb::MariaDBConnectionDetails::password, database::mariadb::MariaDBConnectionDetails::port, database::mariadb::MariaDBConnectionDetails::socket, and database::mariadb::MariaDBConnectionDetails::username.

Here is the call graph for this function:

◆ timeToString()

std::string timeToString ( std::chrono::time_point< std::chrono::system_clock > time)

Definition at line 71 of file AuthorizationServer.cpp.

71 {
72 auto in_time_t = std::chrono::system_clock::to_time_t(time);
73 std::stringstream ss;
74 ss << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d %X");
75 return ss.str();
76}