From 9747260aab211febd59f7ab655f15016412deb8e Mon Sep 17 00:00:00 2001 From: traxys Date: Mon, 29 May 2023 00:15:55 +0200 Subject: [PATCH] server: Implement handling of households --- Cargo.lock | 4 ++ Cargo.toml | 2 + api/Cargo.toml | 1 + api/src/lib.rs | 32 +++++++++ src/routes/household.rs | 149 ++++++++++++++++++++++++++++++++++++++++ src/routes/mod.rs | 40 +++++++++-- 6 files changed, 222 insertions(+), 6 deletions(-) create mode 100644 src/routes/household.rs diff --git a/Cargo.lock b/Cargo.lock index 8fc9c5a..eaf054d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,6 +62,7 @@ name = "api" version = "0.1.0" dependencies = [ "serde", + "uuid", ] [[package]] @@ -2254,6 +2255,7 @@ dependencies = [ "jwt-simple", "migration", "sea-orm", + "sea-query", "serde", "sha2", "thiserror", @@ -2261,6 +2263,7 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", + "uuid", ] [[package]] @@ -3358,6 +3361,7 @@ version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "345444e32442451b267fc254ae85a209c64be56d2890e601a0c37ff0c3c5ecd2" dependencies = [ + "getrandom", "serde", ] diff --git a/Cargo.toml b/Cargo.toml index 814c01e..56892f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ migration = { path = "./migration" } thiserror = "1.0.40" tower-http = { version = "0.4.0", features = ["cors", "fs"] } sha2 = "0.10" +uuid = { version = "1.3", features = ["v4"] } +sea-query = "0.28" [dependencies.sea-orm] version = "0.11" diff --git a/api/Cargo.toml b/api/Cargo.toml index 7efb64a..eb542db 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" [dependencies] serde = { version = "1.0.163", features = ["derive"] } +uuid = "1.3.3" diff --git a/api/src/lib.rs b/api/src/lib.rs index b68cd6a..21b1e51 100644 --- a/api/src/lib.rs +++ b/api/src/lib.rs @@ -1,4 +1,7 @@ +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; +use uuid::Uuid; #[derive(Serialize, Deserialize)] pub struct LoginRequest { @@ -10,3 +13,32 @@ pub struct LoginRequest { pub struct LoginResponse { pub token: String, } + +#[derive(Serialize, Deserialize)] +pub struct EmptyResponse {} + +#[derive(Serialize, Deserialize)] +pub struct Household { + pub name: String, + pub members: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct Households { + pub households: HashMap, +} + +#[derive(Serialize, Deserialize)] +pub struct CreateHouseholdRequest { + pub name: String, +} + +#[derive(Serialize, Deserialize)] +pub struct CreateHouseholdResponse { + pub id: Uuid, +} + +#[derive(Serialize, Deserialize)] +pub struct AddToHouseholdRequest { + pub user: Uuid, +} diff --git a/src/routes/household.rs b/src/routes/household.rs new file mode 100644 index 0000000..e8dc37b --- /dev/null +++ b/src/routes/household.rs @@ -0,0 +1,149 @@ +use std::collections::HashMap; + +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts, Path, State}, + http::request::Parts, + Json, +}; +use sea_orm::{prelude::*, ActiveValue}; +use sea_query::OnConflict; + +use api::{ + AddToHouseholdRequest, CreateHouseholdRequest, CreateHouseholdResponse, EmptyResponse, + Households, +}; + +use super::{AppState, AuthenticatedUser, RouteError}; +use crate::entity::{household, household_members, prelude::*}; + +#[derive(Debug)] +pub(super) struct AuthorizedHousehold(Uuid); + +#[async_trait] +impl FromRequestParts for AuthorizedHousehold +where + S: Send + Sync, + AppState: FromRef, +{ + type Rejection = RouteError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let State(app_state): State = State::from_request_parts(parts, state) + .await + .expect("Could not get state"); + + let user = AuthenticatedUser::from_request_parts(parts, state).await?; + + let Path(household): Path = Path::from_request_parts(parts, state).await?; + + let matching_count = user + .model + .find_related(Household) + .filter(household::Column::Id.eq(household)) + .count(&app_state.db) + .await?; + + match matching_count { + 0 => Err(RouteError::Unauthorized), + _ => Ok(AuthorizedHousehold(household)), + } + } +} + +pub(super) async fn list( + user: AuthenticatedUser, + state: State, +) -> super::JsonResult { + let related_households = user.model.find_related(Household).all(&state.db).await?; + + let mut households = HashMap::new(); + + for household in related_households { + let members = household.find_related(User).all(&state.db).await?; + households.insert( + household.id, + api::Household { + name: household.name, + members: members.into_iter().map(|m| m.id).collect(), + }, + ); + } + + Ok(Json(Households { households })) +} + +pub(super) async fn create( + user: AuthenticatedUser, + state: State, + Json(request): Json, +) -> super::JsonResult { + let household = household::ActiveModel { + name: ActiveValue::Set(request.name), + id: ActiveValue::Set(Uuid::new_v4()), + }; + + let household = household.insert(&state.db).await?; + + let member = household_members::ActiveModel { + household: ActiveValue::Set(household.id), + user: ActiveValue::Set(user.model.id), + }; + + member.insert(&state.db).await?; + + Ok(Json(CreateHouseholdResponse { id: household.id })) +} + +pub(super) async fn add_member( + AuthorizedHousehold(household): AuthorizedHousehold, + state: State, + Json(request): Json, +) -> super::JsonResult { + let member = household_members::ActiveModel { + household: ActiveValue::Set(household), + user: ActiveValue::Set(request.user), + }; + + if let Err(e) = HouseholdMembers::insert(member) + .on_conflict( + OnConflict::columns([ + household_members::Column::Household, + household_members::Column::User, + ]) + .do_nothing() + .to_owned(), + ) + .exec(&state.db) + .await + { + if !matches!(e, DbErr::RecordNotInserted) { + return Err(e.into()); + } + } + + Ok(Json(EmptyResponse {})) +} + +pub(super) async fn leave( + AuthorizedHousehold(household): AuthorizedHousehold, + user: AuthenticatedUser, + state: State, +) -> super::JsonResult { + HouseholdMembers::delete_by_id((household, user.model.id)) + .exec(&state.db) + .await?; + + let Some(household) = Household::find_by_id(household) + .one(&state.db) + .await? else { + return Ok(Json(EmptyResponse {})); + }; + + let member_count = household.find_related(User).count(&state.db).await?; + if member_count == 0 { + household.delete(&state.db).await?; + } + + Ok(Json(EmptyResponse {})) +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs index c275062..e4e9022 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -7,7 +7,7 @@ use axum::{ headers::{authorization::Bearer, Authorization}, http::{header::CONTENT_TYPE, request::Parts, HeaderValue, Method, StatusCode}, response::IntoResponse, - routing::post, + routing::{get, post, put}, Json, Router, TypedHeader, }; use jwt_simple::prelude::*; @@ -17,6 +17,8 @@ use tower_http::cors::{self, AllowOrigin, CorsLayer}; use crate::entity::{prelude::*, user}; +mod household; + #[derive(thiserror::Error, Debug)] enum RouteError { #[error("This account does not exist")] @@ -29,6 +31,10 @@ enum RouteError { UserJwt(jwt_simple::Error), #[error("Request is missing the bearer token")] MissingAuthorization, + #[error("User tried to edit an unauthorized ressource")] + Unauthorized, + #[error("Could not fetch required value from path")] + PathRejection(#[from] axum::extract::rejection::PathRejection), } impl IntoResponse for RouteError { @@ -44,6 +50,12 @@ impl IntoResponse for RouteError { tracing::debug!("Invalid user JWT: {e:?}"); (StatusCode::BAD_REQUEST, "Invalid authorization header").into_response() } + RouteError::PathRejection(p) => p.into_response(), + RouteError::Unauthorized => ( + StatusCode::UNAUTHORIZED, + "Unauthorized to access this ressource", + ) + .into_response(), e => { tracing::error!("Internal error: {e:?}"); StatusCode::INTERNAL_SERVER_ERROR.into_response() @@ -58,7 +70,7 @@ type AppState = Arc; #[derive(Debug)] struct AuthenticatedUser { - pub id: Uuid, + pub model: user::Model, } #[async_trait] @@ -85,9 +97,12 @@ where .verify_token::(bearer.token(), None) .map_err(RouteError::UserJwt)?; - Ok(AuthenticatedUser { - id: claims.subject.unwrap().parse().unwrap(), - }) + let model = User::find_by_id(claims.subject.unwrap().parse::().unwrap()) + .one(&app_state.db) + .await? + .unwrap(); + + Ok(AuthenticatedUser { model }) } } @@ -131,5 +146,18 @@ pub(crate) fn router(api_allowed: Option) -> Router { let mk_service = |m: Vec| cors_base.clone().allow_methods(m); - Router::new().route("/login", post(login).layer(mk_service(vec![Method::POST]))) + Router::new() + .route("/login", post(login).layer(mk_service(vec![Method::POST]))) + .route( + "/household", + get(household::list) + .post(household::create) + .layer(mk_service(vec![Method::GET, Method::POST])), + ) + .route( + "/household/:id", + put(household::add_member) + .delete(household::leave) + .layer(mk_service(vec![Method::PUT, Method::DELETE])), + ) }