diff options
Diffstat (limited to 'src/routers')
| -rw-r--r-- | src/routers/account.rs | 184 | ||||
| -rw-r--r-- | src/routers/mod.rs | 12 |
2 files changed, 177 insertions, 19 deletions
diff --git a/src/routers/account.rs b/src/routers/account.rs index 8192133..98ee61d 100644 --- a/src/routers/account.rs +++ b/src/routers/account.rs | |||
| @@ -1,24 +1,188 @@ | |||
| 1 | use axum::Router; | 1 | use axum::Router; |
| 2 | use axum::response::IntoResponse; | 2 | use axum::extract::State; |
| 3 | use axum::routing::{get, post}; | 3 | use axum::routing::{delete, get, post, put}; |
| 4 | use chrono::{DateTime, Duration, Utc}; | ||
| 5 | use entity::users::{self}; | ||
| 6 | use sea_orm::ActiveValue::Set; | ||
| 7 | use sea_orm::{ | ||
| 8 | ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait, QueryFilter, | ||
| 9 | }; | ||
| 10 | use serde::{Deserialize, Serialize}; | ||
| 4 | 11 | ||
| 5 | use crate::response::ApiResponse; | 12 | use crate::extract::{ApiJson, Auth}; |
| 13 | use crate::{ | ||
| 14 | ApiError, ApiResult, AppState, JwtClaims, SuccessResponse, create_jwt, create_password, | ||
| 15 | validate_password, | ||
| 16 | }; | ||
| 6 | 17 | ||
| 7 | async fn me() -> impl IntoResponse { | 18 | #[derive(Serialize)] |
| 8 | ApiResponse::Success("Me") | 19 | struct Account { |
| 20 | id: i64, | ||
| 21 | username: String, | ||
| 22 | first_name: String, | ||
| 23 | last_name: String, | ||
| 9 | } | 24 | } |
| 10 | 25 | ||
| 11 | async fn register() -> impl IntoResponse { | 26 | #[derive(Serialize)] |
| 12 | ApiResponse::Success("Register") | 27 | struct Token { |
| 28 | token: String, | ||
| 29 | expired_at: DateTime<Utc>, | ||
| 13 | } | 30 | } |
| 14 | 31 | ||
| 15 | async fn login() -> impl IntoResponse { | 32 | async fn me(Auth(user): Auth) -> ApiResult<Account> { |
| 16 | ApiResponse::Success("Login") | 33 | return Ok(SuccessResponse(Account { |
| 34 | id: user.id, | ||
| 35 | username: user.username, | ||
| 36 | first_name: user.first_name, | ||
| 37 | last_name: user.last_name, | ||
| 38 | })); | ||
| 17 | } | 39 | } |
| 18 | 40 | ||
| 19 | pub(crate) fn router() -> Router { | 41 | #[derive(Deserialize)] |
| 42 | struct RegisterRequest { | ||
| 43 | username: String, | ||
| 44 | password: String, | ||
| 45 | first_name: String, | ||
| 46 | last_name: String, | ||
| 47 | } | ||
| 48 | |||
| 49 | async fn register( | ||
| 50 | State(state): State<AppState>, | ||
| 51 | ApiJson(req): ApiJson<RegisterRequest>, | ||
| 52 | ) -> ApiResult<Account> { | ||
| 53 | let user_exists = users::Entity::find() | ||
| 54 | .filter(users::Column::Username.eq(&req.username)) | ||
| 55 | .one(&state.db) | ||
| 56 | .await? | ||
| 57 | .is_some(); | ||
| 58 | |||
| 59 | if user_exists { | ||
| 60 | return Err(ApiError::UserAlreadyExists { | ||
| 61 | username: req.username, | ||
| 62 | }); | ||
| 63 | } | ||
| 64 | |||
| 65 | let user = users::ActiveModel { | ||
| 66 | username: Set(req.username), | ||
| 67 | password_hash: Set(create_password(&req.password)?), | ||
| 68 | password_issue_date: Set(Utc::now().naive_utc()), | ||
| 69 | first_name: Set(req.first_name), | ||
| 70 | last_name: Set(req.last_name), | ||
| 71 | ..Default::default() | ||
| 72 | } | ||
| 73 | .insert(&state.db) | ||
| 74 | .await?; | ||
| 75 | |||
| 76 | Ok(SuccessResponse(Account { | ||
| 77 | id: user.id, | ||
| 78 | username: user.username, | ||
| 79 | first_name: user.first_name, | ||
| 80 | last_name: user.last_name, | ||
| 81 | })) | ||
| 82 | } | ||
| 83 | |||
| 84 | #[derive(Deserialize, Default)] | ||
| 85 | enum TokenLifetime { | ||
| 86 | Day = 1, | ||
| 87 | #[default] | ||
| 88 | Week = 7, | ||
| 89 | Month = 31, | ||
| 90 | } | ||
| 91 | |||
| 92 | #[derive(Deserialize)] | ||
| 93 | struct LoginRequest { | ||
| 94 | username: String, | ||
| 95 | password: String, | ||
| 96 | #[serde(default)] | ||
| 97 | token_lifetime: TokenLifetime, | ||
| 98 | } | ||
| 99 | |||
| 100 | async fn login( | ||
| 101 | State(state): State<AppState>, | ||
| 102 | ApiJson(req): ApiJson<LoginRequest>, | ||
| 103 | ) -> ApiResult<Token> { | ||
| 104 | let user = users::Entity::find() | ||
| 105 | .filter(users::Column::Username.eq(&req.username)) | ||
| 106 | .one(&state.db) | ||
| 107 | .await? | ||
| 108 | .ok_or(ApiError::InvalidPassword)?; | ||
| 109 | |||
| 110 | if !validate_password(&req.password, &user.password_hash)? { | ||
| 111 | return Err(ApiError::InvalidPassword); | ||
| 112 | } | ||
| 113 | |||
| 114 | let expired_at = Utc::now() + Duration::days(req.token_lifetime as i64); | ||
| 115 | |||
| 116 | let token = create_jwt( | ||
| 117 | &JwtClaims { | ||
| 118 | sub: user.id, | ||
| 119 | iat: user.password_issue_date.and_utc().timestamp(), | ||
| 120 | exp: expired_at.timestamp(), | ||
| 121 | }, | ||
| 122 | &state.secret, | ||
| 123 | ) | ||
| 124 | .map_err(|e| ApiError::InternalJwt(e.to_string()))?; | ||
| 125 | |||
| 126 | Ok(SuccessResponse(Token { token, expired_at })) | ||
| 127 | } | ||
| 128 | |||
| 129 | #[derive(Deserialize)] | ||
| 130 | struct ChangePasswordRequest { | ||
| 131 | old_password: String, | ||
| 132 | new_password: String, | ||
| 133 | } | ||
| 134 | |||
| 135 | async fn change_password( | ||
| 136 | State(state): State<AppState>, | ||
| 137 | Auth(user): Auth, | ||
| 138 | ApiJson(req): ApiJson<ChangePasswordRequest>, | ||
| 139 | ) -> ApiResult<Account> { | ||
| 140 | if !validate_password(&req.old_password, &user.password_hash)? { | ||
| 141 | return Err(ApiError::InvalidPassword); | ||
| 142 | } | ||
| 143 | |||
| 144 | let mut active_user = user.into_active_model(); | ||
| 145 | active_user.password_hash = Set(create_password(&req.new_password)?); | ||
| 146 | active_user.password_issue_date = Set(Utc::now().naive_utc()); | ||
| 147 | |||
| 148 | let user = active_user.update(&state.db).await?; | ||
| 149 | Ok(SuccessResponse(Account { | ||
| 150 | id: user.id, | ||
| 151 | username: user.username, | ||
| 152 | first_name: user.first_name, | ||
| 153 | last_name: user.last_name, | ||
| 154 | })) | ||
| 155 | } | ||
| 156 | |||
| 157 | #[derive(Deserialize)] | ||
| 158 | struct DeleteUserRequest { | ||
| 159 | password: String, | ||
| 160 | } | ||
| 161 | |||
| 162 | async fn delete_account( | ||
| 163 | State(state): State<AppState>, | ||
| 164 | Auth(user): Auth, | ||
| 165 | ApiJson(req): ApiJson<DeleteUserRequest>, | ||
| 166 | ) -> ApiResult<Account> { | ||
| 167 | if !validate_password(&req.password, &user.password_hash)? { | ||
| 168 | return Err(ApiError::InvalidPassword); | ||
| 169 | } | ||
| 170 | |||
| 171 | user.clone().delete(&state.db).await?; | ||
| 172 | |||
| 173 | Ok(SuccessResponse(Account { | ||
| 174 | id: user.id, | ||
| 175 | username: user.username, | ||
| 176 | first_name: user.first_name, | ||
| 177 | last_name: user.last_name, | ||
| 178 | })) | ||
| 179 | } | ||
| 180 | |||
| 181 | pub(crate) fn router() -> Router<AppState> { | ||
| 20 | Router::new() | 182 | Router::new() |
| 21 | .route("/me", get(me)) | 183 | .route("/me", get(me)) |
| 22 | .route("/register", post(register)) | 184 | .route("/register", post(register)) |
| 23 | .route("/login", post(login)) | 185 | .route("/login", post(login)) |
| 186 | .route("/change_password", put(change_password)) | ||
| 187 | .route("/delete", delete(delete_account)) | ||
| 24 | } | 188 | } |
diff --git a/src/routers/mod.rs b/src/routers/mod.rs index ee925f0..b57f71d 100644 --- a/src/routers/mod.rs +++ b/src/routers/mod.rs | |||
| @@ -1,15 +1,9 @@ | |||
| 1 | mod account; | 1 | mod account; |
| 2 | 2 | ||
| 3 | use axum::Router; | 3 | use axum::Router; |
| 4 | use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}; | ||
| 5 | use tracing::Level; | ||
| 6 | 4 | ||
| 7 | pub(crate) fn router() -> Router { | 5 | use crate::state::AppState; |
| 8 | let trace_layer = TraceLayer::new_for_http() | ||
| 9 | .on_request(DefaultOnRequest::new().level(Level::INFO)) | ||
| 10 | .on_response(DefaultOnResponse::new().level(Level::INFO)); | ||
| 11 | 6 | ||
| 12 | Router::new() | 7 | pub(crate) fn router() -> Router<AppState> { |
| 13 | .layer(trace_layer) | 8 | Router::new().nest("/account", account::router()) |
| 14 | .nest("/account", account::router()) | ||
| 15 | } | 9 | } |
