aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/auth.rs49
-rw-r--r--src/error.rs89
-rw-r--r--src/extract/auth.rs36
-rw-r--r--src/extract/json.rs19
-rw-r--r--src/extract/mod.rs5
-rw-r--r--src/main.rs30
-rw-r--r--src/response.rs56
-rw-r--r--src/routers/account.rs184
-rw-r--r--src/routers/mod.rs12
-rw-r--r--src/state.rs7
10 files changed, 447 insertions, 40 deletions
diff --git a/src/auth.rs b/src/auth.rs
new file mode 100644
index 0000000..418f64e
--- /dev/null
+++ b/src/auth.rs
@@ -0,0 +1,49 @@
1use argon2::password_hash::rand_core::OsRng;
2use argon2::password_hash::{PasswordHasher, SaltString};
3use argon2::{Argon2, PasswordHash, PasswordVerifier};
4use jsonwebtoken::{self as jwt, DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6
7#[derive(Serialize, Deserialize)]
8pub struct JwtClaims {
9 pub sub: i64,
10 pub iat: i64,
11 pub exp: i64,
12}
13
14pub fn create_password(password: &str) -> argon2::password_hash::Result<String> {
15 Ok(Argon2::default()
16 .hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng))?
17 .to_string())
18}
19
20pub fn validate_password(
21 password: &str,
22 password_hash: &str,
23) -> argon2::password_hash::Result<bool> {
24 Ok(Argon2::default()
25 .verify_password(password.as_bytes(), &PasswordHash::new(password_hash)?)
26 .is_ok())
27}
28
29pub fn create_jwt(claims: &JwtClaims, secret: &str) -> jwt::errors::Result<String> {
30 jwt::encode(
31 &Header::default(),
32 claims,
33 &EncodingKey::from_secret(secret.as_bytes()),
34 )
35}
36
37pub fn validate_jwt(token: &str, secret: &str) -> jwt::errors::Result<JwtClaims> {
38 let mut validation = Validation::default();
39 validation.set_required_spec_claims(&["exp"]);
40 validation.validate_exp = true;
41 validation.leeway = 0;
42
43 Ok(jwt::decode(
44 token,
45 &DecodingKey::from_secret(secret.as_bytes()),
46 &validation,
47 )?
48 .claims)
49}
diff --git a/src/error.rs b/src/error.rs
new file mode 100644
index 0000000..2d0f911
--- /dev/null
+++ b/src/error.rs
@@ -0,0 +1,89 @@
1use axum::extract::rejection::JsonRejection;
2use axum::response::{IntoResponse, Response};
3use axum_extra::typed_header::TypedHeaderRejection;
4
5use crate::response::{ErrorResponse, FailResponse, SuccessResponse};
6
7pub type ApiResult<T> = Result<SuccessResponse<T>, ApiError>;
8
9pub enum ApiError {
10 // 400
11 BadJsonBody(String),
12 BadAuthTokenHeader(String),
13 UserAlreadyExists { username: String },
14 InvalidPassword,
15 NotAuthorized,
16 // 500
17 Database(String),
18 PasswordHash(String),
19 InternalJwt(String),
20}
21
22impl From<JsonRejection> for ApiError {
23 fn from(value: JsonRejection) -> Self {
24 Self::BadJsonBody(value.body_text())
25 }
26}
27
28impl From<TypedHeaderRejection> for ApiError {
29 fn from(value: TypedHeaderRejection) -> Self {
30 Self::BadAuthTokenHeader(value.to_string())
31 }
32}
33
34impl From<sea_orm::DbErr> for ApiError {
35 fn from(value: sea_orm::DbErr) -> Self {
36 Self::Database(value.to_string())
37 }
38}
39
40impl From<argon2::password_hash::Error> for ApiError {
41 fn from(value: argon2::password_hash::Error) -> Self {
42 Self::PasswordHash(value.to_string())
43 }
44}
45
46impl ToString for ApiError {
47 fn to_string(&self) -> String {
48 match self {
49 // 400
50 ApiError::BadJsonBody(..) => "BadJsonBody",
51 ApiError::BadAuthTokenHeader(..) => "BadAuthTokenHeader",
52 ApiError::UserAlreadyExists { .. } => "UserAlreadyExists",
53 ApiError::InvalidPassword => "InvalidPassword",
54 ApiError::NotAuthorized => "NotAuthorized",
55 // 500
56 ApiError::Database(..) => "Database",
57 ApiError::PasswordHash(..) => "PasswordHash",
58 ApiError::InternalJwt(..) => "InternalJwt",
59 }
60 .to_string()
61 }
62}
63
64impl IntoResponse for ApiError {
65 fn into_response(self) -> Response {
66 let kind = self.to_string();
67 match self {
68 // 400
69 ApiError::BadJsonBody(msg) => FailResponse(kind, msg).into_response(),
70 ApiError::BadAuthTokenHeader(msg) => FailResponse(kind, msg).into_response(),
71 ApiError::UserAlreadyExists { username } => FailResponse(
72 kind,
73 format!("user with username `{}` already exists", username),
74 )
75 .into_response(),
76 ApiError::InvalidPassword => {
77 FailResponse(kind, "password is invalid".to_string()).into_response()
78 }
79 ApiError::NotAuthorized => {
80 FailResponse(kind, "user is not authorized".to_string()).into_response()
81 }
82
83 // 500
84 ApiError::Database(msg) => ErrorResponse(kind, msg).into_response(),
85 ApiError::PasswordHash(msg) => ErrorResponse(kind, msg).into_response(),
86 ApiError::InternalJwt(msg) => ErrorResponse(kind, msg).into_response(),
87 }
88 }
89}
diff --git a/src/extract/auth.rs b/src/extract/auth.rs
new file mode 100644
index 0000000..cc357fd
--- /dev/null
+++ b/src/extract/auth.rs
@@ -0,0 +1,36 @@
1use axum::extract::FromRequestParts;
2use axum::http::request::Parts;
3use axum_extra::TypedHeader;
4use entity::users;
5use headers::authorization::{Authorization, Bearer};
6use sea_orm::EntityTrait;
7
8use crate::{ApiError, AppState, validate_jwt};
9
10pub struct Auth(pub users::Model);
11
12impl FromRequestParts<AppState> for Auth {
13 type Rejection = ApiError;
14
15 async fn from_request_parts(
16 parts: &mut Parts,
17 state: &AppState,
18 ) -> Result<Self, Self::Rejection> {
19 let token_header =
20 TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await?;
21
22 let jwt_claims = validate_jwt(token_header.token(), &state.secret)
23 .map_err(|_| ApiError::NotAuthorized)?;
24
25 let user = users::Entity::find_by_id(jwt_claims.sub)
26 .one(&state.db)
27 .await?
28 .ok_or(ApiError::NotAuthorized)?;
29
30 if jwt_claims.iat < user.password_issue_date.and_utc().timestamp() {
31 return Err(ApiError::NotAuthorized);
32 }
33
34 Ok(Auth(user))
35 }
36}
diff --git a/src/extract/json.rs b/src/extract/json.rs
new file mode 100644
index 0000000..cfde15b
--- /dev/null
+++ b/src/extract/json.rs
@@ -0,0 +1,19 @@
1use axum::extract::rejection::JsonRejection;
2use axum::extract::{FromRequest, Request};
3
4use crate::error::ApiError;
5
6pub struct ApiJson<T>(pub T);
7
8impl<S, T> FromRequest<S> for ApiJson<T>
9where
10 axum::Json<T>: FromRequest<S, Rejection = JsonRejection>,
11 S: Send + Sync,
12{
13 type Rejection = ApiError;
14
15 #[inline]
16 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
17 Ok(Self(axum::Json::<T>::from_request(req, state).await?.0))
18 }
19}
diff --git a/src/extract/mod.rs b/src/extract/mod.rs
new file mode 100644
index 0000000..b46a610
--- /dev/null
+++ b/src/extract/mod.rs
@@ -0,0 +1,5 @@
1mod auth;
2mod json;
3
4pub use auth::Auth;
5pub use json::ApiJson;
diff --git a/src/main.rs b/src/main.rs
index c53664a..3b3f868 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,9 +1,30 @@
1mod auth;
1mod error; 2mod error;
3mod extract;
2mod response; 4mod response;
3mod routers; 5mod routers;
6mod state;
4 7
8pub use auth::{JwtClaims, create_jwt, create_password, validate_jwt, validate_password};
9pub use error::{ApiError, ApiResult};
10pub use response::{ErrorResponse, FailResponse, SuccessResponse};
11pub use state::AppState;
12
13use sea_orm::Database;
5use tokio::net::TcpListener; 14use tokio::net::TcpListener;
6use tracing::{Level, info}; 15use tower_http::trace::TraceLayer;
16use tracing::info;
17use tracing_subscriber::EnvFilter;
18
19async fn app_state() -> AppState {
20 let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
21 let secret = std::env::var("SECRET").expect("SECRET must be set");
22
23 AppState {
24 db: Database::connect(db_url).await.unwrap(),
25 secret: secret,
26 }
27}
7 28
8async fn listener() -> TcpListener { 29async fn listener() -> TcpListener {
9 let addr = std::env::var("SERVER_BIND").expect("SERVER_BIND must be set"); 30 let addr = std::env::var("SERVER_BIND").expect("SERVER_BIND must be set");
@@ -13,10 +34,13 @@ async fn listener() -> TcpListener {
13#[tokio::main] 34#[tokio::main]
14async fn main() { 35async fn main() {
15 tracing_subscriber::fmt() 36 tracing_subscriber::fmt()
16 .with_max_level(Level::DEBUG) 37 .with_env_filter(EnvFilter::new("info,sqlx=warn,tower_http=debug"))
17 .init(); 38 .init();
18 39
19 let router = routers::router(); 40 let state = app_state().await;
41 let router = routers::router()
42 .layer(TraceLayer::new_for_http())
43 .with_state(state);
20 let listener = listener().await; 44 let listener = listener().await;
21 45
22 info!( 46 info!(
diff --git a/src/response.rs b/src/response.rs
index 8d505a5..25c3008 100644
--- a/src/response.rs
+++ b/src/response.rs
@@ -1,32 +1,52 @@
1use axum::http::StatusCode;
1use axum::response::{IntoResponse, Response}; 2use axum::response::{IntoResponse, Response};
2use serde::Serialize; 3use serde::Serialize;
3use serde_json::json; 4use serde_json::json;
4 5
5pub enum ApiResponse<T> { 6pub struct SuccessResponse<T>(pub T);
6 Success(T), 7pub struct FailResponse(pub String, pub String);
7 Fail(T), 8pub struct ErrorResponse(pub String, pub String);
8 Error(String),
9}
10 9
11impl<T> IntoResponse for ApiResponse<T> 10impl<T> IntoResponse for SuccessResponse<T>
12where 11where
13 T: Serialize, 12 T: Serialize,
14{ 13{
15 fn into_response(self) -> Response { 14 fn into_response(self) -> Response {
16 axum::Json(match self { 15 (
17 ApiResponse::Success(data) => json!({ 16 StatusCode::OK,
17 axum::Json(json!({
18 "status": "success", 18 "status": "success",
19 "data": data 19 "data": self.0
20 }), 20 })),
21 ApiResponse::Fail(data) => json!({ 21 )
22 .into_response()
23 }
24}
25
26impl IntoResponse for FailResponse {
27 fn into_response(self) -> Response {
28 (
29 StatusCode::BAD_REQUEST,
30 axum::Json(json!({
22 "status": "fail", 31 "status": "fail",
23 "data": data 32 "kind": self.0,
24 }), 33 "message": self.1
25 ApiResponse::Error(message) => json!({ 34 })),
35 )
36 .into_response()
37 }
38}
39
40impl IntoResponse for ErrorResponse {
41 fn into_response(self) -> Response {
42 (
43 StatusCode::INTERNAL_SERVER_ERROR,
44 axum::Json(json!({
26 "status": "error", 45 "status": "error",
27 "message": message 46 "kind": self.0,
28 }), 47 "message": self.1
29 }) 48 })),
30 .into_response() 49 )
50 .into_response()
31 } 51 }
32} 52}
diff --git a/src/routers/account.rs b/src/routers/account.rs
index 8192133..98ee61d 100644
--- a/src/routers/account.rs
+++ b/src/routers/account.rs
@@ -1,24 +1,188 @@
1use axum::Router; 1use axum::Router;
2use axum::response::IntoResponse; 2use axum::extract::State;
3use axum::routing::{get, post}; 3use axum::routing::{delete, get, post, put};
4use chrono::{DateTime, Duration, Utc};
5use entity::users::{self};
6use sea_orm::ActiveValue::Set;
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait, QueryFilter,
9};
10use serde::{Deserialize, Serialize};
4 11
5use crate::response::ApiResponse; 12use crate::extract::{ApiJson, Auth};
13use crate::{
14 ApiError, ApiResult, AppState, JwtClaims, SuccessResponse, create_jwt, create_password,
15 validate_password,
16};
6 17
7async fn me() -> impl IntoResponse { 18#[derive(Serialize)]
8 ApiResponse::Success("Me") 19struct Account {
20 id: i64,
21 username: String,
22 first_name: String,
23 last_name: String,
9} 24}
10 25
11async fn register() -> impl IntoResponse { 26#[derive(Serialize)]
12 ApiResponse::Success("Register") 27struct Token {
28 token: String,
29 expired_at: DateTime<Utc>,
13} 30}
14 31
15async fn login() -> impl IntoResponse { 32async fn me(Auth(user): Auth) -> ApiResult<Account> {
16 ApiResponse::Success("Login") 33 return Ok(SuccessResponse(Account {
34 id: user.id,
35 username: user.username,
36 first_name: user.first_name,
37 last_name: user.last_name,
38 }));
17} 39}
18 40
19pub(crate) fn router() -> Router { 41#[derive(Deserialize)]
42struct RegisterRequest {
43 username: String,
44 password: String,
45 first_name: String,
46 last_name: String,
47}
48
49async fn register(
50 State(state): State<AppState>,
51 ApiJson(req): ApiJson<RegisterRequest>,
52) -> ApiResult<Account> {
53 let user_exists = users::Entity::find()
54 .filter(users::Column::Username.eq(&req.username))
55 .one(&state.db)
56 .await?
57 .is_some();
58
59 if user_exists {
60 return Err(ApiError::UserAlreadyExists {
61 username: req.username,
62 });
63 }
64
65 let user = users::ActiveModel {
66 username: Set(req.username),
67 password_hash: Set(create_password(&req.password)?),
68 password_issue_date: Set(Utc::now().naive_utc()),
69 first_name: Set(req.first_name),
70 last_name: Set(req.last_name),
71 ..Default::default()
72 }
73 .insert(&state.db)
74 .await?;
75
76 Ok(SuccessResponse(Account {
77 id: user.id,
78 username: user.username,
79 first_name: user.first_name,
80 last_name: user.last_name,
81 }))
82}
83
84#[derive(Deserialize, Default)]
85enum TokenLifetime {
86 Day = 1,
87 #[default]
88 Week = 7,
89 Month = 31,
90}
91
92#[derive(Deserialize)]
93struct LoginRequest {
94 username: String,
95 password: String,
96 #[serde(default)]
97 token_lifetime: TokenLifetime,
98}
99
100async fn login(
101 State(state): State<AppState>,
102 ApiJson(req): ApiJson<LoginRequest>,
103) -> ApiResult<Token> {
104 let user = users::Entity::find()
105 .filter(users::Column::Username.eq(&req.username))
106 .one(&state.db)
107 .await?
108 .ok_or(ApiError::InvalidPassword)?;
109
110 if !validate_password(&req.password, &user.password_hash)? {
111 return Err(ApiError::InvalidPassword);
112 }
113
114 let expired_at = Utc::now() + Duration::days(req.token_lifetime as i64);
115
116 let token = create_jwt(
117 &JwtClaims {
118 sub: user.id,
119 iat: user.password_issue_date.and_utc().timestamp(),
120 exp: expired_at.timestamp(),
121 },
122 &state.secret,
123 )
124 .map_err(|e| ApiError::InternalJwt(e.to_string()))?;
125
126 Ok(SuccessResponse(Token { token, expired_at }))
127}
128
129#[derive(Deserialize)]
130struct ChangePasswordRequest {
131 old_password: String,
132 new_password: String,
133}
134
135async fn change_password(
136 State(state): State<AppState>,
137 Auth(user): Auth,
138 ApiJson(req): ApiJson<ChangePasswordRequest>,
139) -> ApiResult<Account> {
140 if !validate_password(&req.old_password, &user.password_hash)? {
141 return Err(ApiError::InvalidPassword);
142 }
143
144 let mut active_user = user.into_active_model();
145 active_user.password_hash = Set(create_password(&req.new_password)?);
146 active_user.password_issue_date = Set(Utc::now().naive_utc());
147
148 let user = active_user.update(&state.db).await?;
149 Ok(SuccessResponse(Account {
150 id: user.id,
151 username: user.username,
152 first_name: user.first_name,
153 last_name: user.last_name,
154 }))
155}
156
157#[derive(Deserialize)]
158struct DeleteUserRequest {
159 password: String,
160}
161
162async fn delete_account(
163 State(state): State<AppState>,
164 Auth(user): Auth,
165 ApiJson(req): ApiJson<DeleteUserRequest>,
166) -> ApiResult<Account> {
167 if !validate_password(&req.password, &user.password_hash)? {
168 return Err(ApiError::InvalidPassword);
169 }
170
171 user.clone().delete(&state.db).await?;
172
173 Ok(SuccessResponse(Account {
174 id: user.id,
175 username: user.username,
176 first_name: user.first_name,
177 last_name: user.last_name,
178 }))
179}
180
181pub(crate) fn router() -> Router<AppState> {
20 Router::new() 182 Router::new()
21 .route("/me", get(me)) 183 .route("/me", get(me))
22 .route("/register", post(register)) 184 .route("/register", post(register))
23 .route("/login", post(login)) 185 .route("/login", post(login))
186 .route("/change_password", put(change_password))
187 .route("/delete", delete(delete_account))
24} 188}
diff --git a/src/routers/mod.rs b/src/routers/mod.rs
index ee925f0..b57f71d 100644
--- a/src/routers/mod.rs
+++ b/src/routers/mod.rs
@@ -1,15 +1,9 @@
1mod account; 1mod account;
2 2
3use axum::Router; 3use axum::Router;
4use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer};
5use tracing::Level;
6 4
7pub(crate) fn router() -> Router { 5use crate::state::AppState;
8 let trace_layer = TraceLayer::new_for_http()
9 .on_request(DefaultOnRequest::new().level(Level::INFO))
10 .on_response(DefaultOnResponse::new().level(Level::INFO));
11 6
12 Router::new() 7pub(crate) fn router() -> Router<AppState> {
13 .layer(trace_layer) 8 Router::new().nest("/account", account::router())
14 .nest("/account", account::router())
15} 9}
diff --git a/src/state.rs b/src/state.rs
new file mode 100644
index 0000000..9779b46
--- /dev/null
+++ b/src/state.rs
@@ -0,0 +1,7 @@
1use sea_orm::DatabaseConnection;
2
3#[derive(Clone)]
4pub struct AppState {
5 pub db: DatabaseConnection,
6 pub secret: String,
7}