From b9d75e22db72aabf47815e381aa6432c1bff3877 Mon Sep 17 00:00:00 2001 From: Tolmachev Igor Date: Mon, 1 Sep 2025 13:32:05 +0300 Subject: Add account endpoints --- src/auth.rs | 49 +++++++++++++ src/error.rs | 89 ++++++++++++++++++++++++ src/extract/auth.rs | 36 ++++++++++ src/extract/json.rs | 19 +++++ src/extract/mod.rs | 5 ++ src/main.rs | 30 +++++++- src/response.rs | 56 ++++++++++----- src/routers/account.rs | 184 ++++++++++++++++++++++++++++++++++++++++++++++--- src/routers/mod.rs | 12 +--- src/state.rs | 7 ++ 10 files changed, 447 insertions(+), 40 deletions(-) create mode 100644 src/auth.rs create mode 100644 src/error.rs create mode 100644 src/extract/auth.rs create mode 100644 src/extract/json.rs create mode 100644 src/extract/mod.rs create mode 100644 src/state.rs (limited to 'src') diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..418f64e --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,49 @@ +use argon2::password_hash::rand_core::OsRng; +use argon2::password_hash::{PasswordHasher, SaltString}; +use argon2::{Argon2, PasswordHash, PasswordVerifier}; +use jsonwebtoken::{self as jwt, DecodingKey, EncodingKey, Header, Validation}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct JwtClaims { + pub sub: i64, + pub iat: i64, + pub exp: i64, +} + +pub fn create_password(password: &str) -> argon2::password_hash::Result { + Ok(Argon2::default() + .hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng))? + .to_string()) +} + +pub fn validate_password( + password: &str, + password_hash: &str, +) -> argon2::password_hash::Result { + Ok(Argon2::default() + .verify_password(password.as_bytes(), &PasswordHash::new(password_hash)?) + .is_ok()) +} + +pub fn create_jwt(claims: &JwtClaims, secret: &str) -> jwt::errors::Result { + jwt::encode( + &Header::default(), + claims, + &EncodingKey::from_secret(secret.as_bytes()), + ) +} + +pub fn validate_jwt(token: &str, secret: &str) -> jwt::errors::Result { + let mut validation = Validation::default(); + validation.set_required_spec_claims(&["exp"]); + validation.validate_exp = true; + validation.leeway = 0; + + Ok(jwt::decode( + token, + &DecodingKey::from_secret(secret.as_bytes()), + &validation, + )? + .claims) +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..2d0f911 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,89 @@ +use axum::extract::rejection::JsonRejection; +use axum::response::{IntoResponse, Response}; +use axum_extra::typed_header::TypedHeaderRejection; + +use crate::response::{ErrorResponse, FailResponse, SuccessResponse}; + +pub type ApiResult = Result, ApiError>; + +pub enum ApiError { + // 400 + BadJsonBody(String), + BadAuthTokenHeader(String), + UserAlreadyExists { username: String }, + InvalidPassword, + NotAuthorized, + // 500 + Database(String), + PasswordHash(String), + InternalJwt(String), +} + +impl From for ApiError { + fn from(value: JsonRejection) -> Self { + Self::BadJsonBody(value.body_text()) + } +} + +impl From for ApiError { + fn from(value: TypedHeaderRejection) -> Self { + Self::BadAuthTokenHeader(value.to_string()) + } +} + +impl From for ApiError { + fn from(value: sea_orm::DbErr) -> Self { + Self::Database(value.to_string()) + } +} + +impl From for ApiError { + fn from(value: argon2::password_hash::Error) -> Self { + Self::PasswordHash(value.to_string()) + } +} + +impl ToString for ApiError { + fn to_string(&self) -> String { + match self { + // 400 + ApiError::BadJsonBody(..) => "BadJsonBody", + ApiError::BadAuthTokenHeader(..) => "BadAuthTokenHeader", + ApiError::UserAlreadyExists { .. } => "UserAlreadyExists", + ApiError::InvalidPassword => "InvalidPassword", + ApiError::NotAuthorized => "NotAuthorized", + // 500 + ApiError::Database(..) => "Database", + ApiError::PasswordHash(..) => "PasswordHash", + ApiError::InternalJwt(..) => "InternalJwt", + } + .to_string() + } +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + let kind = self.to_string(); + match self { + // 400 + ApiError::BadJsonBody(msg) => FailResponse(kind, msg).into_response(), + ApiError::BadAuthTokenHeader(msg) => FailResponse(kind, msg).into_response(), + ApiError::UserAlreadyExists { username } => FailResponse( + kind, + format!("user with username `{}` already exists", username), + ) + .into_response(), + ApiError::InvalidPassword => { + FailResponse(kind, "password is invalid".to_string()).into_response() + } + ApiError::NotAuthorized => { + FailResponse(kind, "user is not authorized".to_string()).into_response() + } + + // 500 + ApiError::Database(msg) => ErrorResponse(kind, msg).into_response(), + ApiError::PasswordHash(msg) => ErrorResponse(kind, msg).into_response(), + ApiError::InternalJwt(msg) => ErrorResponse(kind, msg).into_response(), + } + } +} diff --git a/src/extract/auth.rs b/src/extract/auth.rs new file mode 100644 index 0000000..cc357fd --- /dev/null +++ b/src/extract/auth.rs @@ -0,0 +1,36 @@ +use axum::extract::FromRequestParts; +use axum::http::request::Parts; +use axum_extra::TypedHeader; +use entity::users; +use headers::authorization::{Authorization, Bearer}; +use sea_orm::EntityTrait; + +use crate::{ApiError, AppState, validate_jwt}; + +pub struct Auth(pub users::Model); + +impl FromRequestParts for Auth { + type Rejection = ApiError; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let token_header = + TypedHeader::>::from_request_parts(parts, state).await?; + + let jwt_claims = validate_jwt(token_header.token(), &state.secret) + .map_err(|_| ApiError::NotAuthorized)?; + + let user = users::Entity::find_by_id(jwt_claims.sub) + .one(&state.db) + .await? + .ok_or(ApiError::NotAuthorized)?; + + if jwt_claims.iat < user.password_issue_date.and_utc().timestamp() { + return Err(ApiError::NotAuthorized); + } + + Ok(Auth(user)) + } +} diff --git a/src/extract/json.rs b/src/extract/json.rs new file mode 100644 index 0000000..cfde15b --- /dev/null +++ b/src/extract/json.rs @@ -0,0 +1,19 @@ +use axum::extract::rejection::JsonRejection; +use axum::extract::{FromRequest, Request}; + +use crate::error::ApiError; + +pub struct ApiJson(pub T); + +impl FromRequest for ApiJson +where + axum::Json: FromRequest, + S: Send + Sync, +{ + type Rejection = ApiError; + + #[inline] + async fn from_request(req: Request, state: &S) -> Result { + Ok(Self(axum::Json::::from_request(req, state).await?.0)) + } +} diff --git a/src/extract/mod.rs b/src/extract/mod.rs new file mode 100644 index 0000000..b46a610 --- /dev/null +++ b/src/extract/mod.rs @@ -0,0 +1,5 @@ +mod auth; +mod json; + +pub use auth::Auth; +pub use json::ApiJson; diff --git a/src/main.rs b/src/main.rs index c53664a..3b3f868 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,30 @@ +mod auth; mod error; +mod extract; mod response; mod routers; +mod state; +pub use auth::{JwtClaims, create_jwt, create_password, validate_jwt, validate_password}; +pub use error::{ApiError, ApiResult}; +pub use response::{ErrorResponse, FailResponse, SuccessResponse}; +pub use state::AppState; + +use sea_orm::Database; use tokio::net::TcpListener; -use tracing::{Level, info}; +use tower_http::trace::TraceLayer; +use tracing::info; +use tracing_subscriber::EnvFilter; + +async fn app_state() -> AppState { + let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + let secret = std::env::var("SECRET").expect("SECRET must be set"); + + AppState { + db: Database::connect(db_url).await.unwrap(), + secret: secret, + } +} async fn listener() -> TcpListener { let addr = std::env::var("SERVER_BIND").expect("SERVER_BIND must be set"); @@ -13,10 +34,13 @@ async fn listener() -> TcpListener { #[tokio::main] async fn main() { tracing_subscriber::fmt() - .with_max_level(Level::DEBUG) + .with_env_filter(EnvFilter::new("info,sqlx=warn,tower_http=debug")) .init(); - let router = routers::router(); + let state = app_state().await; + let router = routers::router() + .layer(TraceLayer::new_for_http()) + .with_state(state); let listener = listener().await; info!( diff --git a/src/response.rs b/src/response.rs index 8d505a5..25c3008 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,32 +1,52 @@ +use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use serde::Serialize; use serde_json::json; -pub enum ApiResponse { - Success(T), - Fail(T), - Error(String), -} +pub struct SuccessResponse(pub T); +pub struct FailResponse(pub String, pub String); +pub struct ErrorResponse(pub String, pub String); -impl IntoResponse for ApiResponse +impl IntoResponse for SuccessResponse where T: Serialize, { fn into_response(self) -> Response { - axum::Json(match self { - ApiResponse::Success(data) => json!({ + ( + StatusCode::OK, + axum::Json(json!({ "status": "success", - "data": data - }), - ApiResponse::Fail(data) => json!({ + "data": self.0 + })), + ) + .into_response() + } +} + +impl IntoResponse for FailResponse { + fn into_response(self) -> Response { + ( + StatusCode::BAD_REQUEST, + axum::Json(json!({ "status": "fail", - "data": data - }), - ApiResponse::Error(message) => json!({ + "kind": self.0, + "message": self.1 + })), + ) + .into_response() + } +} + +impl IntoResponse for ErrorResponse { + fn into_response(self) -> Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + axum::Json(json!({ "status": "error", - "message": message - }), - }) - .into_response() + "kind": self.0, + "message": self.1 + })), + ) + .into_response() } } diff --git a/src/routers/account.rs b/src/routers/account.rs index 8192133..98ee61d 100644 --- a/src/routers/account.rs +++ b/src/routers/account.rs @@ -1,24 +1,188 @@ use axum::Router; -use axum::response::IntoResponse; -use axum::routing::{get, post}; +use axum::extract::State; +use axum::routing::{delete, get, post, put}; +use chrono::{DateTime, Duration, Utc}; +use entity::users::{self}; +use sea_orm::ActiveValue::Set; +use sea_orm::{ + ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait, QueryFilter, +}; +use serde::{Deserialize, Serialize}; -use crate::response::ApiResponse; +use crate::extract::{ApiJson, Auth}; +use crate::{ + ApiError, ApiResult, AppState, JwtClaims, SuccessResponse, create_jwt, create_password, + validate_password, +}; -async fn me() -> impl IntoResponse { - ApiResponse::Success("Me") +#[derive(Serialize)] +struct Account { + id: i64, + username: String, + first_name: String, + last_name: String, } -async fn register() -> impl IntoResponse { - ApiResponse::Success("Register") +#[derive(Serialize)] +struct Token { + token: String, + expired_at: DateTime, } -async fn login() -> impl IntoResponse { - ApiResponse::Success("Login") +async fn me(Auth(user): Auth) -> ApiResult { + return Ok(SuccessResponse(Account { + id: user.id, + username: user.username, + first_name: user.first_name, + last_name: user.last_name, + })); } -pub(crate) fn router() -> Router { +#[derive(Deserialize)] +struct RegisterRequest { + username: String, + password: String, + first_name: String, + last_name: String, +} + +async fn register( + State(state): State, + ApiJson(req): ApiJson, +) -> ApiResult { + let user_exists = users::Entity::find() + .filter(users::Column::Username.eq(&req.username)) + .one(&state.db) + .await? + .is_some(); + + if user_exists { + return Err(ApiError::UserAlreadyExists { + username: req.username, + }); + } + + let user = users::ActiveModel { + username: Set(req.username), + password_hash: Set(create_password(&req.password)?), + password_issue_date: Set(Utc::now().naive_utc()), + first_name: Set(req.first_name), + last_name: Set(req.last_name), + ..Default::default() + } + .insert(&state.db) + .await?; + + Ok(SuccessResponse(Account { + id: user.id, + username: user.username, + first_name: user.first_name, + last_name: user.last_name, + })) +} + +#[derive(Deserialize, Default)] +enum TokenLifetime { + Day = 1, + #[default] + Week = 7, + Month = 31, +} + +#[derive(Deserialize)] +struct LoginRequest { + username: String, + password: String, + #[serde(default)] + token_lifetime: TokenLifetime, +} + +async fn login( + State(state): State, + ApiJson(req): ApiJson, +) -> ApiResult { + let user = users::Entity::find() + .filter(users::Column::Username.eq(&req.username)) + .one(&state.db) + .await? + .ok_or(ApiError::InvalidPassword)?; + + if !validate_password(&req.password, &user.password_hash)? { + return Err(ApiError::InvalidPassword); + } + + let expired_at = Utc::now() + Duration::days(req.token_lifetime as i64); + + let token = create_jwt( + &JwtClaims { + sub: user.id, + iat: user.password_issue_date.and_utc().timestamp(), + exp: expired_at.timestamp(), + }, + &state.secret, + ) + .map_err(|e| ApiError::InternalJwt(e.to_string()))?; + + Ok(SuccessResponse(Token { token, expired_at })) +} + +#[derive(Deserialize)] +struct ChangePasswordRequest { + old_password: String, + new_password: String, +} + +async fn change_password( + State(state): State, + Auth(user): Auth, + ApiJson(req): ApiJson, +) -> ApiResult { + if !validate_password(&req.old_password, &user.password_hash)? { + return Err(ApiError::InvalidPassword); + } + + let mut active_user = user.into_active_model(); + active_user.password_hash = Set(create_password(&req.new_password)?); + active_user.password_issue_date = Set(Utc::now().naive_utc()); + + let user = active_user.update(&state.db).await?; + Ok(SuccessResponse(Account { + id: user.id, + username: user.username, + first_name: user.first_name, + last_name: user.last_name, + })) +} + +#[derive(Deserialize)] +struct DeleteUserRequest { + password: String, +} + +async fn delete_account( + State(state): State, + Auth(user): Auth, + ApiJson(req): ApiJson, +) -> ApiResult { + if !validate_password(&req.password, &user.password_hash)? { + return Err(ApiError::InvalidPassword); + } + + user.clone().delete(&state.db).await?; + + Ok(SuccessResponse(Account { + id: user.id, + username: user.username, + first_name: user.first_name, + last_name: user.last_name, + })) +} + +pub(crate) fn router() -> Router { Router::new() .route("/me", get(me)) .route("/register", post(register)) .route("/login", post(login)) + .route("/change_password", put(change_password)) + .route("/delete", delete(delete_account)) } diff --git a/src/routers/mod.rs b/src/routers/mod.rs index ee925f0..b57f71d 100644 --- a/src/routers/mod.rs +++ b/src/routers/mod.rs @@ -1,15 +1,9 @@ mod account; use axum::Router; -use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}; -use tracing::Level; -pub(crate) fn router() -> Router { - let trace_layer = TraceLayer::new_for_http() - .on_request(DefaultOnRequest::new().level(Level::INFO)) - .on_response(DefaultOnResponse::new().level(Level::INFO)); +use crate::state::AppState; - Router::new() - .layer(trace_layer) - .nest("/account", account::router()) +pub(crate) fn router() -> Router { + Router::new().nest("/account", account::router()) } diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..9779b46 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,7 @@ +use sea_orm::DatabaseConnection; + +#[derive(Clone)] +pub struct AppState { + pub db: DatabaseConnection, + pub secret: String, +} -- cgit v1.2.3