diff options
| -rw-r--r-- | .zed/tasks.json | 34 | ||||
| -rw-r--r-- | Cargo.toml | 5 | ||||
| -rw-r--r-- | entity/src/users.rs | 4 | ||||
| -rw-r--r-- | migration/src/m0_init_tables.rs | 8 | ||||
| -rw-r--r-- | src/auth.rs | 49 | ||||
| -rw-r--r-- | src/error.rs | 89 | ||||
| -rw-r--r-- | src/extract/auth.rs | 36 | ||||
| -rw-r--r-- | src/extract/json.rs | 19 | ||||
| -rw-r--r-- | src/extract/mod.rs | 5 | ||||
| -rw-r--r-- | src/main.rs | 30 | ||||
| -rw-r--r-- | src/response.rs | 56 | ||||
| -rw-r--r-- | src/routers/account.rs | 184 | ||||
| -rw-r--r-- | src/routers/mod.rs | 12 | ||||
| -rw-r--r-- | src/state.rs | 7 |
14 files changed, 490 insertions, 48 deletions
diff --git a/.zed/tasks.json b/.zed/tasks.json index e725a43..cf60465 100644 --- a/.zed/tasks.json +++ b/.zed/tasks.json | |||
| @@ -32,12 +32,40 @@ | |||
| 32 | "shell": "system" | 32 | "shell": "system" |
| 33 | }, | 33 | }, |
| 34 | { | 34 | { |
| 35 | "label": "Run debug database", | ||
| 36 | "command": "docker", | ||
| 37 | "args": [ | ||
| 38 | "run", | ||
| 39 | "--name", | ||
| 40 | "debug-postgres", | ||
| 41 | "-e", | ||
| 42 | "POSTGRES_PASSWORD=itmo_queue", | ||
| 43 | "-e", | ||
| 44 | "POSTGRES_USER=itmo_queue", | ||
| 45 | "-e", | ||
| 46 | "POSTGRES_DB=itmo_queue", | ||
| 47 | "-p", | ||
| 48 | "5432:5432", | ||
| 49 | "--rm", | ||
| 50 | "postgres" | ||
| 51 | ], | ||
| 52 | |||
| 53 | "use_new_terminal": false, | ||
| 54 | "allow_concurrent_runs": false, | ||
| 55 | "reveal": "no_focus", | ||
| 56 | "reveal_target": "dock", | ||
| 57 | "hide": "never", | ||
| 58 | "shell": "system" | ||
| 59 | }, | ||
| 60 | { | ||
| 35 | "label": "Run release server", | 61 | "label": "Run release server", |
| 36 | "command": "cargo", | 62 | "command": "cargo", |
| 37 | "args": ["run", "-r"], | 63 | "args": ["run", "-r"], |
| 38 | 64 | ||
| 39 | "env": { | 65 | "env": { |
| 40 | "SERVER_BIND": "0.0.0.0:8080" | 66 | "SECRET": "secret", |
| 67 | "SERVER_BIND": "0.0.0.0:8080", | ||
| 68 | "DATABASE_URL": "postgres://itmo_queue:itmo_queue@localhost/itmo_queue" | ||
| 41 | }, | 69 | }, |
| 42 | 70 | ||
| 43 | "use_new_terminal": false, | 71 | "use_new_terminal": false, |
| @@ -53,7 +81,9 @@ | |||
| 53 | "args": ["run"], | 81 | "args": ["run"], |
| 54 | 82 | ||
| 55 | "env": { | 83 | "env": { |
| 56 | "SERVER_BIND": "0.0.0.0:8080" | 84 | "SECRET": "secret", |
| 85 | "SERVER_BIND": "0.0.0.0:8080", | ||
| 86 | "DATABASE_URL": "postgres://itmo_queue:itmo_queue@localhost/itmo_queue" | ||
| 57 | }, | 87 | }, |
| 58 | 88 | ||
| 59 | "use_new_terminal": false, | 89 | "use_new_terminal": false, |
| @@ -8,8 +8,13 @@ publish = false | |||
| 8 | members = [".", "entity", "migration"] | 8 | members = [".", "entity", "migration"] |
| 9 | 9 | ||
| 10 | [dependencies] | 10 | [dependencies] |
| 11 | argon2 = "0.5.3" | ||
| 11 | axum = "0.8.4" | 12 | axum = "0.8.4" |
| 13 | axum-extra = { version = "0.10.1", features = ["typed-header"] } | ||
| 14 | chrono = { version = "0.4.41", features = ["serde"] } | ||
| 12 | entity = { version = "0.1.0", path = "entity" } | 15 | entity = { version = "0.1.0", path = "entity" } |
| 16 | headers = "0.4.1" | ||
| 17 | jsonwebtoken = "9.3.1" | ||
| 13 | migration = { version = "0.1.0", path = "migration" } | 18 | migration = { version = "0.1.0", path = "migration" } |
| 14 | sea-orm = { version = "1.1.14", features = ["sqlx-postgres", "runtime-tokio-rustls"] } | 19 | sea-orm = { version = "1.1.14", features = ["sqlx-postgres", "runtime-tokio-rustls"] } |
| 15 | serde = { version = "1.0.219", features = ["derive"] } | 20 | serde = { version = "1.0.219", features = ["derive"] } |
diff --git a/entity/src/users.rs b/entity/src/users.rs index b61d51b..6628c9e 100644 --- a/entity/src/users.rs +++ b/entity/src/users.rs | |||
| @@ -8,8 +8,8 @@ pub struct Model { | |||
| 8 | #[sea_orm(primary_key)] | 8 | #[sea_orm(primary_key)] |
| 9 | pub id: i64, | 9 | pub id: i64, |
| 10 | #[sea_orm(unique)] | 10 | #[sea_orm(unique)] |
| 11 | pub login: String, | 11 | pub username: String, |
| 12 | pub password: String, | 12 | pub password_hash: String, |
| 13 | pub password_issue_date: DateTime, | 13 | pub password_issue_date: DateTime, |
| 14 | pub first_name: String, | 14 | pub first_name: String, |
| 15 | pub last_name: String, | 15 | pub last_name: String, |
diff --git a/migration/src/m0_init_tables.rs b/migration/src/m0_init_tables.rs index 576f45b..536b563 100644 --- a/migration/src/m0_init_tables.rs +++ b/migration/src/m0_init_tables.rs | |||
| @@ -6,8 +6,8 @@ use sea_orm_migration::{prelude::*, schema::*}; | |||
| 6 | enum Users { | 6 | enum Users { |
| 7 | Table, | 7 | Table, |
| 8 | Id, | 8 | Id, |
| 9 | Login, | 9 | Username, |
| 10 | Password, | 10 | PasswordHash, |
| 11 | PasswordIssueDate, | 11 | PasswordIssueDate, |
| 12 | FirstName, | 12 | FirstName, |
| 13 | LastName, | 13 | LastName, |
| @@ -64,8 +64,8 @@ impl MigrationTrait for Migration { | |||
| 64 | .table(Users::Table) | 64 | .table(Users::Table) |
| 65 | .if_not_exists() | 65 | .if_not_exists() |
| 66 | .col(pk_auto(Users::Id).big_integer()) | 66 | .col(pk_auto(Users::Id).big_integer()) |
| 67 | .col(string(Users::Login).unique_key()) | 67 | .col(string(Users::Username).unique_key()) |
| 68 | .col(string(Users::Password)) | 68 | .col(string(Users::PasswordHash)) |
| 69 | .col(timestamp(Users::PasswordIssueDate)) | 69 | .col(timestamp(Users::PasswordIssueDate)) |
| 70 | .col(string(Users::FirstName)) | 70 | .col(string(Users::FirstName)) |
| 71 | .col(string(Users::LastName)) | 71 | .col(string(Users::LastName)) |
diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..418f64e --- /dev/null +++ b/src/auth.rs | |||
| @@ -0,0 +1,49 @@ | |||
| 1 | use argon2::password_hash::rand_core::OsRng; | ||
| 2 | use argon2::password_hash::{PasswordHasher, SaltString}; | ||
| 3 | use argon2::{Argon2, PasswordHash, PasswordVerifier}; | ||
| 4 | use jsonwebtoken::{self as jwt, DecodingKey, EncodingKey, Header, Validation}; | ||
| 5 | use serde::{Deserialize, Serialize}; | ||
| 6 | |||
| 7 | #[derive(Serialize, Deserialize)] | ||
| 8 | pub struct JwtClaims { | ||
| 9 | pub sub: i64, | ||
| 10 | pub iat: i64, | ||
| 11 | pub exp: i64, | ||
| 12 | } | ||
| 13 | |||
| 14 | pub fn create_password(password: &str) -> argon2::password_hash::Result<String> { | ||
| 15 | Ok(Argon2::default() | ||
| 16 | .hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng))? | ||
| 17 | .to_string()) | ||
| 18 | } | ||
| 19 | |||
| 20 | pub fn validate_password( | ||
| 21 | password: &str, | ||
| 22 | password_hash: &str, | ||
| 23 | ) -> argon2::password_hash::Result<bool> { | ||
| 24 | Ok(Argon2::default() | ||
| 25 | .verify_password(password.as_bytes(), &PasswordHash::new(password_hash)?) | ||
| 26 | .is_ok()) | ||
| 27 | } | ||
| 28 | |||
| 29 | pub fn create_jwt(claims: &JwtClaims, secret: &str) -> jwt::errors::Result<String> { | ||
| 30 | jwt::encode( | ||
| 31 | &Header::default(), | ||
| 32 | claims, | ||
| 33 | &EncodingKey::from_secret(secret.as_bytes()), | ||
| 34 | ) | ||
| 35 | } | ||
| 36 | |||
| 37 | pub fn validate_jwt(token: &str, secret: &str) -> jwt::errors::Result<JwtClaims> { | ||
| 38 | let mut validation = Validation::default(); | ||
| 39 | validation.set_required_spec_claims(&["exp"]); | ||
| 40 | validation.validate_exp = true; | ||
| 41 | validation.leeway = 0; | ||
| 42 | |||
| 43 | Ok(jwt::decode( | ||
| 44 | token, | ||
| 45 | &DecodingKey::from_secret(secret.as_bytes()), | ||
| 46 | &validation, | ||
| 47 | )? | ||
| 48 | .claims) | ||
| 49 | } | ||
diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..2d0f911 --- /dev/null +++ b/src/error.rs | |||
| @@ -0,0 +1,89 @@ | |||
| 1 | use axum::extract::rejection::JsonRejection; | ||
| 2 | use axum::response::{IntoResponse, Response}; | ||
| 3 | use axum_extra::typed_header::TypedHeaderRejection; | ||
| 4 | |||
| 5 | use crate::response::{ErrorResponse, FailResponse, SuccessResponse}; | ||
| 6 | |||
| 7 | pub type ApiResult<T> = Result<SuccessResponse<T>, ApiError>; | ||
| 8 | |||
| 9 | pub enum ApiError { | ||
| 10 | // 400 | ||
| 11 | BadJsonBody(String), | ||
| 12 | BadAuthTokenHeader(String), | ||
| 13 | UserAlreadyExists { username: String }, | ||
| 14 | InvalidPassword, | ||
| 15 | NotAuthorized, | ||
| 16 | // 500 | ||
| 17 | Database(String), | ||
| 18 | PasswordHash(String), | ||
| 19 | InternalJwt(String), | ||
| 20 | } | ||
| 21 | |||
| 22 | impl From<JsonRejection> for ApiError { | ||
| 23 | fn from(value: JsonRejection) -> Self { | ||
| 24 | Self::BadJsonBody(value.body_text()) | ||
| 25 | } | ||
| 26 | } | ||
| 27 | |||
| 28 | impl From<TypedHeaderRejection> for ApiError { | ||
| 29 | fn from(value: TypedHeaderRejection) -> Self { | ||
| 30 | Self::BadAuthTokenHeader(value.to_string()) | ||
| 31 | } | ||
| 32 | } | ||
| 33 | |||
| 34 | impl From<sea_orm::DbErr> for ApiError { | ||
| 35 | fn from(value: sea_orm::DbErr) -> Self { | ||
| 36 | Self::Database(value.to_string()) | ||
| 37 | } | ||
| 38 | } | ||
| 39 | |||
| 40 | impl From<argon2::password_hash::Error> for ApiError { | ||
| 41 | fn from(value: argon2::password_hash::Error) -> Self { | ||
| 42 | Self::PasswordHash(value.to_string()) | ||
| 43 | } | ||
| 44 | } | ||
| 45 | |||
| 46 | impl ToString for ApiError { | ||
| 47 | fn to_string(&self) -> String { | ||
| 48 | match self { | ||
| 49 | // 400 | ||
| 50 | ApiError::BadJsonBody(..) => "BadJsonBody", | ||
| 51 | ApiError::BadAuthTokenHeader(..) => "BadAuthTokenHeader", | ||
| 52 | ApiError::UserAlreadyExists { .. } => "UserAlreadyExists", | ||
| 53 | ApiError::InvalidPassword => "InvalidPassword", | ||
| 54 | ApiError::NotAuthorized => "NotAuthorized", | ||
| 55 | // 500 | ||
| 56 | ApiError::Database(..) => "Database", | ||
| 57 | ApiError::PasswordHash(..) => "PasswordHash", | ||
| 58 | ApiError::InternalJwt(..) => "InternalJwt", | ||
| 59 | } | ||
| 60 | .to_string() | ||
| 61 | } | ||
| 62 | } | ||
| 63 | |||
| 64 | impl IntoResponse for ApiError { | ||
| 65 | fn into_response(self) -> Response { | ||
| 66 | let kind = self.to_string(); | ||
| 67 | match self { | ||
| 68 | // 400 | ||
| 69 | ApiError::BadJsonBody(msg) => FailResponse(kind, msg).into_response(), | ||
| 70 | ApiError::BadAuthTokenHeader(msg) => FailResponse(kind, msg).into_response(), | ||
| 71 | ApiError::UserAlreadyExists { username } => FailResponse( | ||
| 72 | kind, | ||
| 73 | format!("user with username `{}` already exists", username), | ||
| 74 | ) | ||
| 75 | .into_response(), | ||
| 76 | ApiError::InvalidPassword => { | ||
| 77 | FailResponse(kind, "password is invalid".to_string()).into_response() | ||
| 78 | } | ||
| 79 | ApiError::NotAuthorized => { | ||
| 80 | FailResponse(kind, "user is not authorized".to_string()).into_response() | ||
| 81 | } | ||
| 82 | |||
| 83 | // 500 | ||
| 84 | ApiError::Database(msg) => ErrorResponse(kind, msg).into_response(), | ||
| 85 | ApiError::PasswordHash(msg) => ErrorResponse(kind, msg).into_response(), | ||
| 86 | ApiError::InternalJwt(msg) => ErrorResponse(kind, msg).into_response(), | ||
| 87 | } | ||
| 88 | } | ||
| 89 | } | ||
diff --git a/src/extract/auth.rs b/src/extract/auth.rs new file mode 100644 index 0000000..cc357fd --- /dev/null +++ b/src/extract/auth.rs | |||
| @@ -0,0 +1,36 @@ | |||
| 1 | use axum::extract::FromRequestParts; | ||
| 2 | use axum::http::request::Parts; | ||
| 3 | use axum_extra::TypedHeader; | ||
| 4 | use entity::users; | ||
| 5 | use headers::authorization::{Authorization, Bearer}; | ||
| 6 | use sea_orm::EntityTrait; | ||
| 7 | |||
| 8 | use crate::{ApiError, AppState, validate_jwt}; | ||
| 9 | |||
| 10 | pub struct Auth(pub users::Model); | ||
| 11 | |||
| 12 | impl FromRequestParts<AppState> for Auth { | ||
| 13 | type Rejection = ApiError; | ||
| 14 | |||
| 15 | async fn from_request_parts( | ||
| 16 | parts: &mut Parts, | ||
| 17 | state: &AppState, | ||
| 18 | ) -> Result<Self, Self::Rejection> { | ||
| 19 | let token_header = | ||
| 20 | TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await?; | ||
| 21 | |||
| 22 | let jwt_claims = validate_jwt(token_header.token(), &state.secret) | ||
| 23 | .map_err(|_| ApiError::NotAuthorized)?; | ||
| 24 | |||
| 25 | let user = users::Entity::find_by_id(jwt_claims.sub) | ||
| 26 | .one(&state.db) | ||
| 27 | .await? | ||
| 28 | .ok_or(ApiError::NotAuthorized)?; | ||
| 29 | |||
| 30 | if jwt_claims.iat < user.password_issue_date.and_utc().timestamp() { | ||
| 31 | return Err(ApiError::NotAuthorized); | ||
| 32 | } | ||
| 33 | |||
| 34 | Ok(Auth(user)) | ||
| 35 | } | ||
| 36 | } | ||
diff --git a/src/extract/json.rs b/src/extract/json.rs new file mode 100644 index 0000000..cfde15b --- /dev/null +++ b/src/extract/json.rs | |||
| @@ -0,0 +1,19 @@ | |||
| 1 | use axum::extract::rejection::JsonRejection; | ||
| 2 | use axum::extract::{FromRequest, Request}; | ||
| 3 | |||
| 4 | use crate::error::ApiError; | ||
| 5 | |||
| 6 | pub struct ApiJson<T>(pub T); | ||
| 7 | |||
| 8 | impl<S, T> FromRequest<S> for ApiJson<T> | ||
| 9 | where | ||
| 10 | axum::Json<T>: FromRequest<S, Rejection = JsonRejection>, | ||
| 11 | S: Send + Sync, | ||
| 12 | { | ||
| 13 | type Rejection = ApiError; | ||
| 14 | |||
| 15 | #[inline] | ||
| 16 | async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { | ||
| 17 | Ok(Self(axum::Json::<T>::from_request(req, state).await?.0)) | ||
| 18 | } | ||
| 19 | } | ||
diff --git a/src/extract/mod.rs b/src/extract/mod.rs new file mode 100644 index 0000000..b46a610 --- /dev/null +++ b/src/extract/mod.rs | |||
| @@ -0,0 +1,5 @@ | |||
| 1 | mod auth; | ||
| 2 | mod json; | ||
| 3 | |||
| 4 | pub use auth::Auth; | ||
| 5 | pub use json::ApiJson; | ||
diff --git a/src/main.rs b/src/main.rs index c53664a..3b3f868 100644 --- a/src/main.rs +++ b/src/main.rs | |||
| @@ -1,9 +1,30 @@ | |||
| 1 | mod auth; | ||
| 1 | mod error; | 2 | mod error; |
| 3 | mod extract; | ||
| 2 | mod response; | 4 | mod response; |
| 3 | mod routers; | 5 | mod routers; |
| 6 | mod state; | ||
| 4 | 7 | ||
| 8 | pub use auth::{JwtClaims, create_jwt, create_password, validate_jwt, validate_password}; | ||
| 9 | pub use error::{ApiError, ApiResult}; | ||
| 10 | pub use response::{ErrorResponse, FailResponse, SuccessResponse}; | ||
| 11 | pub use state::AppState; | ||
| 12 | |||
| 13 | use sea_orm::Database; | ||
| 5 | use tokio::net::TcpListener; | 14 | use tokio::net::TcpListener; |
| 6 | use tracing::{Level, info}; | 15 | use tower_http::trace::TraceLayer; |
| 16 | use tracing::info; | ||
| 17 | use tracing_subscriber::EnvFilter; | ||
| 18 | |||
| 19 | async fn app_state() -> AppState { | ||
| 20 | let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); | ||
| 21 | let secret = std::env::var("SECRET").expect("SECRET must be set"); | ||
| 22 | |||
| 23 | AppState { | ||
| 24 | db: Database::connect(db_url).await.unwrap(), | ||
| 25 | secret: secret, | ||
| 26 | } | ||
| 27 | } | ||
| 7 | 28 | ||
| 8 | async fn listener() -> TcpListener { | 29 | async fn listener() -> TcpListener { |
| 9 | let addr = std::env::var("SERVER_BIND").expect("SERVER_BIND must be set"); | 30 | let addr = std::env::var("SERVER_BIND").expect("SERVER_BIND must be set"); |
| @@ -13,10 +34,13 @@ async fn listener() -> TcpListener { | |||
| 13 | #[tokio::main] | 34 | #[tokio::main] |
| 14 | async fn main() { | 35 | async fn main() { |
| 15 | tracing_subscriber::fmt() | 36 | tracing_subscriber::fmt() |
| 16 | .with_max_level(Level::DEBUG) | 37 | .with_env_filter(EnvFilter::new("info,sqlx=warn,tower_http=debug")) |
| 17 | .init(); | 38 | .init(); |
| 18 | 39 | ||
| 19 | let router = routers::router(); | 40 | let state = app_state().await; |
| 41 | let router = routers::router() | ||
| 42 | .layer(TraceLayer::new_for_http()) | ||
| 43 | .with_state(state); | ||
| 20 | let listener = listener().await; | 44 | let listener = listener().await; |
| 21 | 45 | ||
| 22 | info!( | 46 | info!( |
diff --git a/src/response.rs b/src/response.rs index 8d505a5..25c3008 100644 --- a/src/response.rs +++ b/src/response.rs | |||
| @@ -1,32 +1,52 @@ | |||
| 1 | use axum::http::StatusCode; | ||
| 1 | use axum::response::{IntoResponse, Response}; | 2 | use axum::response::{IntoResponse, Response}; |
| 2 | use serde::Serialize; | 3 | use serde::Serialize; |
| 3 | use serde_json::json; | 4 | use serde_json::json; |
| 4 | 5 | ||
| 5 | pub enum ApiResponse<T> { | 6 | pub struct SuccessResponse<T>(pub T); |
| 6 | Success(T), | 7 | pub struct FailResponse(pub String, pub String); |
| 7 | Fail(T), | 8 | pub struct ErrorResponse(pub String, pub String); |
| 8 | Error(String), | ||
| 9 | } | ||
| 10 | 9 | ||
| 11 | impl<T> IntoResponse for ApiResponse<T> | 10 | impl<T> IntoResponse for SuccessResponse<T> |
| 12 | where | 11 | where |
| 13 | T: Serialize, | 12 | T: Serialize, |
| 14 | { | 13 | { |
| 15 | fn into_response(self) -> Response { | 14 | fn into_response(self) -> Response { |
| 16 | axum::Json(match self { | 15 | ( |
| 17 | ApiResponse::Success(data) => json!({ | 16 | StatusCode::OK, |
| 17 | axum::Json(json!({ | ||
| 18 | "status": "success", | 18 | "status": "success", |
| 19 | "data": data | 19 | "data": self.0 |
| 20 | }), | 20 | })), |
| 21 | ApiResponse::Fail(data) => json!({ | 21 | ) |
| 22 | .into_response() | ||
| 23 | } | ||
| 24 | } | ||
| 25 | |||
| 26 | impl IntoResponse for FailResponse { | ||
| 27 | fn into_response(self) -> Response { | ||
| 28 | ( | ||
| 29 | StatusCode::BAD_REQUEST, | ||
| 30 | axum::Json(json!({ | ||
| 22 | "status": "fail", | 31 | "status": "fail", |
| 23 | "data": data | 32 | "kind": self.0, |
| 24 | }), | 33 | "message": self.1 |
| 25 | ApiResponse::Error(message) => json!({ | 34 | })), |
| 35 | ) | ||
| 36 | .into_response() | ||
| 37 | } | ||
| 38 | } | ||
| 39 | |||
| 40 | impl IntoResponse for ErrorResponse { | ||
| 41 | fn into_response(self) -> Response { | ||
| 42 | ( | ||
| 43 | StatusCode::INTERNAL_SERVER_ERROR, | ||
| 44 | axum::Json(json!({ | ||
| 26 | "status": "error", | 45 | "status": "error", |
| 27 | "message": message | 46 | "kind": self.0, |
| 28 | }), | 47 | "message": self.1 |
| 29 | }) | 48 | })), |
| 30 | .into_response() | 49 | ) |
| 50 | .into_response() | ||
| 31 | } | 51 | } |
| 32 | } | 52 | } |
diff --git a/src/routers/account.rs b/src/routers/account.rs index 8192133..98ee61d 100644 --- a/src/routers/account.rs +++ b/src/routers/account.rs | |||
| @@ -1,24 +1,188 @@ | |||
| 1 | use axum::Router; | 1 | use axum::Router; |
| 2 | use axum::response::IntoResponse; | 2 | use axum::extract::State; |
| 3 | use axum::routing::{get, post}; | 3 | use axum::routing::{delete, get, post, put}; |
| 4 | use chrono::{DateTime, Duration, Utc}; | ||
| 5 | use entity::users::{self}; | ||
| 6 | use sea_orm::ActiveValue::Set; | ||
| 7 | use sea_orm::{ | ||
| 8 | ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait, QueryFilter, | ||
| 9 | }; | ||
| 10 | use serde::{Deserialize, Serialize}; | ||
| 4 | 11 | ||
| 5 | use crate::response::ApiResponse; | 12 | use crate::extract::{ApiJson, Auth}; |
| 13 | use crate::{ | ||
| 14 | ApiError, ApiResult, AppState, JwtClaims, SuccessResponse, create_jwt, create_password, | ||
| 15 | validate_password, | ||
| 16 | }; | ||
| 6 | 17 | ||
| 7 | async fn me() -> impl IntoResponse { | 18 | #[derive(Serialize)] |
| 8 | ApiResponse::Success("Me") | 19 | struct Account { |
| 20 | id: i64, | ||
| 21 | username: String, | ||
| 22 | first_name: String, | ||
| 23 | last_name: String, | ||
| 9 | } | 24 | } |
| 10 | 25 | ||
| 11 | async fn register() -> impl IntoResponse { | 26 | #[derive(Serialize)] |
| 12 | ApiResponse::Success("Register") | 27 | struct Token { |
| 28 | token: String, | ||
| 29 | expired_at: DateTime<Utc>, | ||
| 13 | } | 30 | } |
| 14 | 31 | ||
| 15 | async fn login() -> impl IntoResponse { | 32 | async fn me(Auth(user): Auth) -> ApiResult<Account> { |
| 16 | ApiResponse::Success("Login") | 33 | return Ok(SuccessResponse(Account { |
| 34 | id: user.id, | ||
| 35 | username: user.username, | ||
| 36 | first_name: user.first_name, | ||
| 37 | last_name: user.last_name, | ||
| 38 | })); | ||
| 17 | } | 39 | } |
| 18 | 40 | ||
| 19 | pub(crate) fn router() -> Router { | 41 | #[derive(Deserialize)] |
| 42 | struct RegisterRequest { | ||
| 43 | username: String, | ||
| 44 | password: String, | ||
| 45 | first_name: String, | ||
| 46 | last_name: String, | ||
| 47 | } | ||
| 48 | |||
| 49 | async fn register( | ||
| 50 | State(state): State<AppState>, | ||
| 51 | ApiJson(req): ApiJson<RegisterRequest>, | ||
| 52 | ) -> ApiResult<Account> { | ||
| 53 | let user_exists = users::Entity::find() | ||
| 54 | .filter(users::Column::Username.eq(&req.username)) | ||
| 55 | .one(&state.db) | ||
| 56 | .await? | ||
| 57 | .is_some(); | ||
| 58 | |||
| 59 | if user_exists { | ||
| 60 | return Err(ApiError::UserAlreadyExists { | ||
| 61 | username: req.username, | ||
| 62 | }); | ||
| 63 | } | ||
| 64 | |||
| 65 | let user = users::ActiveModel { | ||
| 66 | username: Set(req.username), | ||
| 67 | password_hash: Set(create_password(&req.password)?), | ||
| 68 | password_issue_date: Set(Utc::now().naive_utc()), | ||
| 69 | first_name: Set(req.first_name), | ||
| 70 | last_name: Set(req.last_name), | ||
| 71 | ..Default::default() | ||
| 72 | } | ||
| 73 | .insert(&state.db) | ||
| 74 | .await?; | ||
| 75 | |||
| 76 | Ok(SuccessResponse(Account { | ||
| 77 | id: user.id, | ||
| 78 | username: user.username, | ||
| 79 | first_name: user.first_name, | ||
| 80 | last_name: user.last_name, | ||
| 81 | })) | ||
| 82 | } | ||
| 83 | |||
| 84 | #[derive(Deserialize, Default)] | ||
| 85 | enum TokenLifetime { | ||
| 86 | Day = 1, | ||
| 87 | #[default] | ||
| 88 | Week = 7, | ||
| 89 | Month = 31, | ||
| 90 | } | ||
| 91 | |||
| 92 | #[derive(Deserialize)] | ||
| 93 | struct LoginRequest { | ||
| 94 | username: String, | ||
| 95 | password: String, | ||
| 96 | #[serde(default)] | ||
| 97 | token_lifetime: TokenLifetime, | ||
| 98 | } | ||
| 99 | |||
| 100 | async fn login( | ||
| 101 | State(state): State<AppState>, | ||
| 102 | ApiJson(req): ApiJson<LoginRequest>, | ||
| 103 | ) -> ApiResult<Token> { | ||
| 104 | let user = users::Entity::find() | ||
| 105 | .filter(users::Column::Username.eq(&req.username)) | ||
| 106 | .one(&state.db) | ||
| 107 | .await? | ||
| 108 | .ok_or(ApiError::InvalidPassword)?; | ||
| 109 | |||
| 110 | if !validate_password(&req.password, &user.password_hash)? { | ||
| 111 | return Err(ApiError::InvalidPassword); | ||
| 112 | } | ||
| 113 | |||
| 114 | let expired_at = Utc::now() + Duration::days(req.token_lifetime as i64); | ||
| 115 | |||
| 116 | let token = create_jwt( | ||
| 117 | &JwtClaims { | ||
| 118 | sub: user.id, | ||
| 119 | iat: user.password_issue_date.and_utc().timestamp(), | ||
| 120 | exp: expired_at.timestamp(), | ||
| 121 | }, | ||
| 122 | &state.secret, | ||
| 123 | ) | ||
| 124 | .map_err(|e| ApiError::InternalJwt(e.to_string()))?; | ||
| 125 | |||
| 126 | Ok(SuccessResponse(Token { token, expired_at })) | ||
| 127 | } | ||
| 128 | |||
| 129 | #[derive(Deserialize)] | ||
| 130 | struct ChangePasswordRequest { | ||
| 131 | old_password: String, | ||
| 132 | new_password: String, | ||
| 133 | } | ||
| 134 | |||
| 135 | async fn change_password( | ||
| 136 | State(state): State<AppState>, | ||
| 137 | Auth(user): Auth, | ||
| 138 | ApiJson(req): ApiJson<ChangePasswordRequest>, | ||
| 139 | ) -> ApiResult<Account> { | ||
| 140 | if !validate_password(&req.old_password, &user.password_hash)? { | ||
| 141 | return Err(ApiError::InvalidPassword); | ||
| 142 | } | ||
| 143 | |||
| 144 | let mut active_user = user.into_active_model(); | ||
| 145 | active_user.password_hash = Set(create_password(&req.new_password)?); | ||
| 146 | active_user.password_issue_date = Set(Utc::now().naive_utc()); | ||
| 147 | |||
| 148 | let user = active_user.update(&state.db).await?; | ||
| 149 | Ok(SuccessResponse(Account { | ||
| 150 | id: user.id, | ||
| 151 | username: user.username, | ||
| 152 | first_name: user.first_name, | ||
| 153 | last_name: user.last_name, | ||
| 154 | })) | ||
| 155 | } | ||
| 156 | |||
| 157 | #[derive(Deserialize)] | ||
| 158 | struct DeleteUserRequest { | ||
| 159 | password: String, | ||
| 160 | } | ||
| 161 | |||
| 162 | async fn delete_account( | ||
| 163 | State(state): State<AppState>, | ||
| 164 | Auth(user): Auth, | ||
| 165 | ApiJson(req): ApiJson<DeleteUserRequest>, | ||
| 166 | ) -> ApiResult<Account> { | ||
| 167 | if !validate_password(&req.password, &user.password_hash)? { | ||
| 168 | return Err(ApiError::InvalidPassword); | ||
| 169 | } | ||
| 170 | |||
| 171 | user.clone().delete(&state.db).await?; | ||
| 172 | |||
| 173 | Ok(SuccessResponse(Account { | ||
| 174 | id: user.id, | ||
| 175 | username: user.username, | ||
| 176 | first_name: user.first_name, | ||
| 177 | last_name: user.last_name, | ||
| 178 | })) | ||
| 179 | } | ||
| 180 | |||
| 181 | pub(crate) fn router() -> Router<AppState> { | ||
| 20 | Router::new() | 182 | Router::new() |
| 21 | .route("/me", get(me)) | 183 | .route("/me", get(me)) |
| 22 | .route("/register", post(register)) | 184 | .route("/register", post(register)) |
| 23 | .route("/login", post(login)) | 185 | .route("/login", post(login)) |
| 186 | .route("/change_password", put(change_password)) | ||
| 187 | .route("/delete", delete(delete_account)) | ||
| 24 | } | 188 | } |
diff --git a/src/routers/mod.rs b/src/routers/mod.rs index ee925f0..b57f71d 100644 --- a/src/routers/mod.rs +++ b/src/routers/mod.rs | |||
| @@ -1,15 +1,9 @@ | |||
| 1 | mod account; | 1 | mod account; |
| 2 | 2 | ||
| 3 | use axum::Router; | 3 | use axum::Router; |
| 4 | use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}; | ||
| 5 | use tracing::Level; | ||
| 6 | 4 | ||
| 7 | pub(crate) fn router() -> Router { | 5 | use crate::state::AppState; |
| 8 | let trace_layer = TraceLayer::new_for_http() | ||
| 9 | .on_request(DefaultOnRequest::new().level(Level::INFO)) | ||
| 10 | .on_response(DefaultOnResponse::new().level(Level::INFO)); | ||
| 11 | 6 | ||
| 12 | Router::new() | 7 | pub(crate) fn router() -> Router<AppState> { |
| 13 | .layer(trace_layer) | 8 | Router::new().nest("/account", account::router()) |
| 14 | .nest("/account", account::router()) | ||
| 15 | } | 9 | } |
diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..9779b46 --- /dev/null +++ b/src/state.rs | |||
| @@ -0,0 +1,7 @@ | |||
| 1 | use sea_orm::DatabaseConnection; | ||
| 2 | |||
| 3 | #[derive(Clone)] | ||
| 4 | pub struct AppState { | ||
| 5 | pub db: DatabaseConnection, | ||
| 6 | pub secret: String, | ||
| 7 | } | ||
