SNode.C
Loading...
Searching...
No Matches
AuthorizationServer.cpp File Reference
#include "database/mariadb/MariaDBClient.h"
#include "database/mariadb/MariaDBCommandSequence.h"
#include "express/legacy/in/WebApp.h"
#include "express/middleware/JsonMiddleware.h"
#include "express/middleware/StaticMiddleware.h"
#include "log/Logger.h"
#include "utils/sha1.h"
#include <chrono>
#include <ctime>
#include <fstream>
#include <iomanip>
#include <mysql.h>
#include <nlohmann/json.hpp>
#include <sstream>
#include <string>
Include dependency graph for AuthorizationServer.cpp:

Go to the source code of this file.

Functions

void addQueryParamToUri (std::string &uri, const std::string &queryParamName, const std::string &queryParamValue)
 
std::string timeToString (std::chrono::time_point< std::chrono::system_clock > time)
 
std::string getNewUUID ()
 
std::string hashSha1 (const std::string &str)
 
int main (int argc, char *argv[])
 

Function Documentation

◆ addQueryParamToUri()

void addQueryParamToUri ( std::string &  uri,
const std::string &  queryParamName,
const std::string &  queryParamValue 
)

Definition at line 62 of file AuthorizationServer.cpp.

62 {
63 if (uri.find('?') == std::string::npos) {
64 uri += '?';
65 } else {
66 uri += '&';
67 }
68 uri += queryParamName + "=" + queryParamValue;
69}

◆ getNewUUID()

std::string getNewUUID ( )

Definition at line 78 of file AuthorizationServer.cpp.

78 {
79 const size_t uuidLength{36};
80 char uuidCharArray[uuidLength];
81 std::ifstream file("/proc/sys/kernel/random/uuid");
82 file.getline(uuidCharArray, uuidLength);
83 file.close();
84 return std::string{uuidCharArray};
85}

◆ hashSha1()

std::string hashSha1 ( const std::string &  str)

Definition at line 87 of file AuthorizationServer.cpp.

87 {
88 utils::SHA1 checksum;
89 checksum.update(str);
90 return checksum.final();
91}
std::string final()
Definition sha1.cpp:79
void update(const std::string &s)
Definition sha1.cpp:57

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 93 of file AuthorizationServer.cpp.

93 {
94 express::WebApp::init(argc, argv);
95 const express::legacy::in::WebApp app("OAuth2AuthorizationServer");
96
98 .hostname = "localhost",
99 .username = "rathalin",
100 .password = "rathalin",
101 .database = "oauth2",
102 .port = 3306,
103 .socket = "/run/mysqld/mysqld.sock",
104 .flags = 0,
105 };
107
109
110 const express::Router router{};
111
112 // Middleware to catch requests without a valid client_id
113 router.use([&db] MIDDLEWARE(req, res, next) {
114 const std::string queryClientId{req->query("client_id")};
115 if (queryClientId.length() > 0) {
116 db.query(
117 "select count(*) from client where uuid = '" + queryClientId + "'",
118 [req, res, next, queryClientId](const MYSQL_ROW row) {
119 if (row != nullptr) {
120 if (std::stoi(row[0]) > 0) {
121 VLOG(1) << "Valid client id '" << queryClientId << "'";
122 VLOG(1) << "Next with " << req->httpVersion << " " << req->method << " " << req->url;
123 next();
124 } else {
125 VLOG(1) << "Invalid client id '" << queryClientId << "'";
126 res->sendStatus(401);
127 }
128 }
129 },
130 [res](const std::string& errorString, unsigned int errorNumber) {
131 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
132 res->sendStatus(500);
133 });
134 } else {
135 res->status(401).send("Invalid client_id");
136 }
137 });
138
139 router.get("/authorize", [&db] APPLICATION(req, res) {
140 // REQUIRED: response_type, client_id
141 // OPTIONAL: redirect_uri, scope
142 // RECOMMENDED: state
143 const std::string paramResponseType{req->query("response_type")};
144 const std::string paramClientId{req->query("client_id")};
145 const std::string paramRedirectUri{req->query("redirect_uri")};
146 const std::string paramScope{req->query("scope")};
147 const std::string paramState{req->query("state")};
148
149 VLOG(1) << "Query params: "
150 << "response_type=" << req->query("response_type") << ", "
151 << "redirect_uri=" << req->query("redirect_uri") << ", "
152 << "scope=" << req->query("scope") << ", "
153 << "state=" << req->query("state") << "\n";
154
155 if (paramResponseType != "code") {
156 VLOG(1) << "Auth invalid, sending Bad Request";
157 res->sendStatus(400);
158 return;
159 }
160
161 if (!paramRedirectUri.empty()) {
162 db.exec(
163 "update client set redirect_uri = '" + paramRedirectUri + "' where uuid = '" + paramClientId + "'",
164 [paramRedirectUri]() {
165 VLOG(1) << "Database: Set redirect_uri to " << paramRedirectUri;
166 },
167 [](const std::string& errorString, unsigned int errorNumber) {
168 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
169 });
170 }
171
172 if (!paramScope.empty()) {
173 db.exec(
174 "update client set scope = '" + paramScope + "' where uuid = '" + paramClientId + "'",
175 [paramScope]() {
176 VLOG(1) << "Database: Set scope to " << paramScope;
177 },
178 [](const std::string& errorString, unsigned int errorNumber) {
179 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
180 });
181 }
182
183 if (!paramState.empty()) {
184 db.exec(
185 "update client set state = '" + paramState + "' where uuid = '" + paramClientId + "'",
186 [paramState]() {
187 VLOG(1) << "Database: Set state to " << paramState;
188 },
189 [](const std::string& errorString, unsigned int errorNumber) {
190 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
191 });
192 }
193
194 VLOG(1) << "Auth request valid, redirecting to login";
195 std::string loginUri{"/oauth2/login"};
196 addQueryParamToUri(loginUri, "client_id", paramClientId);
197 res->redirect(loginUri);
198 });
199
200 router.get("/login", [] APPLICATION(req, res) {
201 res->sendFile("/home/rathalin/projects/snode.c/src/oauth2/authorization_server/vue-frontend-oauth2-auth-server/dist/index.html",
202 [req](int ret) {
203 if (ret != 0) {
204 PLOG(ERROR) << req->url;
205 }
206 });
207 });
208
209 router.post("/login", [&db] APPLICATION(req, res) {
210 req->getAttribute<nlohmann::json>(
211 [req, res, &db](nlohmann::json& body) {
212 db.query(
213 "select email, password_hash, password_salt, redirect_uri, state "
214 "from client "
215 "where uuid = '" +
216 req->query("client_id") + "'",
217 [req, res, &db, &body](const MYSQL_ROW row) {
218 if (row != nullptr) {
219 const std::string dbEmail{row[0]};
220 const std::string dbPasswordHash{row[1]};
221 const std::string dbPasswordSalt{row[2]};
222 const std::string dbRedirectUri{row[3]};
223 const std::string dbState{row[4]};
224 const std::string queryEmail{body["email"]};
225 const std::string queryPassword{body["password"]};
226 // Check email and password
227 if (dbEmail != queryEmail) {
228 res->status(401).send("Invalid email address");
229 } else if (dbPasswordHash != hashSha1(dbPasswordSalt + queryPassword)) {
230 res->status(401).send("Invalid password");
231 } else {
232 // Generate auth code which expires after 10 minutes
233 const unsigned int expireMinutes{10};
234 const std::string authCode{getNewUUID()};
235 db.exec(
236 "insert into token(uuid, expire_datetime) "
237 "values('" +
238 authCode + "', '" +
239 timeToString(std::chrono::system_clock::now() + std::chrono::minutes(expireMinutes)) + "')",
240 []() {
241 },
242 [res](const std::string& errorString, unsigned int errorNumber) {
243 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
244 res->sendStatus(500);
245 })
246 .query(
247 "select last_insert_id()",
248 [req, res, &db, dbState, dbRedirectUri, authCode](const MYSQL_ROW row) {
249 if (row != nullptr) {
250 db.exec(
251 "update client "
252 "set auth_code_id = '" +
253 std::string{row[0]} +
254 "' "
255 "where uuid = '" +
256 req->query("client_id") + "'",
257 [res, dbState, dbRedirectUri, authCode]() {
258 // Redirect back to the client app
259 std::string clientRedirectUri{dbRedirectUri};
260 addQueryParamToUri(clientRedirectUri, "code", authCode);
261 if (!dbState.empty()) {
262 addQueryParamToUri(clientRedirectUri, "state", dbState);
263 }
264 // Set CORS header
265 res->set("Access-Control-Allow-Origin", "*");
266 const nlohmann::json responseJson = {{"redirect_uri", clientRedirectUri}};
267 const std::string responseJsonString{responseJson.dump(4)};
268 VLOG(1) << "Sending json reponse: " << responseJsonString;
269 res->send(responseJsonString);
270 },
271 [res](const std::string& errorString, unsigned int errorNumber) {
272 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
273 res->sendStatus(500);
274 });
275 }
276 },
277 [res](const std::string& errorString, unsigned int errorNumber) {
278 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
279 res->sendStatus(500);
280 });
281 }
282 }
283 },
284 [res](const std::string& errorString, unsigned int errorNumber) {
285 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
286 res->sendStatus(500);
287 });
288 },
289 [res]([[maybe_unused]] const std::string& key) {
290 res->sendStatus(500);
291 });
292 });
293
294 router.get("/token", [&db] APPLICATION(req, res) {
295 res->set("Access-Control-Allow-Origin", "*");
296 auto queryGrantType = req->query("grant_type");
297 VLOG(1) << "GrandType: " << queryGrantType;
298 auto queryCode = req->query("code");
299 VLOG(1) << "Code: " << queryCode;
300 auto queryRedirectUri = req->query("redirect_uri");
301 VLOG(1) << "RedirectUri: " << queryRedirectUri;
302 if (queryGrantType != "authorization_code") {
303 res->status(400).send("Invalid query parameter 'grant_type', value must be 'authorization_code'");
304 return;
305 }
306 if (queryCode.length() == 0) {
307 res->status(400).send("Missing query parameter 'code'");
308 return;
309 }
310 if (queryRedirectUri.length() == 0) {
311 res->status(400).send("Missing query parameter 'redirect_uri'");
312 return;
313 }
314 db.query(
315 "select count(*) "
316 "from client "
317 "where uuid = '" +
318 req->query("client_id") +
319 "' "
320 "and redirect_uri = '" +
321 queryRedirectUri + "'",
322 [req, res, &db](const MYSQL_ROW row) {
323 if (row != nullptr) {
324 if (std::stoi(row[0]) == 0) {
325 res->status(400).send("Query param 'redirect_uri' must be the same as in the initial request");
326 } else {
327 db.query(
328 "select count(*) "
329 "from client c "
330 "join token a "
331 "on c.auth_code_id = a.id "
332 "where c.uuid = '" +
333 req->query("client_id") +
334 "' "
335 "and a.uuid = '" +
336 req->query("code") +
337 "' "
338 "and timestampdiff(second, current_timestamp(), a.expire_datetime) > 0",
339 [req, res, &db](const MYSQL_ROW row) {
340 if (row != nullptr) {
341 if (std::stoi(row[0]) == 0) {
342 res->status(401).send("Invalid auth token");
343 return;
344 }
345 // Generate access and refresh token
346 const std::string accessToken{getNewUUID()};
347 const unsigned int accessTokenExpireSeconds{60 * 60}; // 1 hour
348 const std::string refreshToken{getNewUUID()};
349 const unsigned int refreshTokenExpireSeconds{60 * 60 * 24}; // 24 hours
350 db.exec(
351 "insert into token(uuid, expire_datetime) "
352 "values('" +
353 accessToken + "', '" +
354 timeToString(std::chrono::system_clock::now() +
355 std::chrono::seconds(accessTokenExpireSeconds)) +
356 "')",
357 []() {
358 },
359 [res](const std::string& errorString, unsigned int errorNumber) {
360 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
361 res->sendStatus(500);
362 })
363 .query(
364 "select last_insert_id()",
365 [req, res, &db](const MYSQL_ROW row) {
366 if (row != nullptr) {
367 db.exec(
368 "update client "
369 "set access_token_id = '" +
370 std::string{row[0]} +
371 "' "
372 "where uuid = '" +
373 req->query("client_id") + "'",
374 []() {
375 },
376 [res](const std::string& errorString, unsigned int errorNumber) {
377 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
378 res->sendStatus(500);
379 });
380 }
381 },
382 [res](const std::string& errorString, unsigned int errorNumber) {
383 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
384 res->sendStatus(500);
385 })
386 .exec(
387 "insert into token(uuid, expire_datetime) "
388 "values('" +
389 refreshToken + "', '" +
390 timeToString(std::chrono::system_clock::now() +
391 std::chrono::seconds(refreshTokenExpireSeconds)) +
392 "')",
393 []() {
394 },
395 [res](const std::string& errorString, unsigned int errorNumber) {
396 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
397 res->sendStatus(500);
398 })
399 .query(
400 "select last_insert_id()",
401 [req, res, &db, accessToken, accessTokenExpireSeconds, refreshToken](const MYSQL_ROW row) {
402 if (row != nullptr) {
403 db.exec(
404 "update client "
405 "set refresh_token_id = '" +
406 std::string{row[0]} +
407 "' "
408 "where uuid = '" +
409 req->query("client_id") + "'",
410 [res, accessToken, accessTokenExpireSeconds, refreshToken]() {
411 // Send auth token and refresh token
412 const nlohmann::json jsonResponse = {{"access_token", accessToken},
413 {"expires_in", accessTokenExpireSeconds},
414 {"refresh_token", refreshToken}};
415 const std::string jsonResponseString{jsonResponse.dump(4)};
416 res->send(jsonResponseString);
417 },
418 [res](const std::string& errorString, unsigned int errorNumber) {
419 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
420 res->sendStatus(500);
421 });
422 }
423 },
424 [res](const std::string& errorString, unsigned int errorNumber) {
425 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
426 res->sendStatus(500);
427 });
428 }
429 },
430 [res](const std::string& errorString, unsigned int errorNumber) {
431 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
432 res->sendStatus(500);
433 });
434 }
435 }
436 },
437 [res](const std::string& errorString, unsigned int errorNumber) {
438 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
439 res->sendStatus(500);
440 });
441 });
442
443 router.post("/token/refresh", [&db] APPLICATION(req, res) {
444 res->set("Access-Control-Allow-Origin", "*");
445 auto queryClientId = req->query("client_id");
446 VLOG(1) << "ClientId: " << queryClientId;
447 auto queryGrantType = req->query("grant_type");
448 VLOG(1) << "GrandType: " << queryGrantType;
449 auto queryRefreshToken = req->query("refresh_token");
450 VLOG(1) << "RefreshToken: " << queryRefreshToken;
451 auto queryState = req->query("state");
452 VLOG(1) << "State: " << queryState;
453 if (queryGrantType.length() == 0) {
454 res->status(400).send("Missing query parameter 'grant_type'");
455 return;
456 }
457 if (queryGrantType != "refresh_token") {
458 res->status(400).send("Invalid query parameter 'grant_type', value must be 'refresh_token'");
459 return;
460 }
461 if (queryRefreshToken.empty()) {
462 res->status(400).send("Missing query parameter 'refresh_token'");
463 }
464 db.query(
465 "select count(*) "
466 "from client c "
467 "join token r "
468 "on c.refresh_token_id = r.id "
469 "where c.uuid = '" +
470 req->query("client_id") +
471 "' "
472 "and r.uuid = '" +
473 req->query("refresh_token") +
474 "' "
475 "and timestampdiff(second, current_timestamp(), r.expire_datetime) > 0",
476 [req, res, &db](const MYSQL_ROW row) {
477 if (row != nullptr) {
478 if (std::stoi(row[0]) == 0) {
479 res->status(401).send("Invalid refresh token");
480 return;
481 }
482 // Generate access token
483 std::string accessToken{getNewUUID()};
484 unsigned int accessTokenExpireSeconds{60 * 60}; // 1 hour
485 db.exec(
486 "insert into token(uuid, expire_datetime) "
487 "values('" +
488 accessToken + "', '" +
489 timeToString(std::chrono::system_clock::now() + std::chrono::seconds(accessTokenExpireSeconds)) + "')",
490 []() {
491 },
492 [res](const std::string& errorString, unsigned int errorNumber) {
493 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
494 res->sendStatus(500);
495 })
496 .query(
497 "select last_insert_id()",
498 [req, res, &db, accessToken, accessTokenExpireSeconds](const MYSQL_ROW row) {
499 if (row != nullptr) {
500 db.exec(
501 "update client "
502 "set access_token_id = '" +
503 std::string{row[0]} +
504 "' "
505 "where uuid = '" +
506 req->query("client_id") + "'",
507 [res, accessToken, accessTokenExpireSeconds]() {
508 const nlohmann::json responseJson = {{"access_token", accessToken},
509 {"expires_in", accessTokenExpireSeconds}};
510 res->send(responseJson.dump(4));
511 },
512 [res](const std::string& errorString, unsigned int errorNumber) {
513 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
514 res->sendStatus(500);
515 });
516 }
517 },
518 [res](const std::string& errorString, unsigned int errorNumber) {
519 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
520 res->sendStatus(500);
521 });
522 }
523 },
524 [res](const std::string& errorString, unsigned int errorNumber) {
525 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
526 res->sendStatus(500);
527 });
528 });
529
530 router.post("/token/validate", [&db] APPLICATION(req, res) {
531 VLOG(1) << "POST /token/validate";
532 req->getAttribute<nlohmann::json>([res, &db](nlohmann::json& jsonBody) {
533 if (!jsonBody.contains("access_token")) {
534 VLOG(1) << "Missing 'access_token' in json";
535 res->status(500).send("Missing 'access_token' in json");
536 return;
537 }
538 const std::string jsonAccessToken{jsonBody["access_token"]};
539 if (!jsonBody.contains("client_id")) {
540 VLOG(1) << "Missing 'client_id' in json";
541 res->status(500).send("Missing 'client_id' in json");
542 return;
543 }
544 const std::string jsonClientId{jsonBody["client_id"]};
545 db.query(
546 "select count(*) "
547 "from client c "
548 "join token a "
549 "on c.access_token_id = a.id "
550 "where c.uuid = '" +
551 jsonClientId +
552 "' "
553 "and a.uuid = '" +
554 jsonAccessToken + "'",
555 [res, jsonClientId, jsonAccessToken](const MYSQL_ROW row) {
556 if (row != nullptr) {
557 if (std::stoi(row[0]) == 0) {
558 const nlohmann::json errorJson = {{"error", "Invalid access token"}};
559 VLOG(1) << "Sending 401: Invalid access token '" << jsonAccessToken << "'";
560 res->status(401).send(errorJson.dump(4));
561 } else {
562 VLOG(1) << "Sending 200: Valid access token '" << jsonAccessToken << "";
563 const nlohmann::json successJson = {{"success", "Valid access token"}};
564 res->status(200).send(successJson.dump(4));
565 }
566 }
567 },
568 [res](const std::string& errorString, unsigned int errorNumber) {
569 VLOG(1) << "Database error: " << errorString << " : " << errorNumber;
570 res->sendStatus(500);
571 });
572 });
573 });
574
575 app.use("/oauth2", router);
577 "/home/rathalin/projects/snode.c/src/oauth2/authorization_server/vue-frontend-oauth2-auth-server/dist/"));
578
579 app.listen(8082, [](const express::legacy::in::WebApp::SocketAddress& socketAddress, const core::socket::State& state) {
580 switch (state) {
582 VLOG(1) << "OAuth2AuthorizationServer: listening on '" << socketAddress.toString() << "'";
583 break;
585 VLOG(1) << "OAuth2AuthorizationServer: disabled";
586 break;
588 VLOG(1) << "OAuth2AuthorizationServer: error occurred";
589 break;
591 VLOG(1) << "OAuth2AuthorizationServer: fatal error occurred";
592 break;
593 }
594 });
595
596 return express::WebApp::start();
597}
std::string timeToString(std::chrono::time_point< std::chrono::system_clock > time)
void addQueryParamToUri(std::string &uri, const std::string &queryParamName, const std::string &queryParamValue)
std::string getNewUUID()
std::string hashSha1(const std::string &str)
#define APPLICATION(req, res)
Definition Router.h:68
#define MIDDLEWARE(req, res, next)
Definition Router.h:63
static constexpr int DISABLED
Definition State.h:56
static constexpr int ERROR
Definition State.h:57
static constexpr int FATAL
Definition State.h:58
static constexpr int OK
Definition State.h:55
Route & use(const std::function< void(const std::shared_ptr< Request > &, const std::shared_ptr< Response > &)> &lambda) const
Definition Route.cpp:131
Route & get(const std::function< void(const std::shared_ptr< Request > &, const std::shared_ptr< Response > &)> &lambda) const
Definition Route.cpp:133
Route & post(const std::function< void(const std::shared_ptr< Request > &, const std::shared_ptr< Response > &)> &lambda) const
Definition Route.cpp:135
Route & use(const Router &router) const
Definition Router.cpp:94
typename Server::SocketAddress SocketAddress
Definition WebAppT.h:70
static void init(int argc, char *argv[])
Definition WebApp.cpp:56
static int start(const utils::Timeval &timeOut={LONG_MAX, 0})
Definition WebApp.cpp:60
Router router(database::mariadb::MariaDBClient &db)
Definition testregex.cpp:68

References database::mariadb::MariaDBConnectionDetails::database, database::mariadb::MariaDBClientASyncAPI::exec(), database::mariadb::MariaDBConnectionDetails::flags, database::mariadb::MariaDBConnectionDetails::hostname, database::mariadb::MariaDBClient::MariaDBClient(), database::mariadb::MariaDBConnectionDetails::password, database::mariadb::MariaDBConnectionDetails::port, database::mariadb::MariaDBConnectionDetails::socket, and database::mariadb::MariaDBConnectionDetails::username.

Here is the call graph for this function:

◆ timeToString()

std::string timeToString ( std::chrono::time_point< std::chrono::system_clock >  time)

Definition at line 71 of file AuthorizationServer.cpp.

71 {
72 auto in_time_t = std::chrono::system_clock::to_time_t(time);
73 std::stringstream ss;
74 ss << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d %X");
75 return ss.str();
76}