Start of template refactor
This commit is contained in:
parent
6bdfe1ad0e
commit
e14c1f0918
8 changed files with 1612 additions and 531 deletions
1514
Cargo.lock
generated
1514
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
17
Cargo.toml
17
Cargo.toml
|
|
@ -9,7 +9,7 @@ members = [".", "api", "gui", "migration", "web"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.71"
|
anyhow = "1.0.71"
|
||||||
axum = { version = "0.6.18", features = ["headers"] }
|
axum = { version = "0.7.2" }
|
||||||
base64 = "0.21.0"
|
base64 = "0.21.0"
|
||||||
jwt-simple = "0.11.5"
|
jwt-simple = "0.11.5"
|
||||||
serde = { version = "1.0.163", features = ["derive"] }
|
serde = { version = "1.0.163", features = ["derive"] }
|
||||||
|
|
@ -19,15 +19,24 @@ tracing-subscriber = "0.3.17"
|
||||||
api = { path = "./api" }
|
api = { path = "./api" }
|
||||||
migration = { path = "./migration" }
|
migration = { path = "./migration" }
|
||||||
thiserror = "1.0.40"
|
thiserror = "1.0.40"
|
||||||
tower-http = { version = "0.4.0", features = ["cors", "fs"] }
|
tower-http = { version = "0.5.0", features = ["cors", "fs"] }
|
||||||
sha2 = "0.10"
|
sha2 = "0.10"
|
||||||
uuid = { version = "1.3", features = ["v4"] }
|
uuid = { version = "1.3", features = ["v4"] }
|
||||||
sea-query = "0.28"
|
sea-query = "0.30"
|
||||||
openidconnect = "3.3.0"
|
openidconnect = "3.3.0"
|
||||||
envious = "0.2.2"
|
envious = "0.2.2"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
urlencoding = "2.1.3"
|
urlencoding = "2.1.3"
|
||||||
|
tower-sessions = { version = "0.8.2", features = [
|
||||||
|
"postgres-store",
|
||||||
|
"deletion-task",
|
||||||
|
] }
|
||||||
|
tower = "0.4.13"
|
||||||
|
time = "0.3.31"
|
||||||
|
maud = { git = "https://github.com/lambda-fairy/maud", version = "0.25.0", features = [
|
||||||
|
"axum",
|
||||||
|
] }
|
||||||
|
|
||||||
[dependencies.sea-orm]
|
[dependencies.sea-orm]
|
||||||
version = "0.11"
|
version = "0.12"
|
||||||
features = ["runtime-tokio-rustls", "sqlx-postgres", "sqlx-sqlite"]
|
features = ["runtime-tokio-rustls", "sqlx-postgres", "sqlx-sqlite"]
|
||||||
|
|
|
||||||
|
|
@ -12,5 +12,5 @@ path = "src/lib.rs"
|
||||||
async-std = { version = "1", features = ["attributes", "tokio1"] }
|
async-std = { version = "1", features = ["attributes", "tokio1"] }
|
||||||
|
|
||||||
[dependencies.sea-orm-migration]
|
[dependencies.sea-orm-migration]
|
||||||
version = "0.11.0"
|
version = "0.12"
|
||||||
features = ["runtime-tokio-rustls", "sqlx-postgres", "sqlx-sqlite"]
|
features = ["runtime-tokio-rustls", "sqlx-postgres", "sqlx-sqlite"]
|
||||||
|
|
|
||||||
11
public/household_selection.css
Normal file
11
public/household_selection.css
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
html,
|
||||||
|
body {
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
padding-top: 40px;
|
||||||
|
padding-bottom: 40px;
|
||||||
|
}
|
||||||
159
src/app/household.rs
Normal file
159
src/app/household.rs
Normal file
|
|
@ -0,0 +1,159 @@
|
||||||
|
use axum::{
|
||||||
|
async_trait,
|
||||||
|
extract::{FromRef, FromRequestParts, State},
|
||||||
|
http::request::Parts,
|
||||||
|
response::Redirect,
|
||||||
|
Form,
|
||||||
|
};
|
||||||
|
use maud::{html, Markup};
|
||||||
|
use sea_orm::{prelude::*, ActiveValue};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tower_sessions::Session;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::entity::{household, household_members, prelude::*};
|
||||||
|
|
||||||
|
use super::{base_page_with_head, AppState, AuthenticatedUser, RedirectOrError, RouteError};
|
||||||
|
|
||||||
|
pub(super) struct CurrentHousehold(pub household::Model);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S> FromRequestParts<S> for CurrentHousehold
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
AppState: FromRef<S>,
|
||||||
|
{
|
||||||
|
type Rejection = RedirectOrError;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let State(app_state): State<AppState> = State::from_request_parts(parts, state)
|
||||||
|
.await
|
||||||
|
.expect("Could not get state");
|
||||||
|
|
||||||
|
let session = Session::from_request_parts(parts, state)
|
||||||
|
.await
|
||||||
|
.map_err(|_| RouteError::SessionExtract)?;
|
||||||
|
|
||||||
|
let id: Uuid = session
|
||||||
|
.get("household")
|
||||||
|
.await
|
||||||
|
.map_err(RouteError::from)?
|
||||||
|
.ok_or_else(|| Redirect::to("/household/select"))?;
|
||||||
|
|
||||||
|
Ok(Self(
|
||||||
|
Household::find_by_id(id)
|
||||||
|
.one(&app_state.db)
|
||||||
|
.await
|
||||||
|
.map_err(RouteError::from)?
|
||||||
|
.unwrap(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_household_modal() -> Markup {
|
||||||
|
html! {
|
||||||
|
button .btn.btn-lg.btn-primary data-bs-toggle="modal" data-bs-target="#newHsModal" {
|
||||||
|
"New household"
|
||||||
|
}
|
||||||
|
|
||||||
|
.modal.fade
|
||||||
|
#newHsModal
|
||||||
|
tabindex="-1"
|
||||||
|
aria-labelledby="newHsModalLabel"
|
||||||
|
aria-hidden="true" {
|
||||||
|
.modal-dialog.modal-dialog-centered {
|
||||||
|
.modal-content {
|
||||||
|
.modal-header {
|
||||||
|
h1 .modal-title."fs-5" #newHsModalLabel { "New Household" }
|
||||||
|
input
|
||||||
|
type="reset"
|
||||||
|
form="newHsModalForm"
|
||||||
|
.btn-close
|
||||||
|
data-bs-dismiss="modal"
|
||||||
|
aria-label="Close"
|
||||||
|
value="";
|
||||||
|
}
|
||||||
|
.modal-body {
|
||||||
|
form #newHsModalForm method="post" action="/household/create" {
|
||||||
|
.form-floating {
|
||||||
|
input
|
||||||
|
.form-control
|
||||||
|
#newHsName
|
||||||
|
placeholder="Household name"
|
||||||
|
name="name";
|
||||||
|
label for="newHsName" { "Household name" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.modal-footer {
|
||||||
|
input
|
||||||
|
type="reset"
|
||||||
|
form="newHsModalForm"
|
||||||
|
.btn.btn-danger
|
||||||
|
data-bs-dismiss="modal"
|
||||||
|
value="Cancel";
|
||||||
|
input
|
||||||
|
type="submit"
|
||||||
|
form="newHsModalForm"
|
||||||
|
.btn.btn-primary
|
||||||
|
data-bs-dismiss="modal"
|
||||||
|
value="Create";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn household_selection(
|
||||||
|
state: State<AppState>,
|
||||||
|
user: AuthenticatedUser,
|
||||||
|
) -> Result<Markup, RouteError> {
|
||||||
|
let related_households = user.model.find_related(Household).all(&state.db).await?;
|
||||||
|
|
||||||
|
Ok(base_page_with_head(
|
||||||
|
html! {
|
||||||
|
div ."col-sm-3".m-auto."p-2".text-center.border.rounded {
|
||||||
|
h1 .h3 { "Available" }
|
||||||
|
hr;
|
||||||
|
@for hs in related_households {
|
||||||
|
a href={"/household/select/" (hs.id)} .btn.btn-secondary."m-1" {
|
||||||
|
(hs.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hr;
|
||||||
|
(new_household_modal())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Some(html! {
|
||||||
|
link href="/household_selection.css" rel="stylesheet";
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub(super) struct CreateHousehold {
|
||||||
|
name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn create_household(
|
||||||
|
user: AuthenticatedUser,
|
||||||
|
state: State<AppState>,
|
||||||
|
Form(form): Form<CreateHousehold>,
|
||||||
|
) -> Result<Redirect, RouteError> {
|
||||||
|
let household = household::ActiveModel {
|
||||||
|
name: ActiveValue::Set(form.name),
|
||||||
|
id: ActiveValue::Set(Uuid::new_v4()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let household = household.insert(&state.db).await?;
|
||||||
|
|
||||||
|
let member = household_members::ActiveModel {
|
||||||
|
household: ActiveValue::Set(household.id),
|
||||||
|
user: ActiveValue::Set(user.model.id),
|
||||||
|
};
|
||||||
|
|
||||||
|
member.insert(&state.db).await?;
|
||||||
|
|
||||||
|
Ok(Redirect::to("/household/select"))
|
||||||
|
}
|
||||||
315
src/app/mod.rs
Normal file
315
src/app/mod.rs
Normal file
|
|
@ -0,0 +1,315 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
async_trait,
|
||||||
|
extract::{FromRef, FromRequestParts, Path, Query, State},
|
||||||
|
handler::HandlerWithoutStateExt,
|
||||||
|
http::{request::Parts, StatusCode},
|
||||||
|
response::{IntoResponse, Redirect},
|
||||||
|
routing::{get, post},
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use maud::{html, Markup};
|
||||||
|
use sea_orm::{prelude::*, ActiveValue, DbErr, TransactionError};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tower_http::services::ServeDir;
|
||||||
|
use tower_sessions::{session, Session};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::entity::{prelude::*, user};
|
||||||
|
|
||||||
|
use self::household::CurrentHousehold;
|
||||||
|
|
||||||
|
mod household;
|
||||||
|
|
||||||
|
type AppState = Arc<crate::AppState>;
|
||||||
|
|
||||||
|
pub fn base_page_with_head(body: Markup, head: Option<Markup>) -> Markup {
|
||||||
|
html! {
|
||||||
|
(maud::DOCTYPE)
|
||||||
|
html lang="en" data-bs-theme="dark" {
|
||||||
|
head {
|
||||||
|
meta charset="utf-8";
|
||||||
|
meta name="viewport" content="width=device-width, initial-scale=1";
|
||||||
|
title { "Regalade" }
|
||||||
|
link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css"
|
||||||
|
rel="stylesheet"
|
||||||
|
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN"
|
||||||
|
crossorigin="anonymous";
|
||||||
|
@if let Some(head) = head {
|
||||||
|
(head)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
(body)
|
||||||
|
script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/js/bootstrap.bundle.min.js"
|
||||||
|
integrity="sha384-C6RzsynM9kWDrMNeT87bh95OGNyZPhcTNXj1NW7RuBCsyN/o0jlpcV8Qyq46cDfL"
|
||||||
|
crossorigin="anonymous" {}
|
||||||
|
script src="https://unpkg.com/htmx.org@1.9.10"
|
||||||
|
integrity="sha384-D1Kt99CQMDuVetoL1lrYwg5t+9QdHe7NLX/SoJYkXDFfX37iInKRy5xLSi8nO7UC"
|
||||||
|
crossorigin="anonymous" {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn base_page(body: Markup) -> Markup {
|
||||||
|
base_page_with_head(body, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn error_alert(message: Option<impl maud::Render>) -> Markup {
|
||||||
|
html! {
|
||||||
|
@if let Some(msg) = message {
|
||||||
|
.alert.alert-danger { (msg) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub fn base_page_with_error(error: Option<Markup>, body: Markup) -> Markup {
|
||||||
|
// base_page_with_head(
|
||||||
|
// html! {
|
||||||
|
// @if let Some(e) = error {
|
||||||
|
// (e)
|
||||||
|
// }
|
||||||
|
// (body)
|
||||||
|
// },
|
||||||
|
// None,
|
||||||
|
// )
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub fn error_page(code: StatusCode, message: impl maud::Render) -> (StatusCode, Markup) {
|
||||||
|
(
|
||||||
|
code,
|
||||||
|
base_page(html! {
|
||||||
|
div .container.text-center {
|
||||||
|
h1 { "Regalade" }
|
||||||
|
h2 { "Error" }
|
||||||
|
hr;
|
||||||
|
(message)
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
enum RouteError {
|
||||||
|
#[error("Database encountered an error")]
|
||||||
|
Db(#[from] DbErr),
|
||||||
|
#[error("Failure to login with OAuth2")]
|
||||||
|
Oauth2Failure,
|
||||||
|
#[error("Could not fetch required value from path")]
|
||||||
|
PathRejection(#[from] axum::extract::rejection::PathRejection),
|
||||||
|
#[error("The supplied ressource does not exist")]
|
||||||
|
RessourceNotFound,
|
||||||
|
#[error("The request was malformed")]
|
||||||
|
InvalidRequest(String),
|
||||||
|
#[error("Error in DB transaction")]
|
||||||
|
TxnError(#[from] TransactionError<Box<RouteError>>),
|
||||||
|
#[error("Error in session management")]
|
||||||
|
Session(#[from] session::Error),
|
||||||
|
#[error("Could not extract session")]
|
||||||
|
SessionExtract,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DbErr> for Box<RouteError> {
|
||||||
|
fn from(value: DbErr) -> Self {
|
||||||
|
Box::new(value.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for RouteError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
match self {
|
||||||
|
RouteError::TxnError(TransactionError::Transaction(e)) => e.into_response(),
|
||||||
|
RouteError::PathRejection(p) => error_page(p.status(), p.body_text()).into_response(),
|
||||||
|
RouteError::Oauth2Failure => {
|
||||||
|
error_page(StatusCode::BAD_REQUEST, "Failure to login with OAuth2").into_response()
|
||||||
|
}
|
||||||
|
RouteError::RessourceNotFound => not_found().into_response(),
|
||||||
|
RouteError::InvalidRequest(reason) => {
|
||||||
|
error_page(StatusCode::BAD_REQUEST, reason).into_response()
|
||||||
|
}
|
||||||
|
e => {
|
||||||
|
tracing::error!("Internal error: {e:?}");
|
||||||
|
error_page(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn oidc_login(State(state): State<AppState>) -> Redirect {
|
||||||
|
tracing::info!("Starting OIDC login");
|
||||||
|
let oidc = state.oidc.as_ref().unwrap();
|
||||||
|
|
||||||
|
let redirect_url = oidc.start_auth();
|
||||||
|
|
||||||
|
Redirect::to(redirect_url.as_str())
|
||||||
|
}
|
||||||
|
|
||||||
|
enum RedirectOrError {
|
||||||
|
Redirect(Redirect),
|
||||||
|
Err(RouteError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Redirect> for RedirectOrError {
|
||||||
|
fn from(value: Redirect) -> Self {
|
||||||
|
Self::Redirect(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RouteError> for RedirectOrError {
|
||||||
|
fn from(value: RouteError) -> Self {
|
||||||
|
Self::Err(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for RedirectOrError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
match self {
|
||||||
|
RedirectOrError::Redirect(r) => r.into_response(),
|
||||||
|
RedirectOrError::Err(e) => e.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AuthenticatedUser {
|
||||||
|
pub model: user::Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S> FromRequestParts<S> for AuthenticatedUser
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
AppState: FromRef<S>,
|
||||||
|
{
|
||||||
|
type Rejection = RedirectOrError;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let State(app_state): State<AppState> = State::from_request_parts(parts, state)
|
||||||
|
.await
|
||||||
|
.expect("Could not get state");
|
||||||
|
|
||||||
|
let session = Session::from_request_parts(parts, state)
|
||||||
|
.await
|
||||||
|
.map_err(|_| RouteError::SessionExtract)?;
|
||||||
|
|
||||||
|
let id: Uuid = session
|
||||||
|
.get("id")
|
||||||
|
.await
|
||||||
|
.map_err(RouteError::from)?
|
||||||
|
.ok_or_else(|| Redirect::to("/login"))?;
|
||||||
|
|
||||||
|
let model = User::find_by_id(id)
|
||||||
|
.one(&app_state.db)
|
||||||
|
.await
|
||||||
|
.map_err(RouteError::from)?
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(Self { model })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OidcRedirectParams {
|
||||||
|
state: String,
|
||||||
|
code: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn oidc_login_finish(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<Uuid>,
|
||||||
|
Query(redirect): Query<OidcRedirectParams>,
|
||||||
|
session: Session,
|
||||||
|
) -> Result<Redirect, RouteError> {
|
||||||
|
match state
|
||||||
|
.oidc
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.redirected(id, redirect.state, redirect.code)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Error when finishing OAuth2 flow: {e:#?}");
|
||||||
|
Err(RouteError::Oauth2Failure)
|
||||||
|
}
|
||||||
|
Ok(account) => {
|
||||||
|
let user = User::find()
|
||||||
|
.filter(
|
||||||
|
user::Column::Name
|
||||||
|
.eq(&account.name)
|
||||||
|
.or(user::Column::OpenIdSubject.eq(&account.sub)),
|
||||||
|
)
|
||||||
|
.one(&state.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let user = match user {
|
||||||
|
None => {
|
||||||
|
let model = user::ActiveModel {
|
||||||
|
id: ActiveValue::Set(Uuid::new_v4()),
|
||||||
|
name: ActiveValue::Set(account.name),
|
||||||
|
password: ActiveValue::NotSet,
|
||||||
|
open_id_subject: ActiveValue::Set(Some(account.sub)),
|
||||||
|
};
|
||||||
|
|
||||||
|
model.insert(&state.db).await?
|
||||||
|
}
|
||||||
|
Some(user) => user,
|
||||||
|
};
|
||||||
|
|
||||||
|
session.insert("id", user.id).await?;
|
||||||
|
|
||||||
|
Ok(Redirect::to("/"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn not_found() -> (StatusCode, Markup) {
|
||||||
|
error_page(StatusCode::NOT_FOUND, "Page not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn index(_user: AuthenticatedUser, household: CurrentHousehold) -> Markup {
|
||||||
|
base_page(html! {
|
||||||
|
"Hello world in " (household.0.name) "!"
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[derive(Serialize, Deserialize, Debug)]
|
||||||
|
// enum UserError {
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// #[derive(Debug, Serialize, Deserialize)]
|
||||||
|
// struct ErrorQuery {
|
||||||
|
// #[serde(default)]
|
||||||
|
// err: Option<UserError>,
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// impl Render for ErrorQuery {
|
||||||
|
// fn render(&self) -> Markup {
|
||||||
|
// let err = match &self.err {
|
||||||
|
// Some(e) => e,
|
||||||
|
// None => return html! {},
|
||||||
|
// };
|
||||||
|
//
|
||||||
|
// match err {
|
||||||
|
// _ => todo!(),
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub(crate) fn router() -> Router<AppState> {
|
||||||
|
let router = Router::new();
|
||||||
|
|
||||||
|
let public = option_env!("REGALADE_PUBLIC_DIR")
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.unwrap_or_else(|| std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("public"));
|
||||||
|
|
||||||
|
tracing::debug!("Public directory: {public:?}");
|
||||||
|
|
||||||
|
router
|
||||||
|
.route("/", get(index))
|
||||||
|
.route("/login", get(oidc_login))
|
||||||
|
.route("/household/select", get(household::household_selection))
|
||||||
|
.route("/household/create", post(household::create_household))
|
||||||
|
.route("/login/redirect/:id", get(oidc_login_finish))
|
||||||
|
.fallback_service(ServeDir::new(public).fallback((|| async { not_found() }).into_service()))
|
||||||
|
}
|
||||||
59
src/main.rs
59
src/main.rs
|
|
@ -6,7 +6,7 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use axum::Router;
|
use axum::{error_handling::HandleErrorLayer, http::StatusCode, BoxError, Router};
|
||||||
use base64::{engine::general_purpose, Engine};
|
use base64::{engine::general_purpose, Engine};
|
||||||
use jwt_simple::prelude::HS256Key;
|
use jwt_simple::prelude::HS256Key;
|
||||||
use migration::{Migrator, MigratorTrait};
|
use migration::{Migrator, MigratorTrait};
|
||||||
|
|
@ -18,13 +18,19 @@ use openidconnect::{
|
||||||
};
|
};
|
||||||
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
|
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
|
||||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
use tower::ServiceBuilder;
|
||||||
use tower_http::services::{ServeDir, ServeFile};
|
use tower_http::services::{ServeDir, ServeFile};
|
||||||
|
use tower_sessions::{sqlx::PgPool, ExpiredDeletion, PostgresStore, SessionManagerLayer};
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
mod app;
|
||||||
pub(crate) mod entity;
|
pub(crate) mod entity;
|
||||||
mod routes;
|
mod routes;
|
||||||
|
|
||||||
|
const SESSION_DURATION: time::Duration = time::Duration::weeks(26);
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct Base64(pub(crate) HS256Key);
|
pub(crate) struct Base64(pub(crate) HS256Key);
|
||||||
|
|
||||||
|
|
@ -156,6 +162,7 @@ struct AppState {
|
||||||
jwt_secret: Base64,
|
jwt_secret: Base64,
|
||||||
db: DatabaseConnection,
|
db: DatabaseConnection,
|
||||||
oidc: Option<OpenidConnector>,
|
oidc: Option<OpenidConnector>,
|
||||||
|
sessions: Arc<PostgresStore>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct OpenidConnector {
|
struct OpenidConnector {
|
||||||
|
|
@ -225,13 +232,11 @@ struct OpenidAuthState {
|
||||||
pkce_verifier: PkceCodeVerifier,
|
pkce_verifier: PkceCodeVerifier,
|
||||||
csrf_token: CsrfToken,
|
csrf_token: CsrfToken,
|
||||||
nonce: Nonce,
|
nonce: Nonce,
|
||||||
source_url: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenidAccount {
|
pub struct OpenidAccount {
|
||||||
pub sub: String,
|
pub sub: String,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub source_url: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenidConnector {
|
impl OpenidConnector {
|
||||||
|
|
@ -253,7 +258,7 @@ impl OpenidConnector {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn start_auth(&self, source_url: String) -> Url {
|
pub fn start_auth(&self) -> Url {
|
||||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
|
||||||
let mut inflight = self.inflight.lock();
|
let mut inflight = self.inflight.lock();
|
||||||
|
|
@ -273,7 +278,7 @@ impl OpenidConnector {
|
||||||
.set_pkce_challenge(pkce_challenge)
|
.set_pkce_challenge(pkce_challenge)
|
||||||
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
||||||
self.domain
|
self.domain
|
||||||
.join(&format!("api/login/redirect/{}", slot.id))
|
.join(&format!("login/redirect/{}", slot.id))
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
)))
|
)))
|
||||||
.url();
|
.url();
|
||||||
|
|
@ -282,7 +287,6 @@ impl OpenidConnector {
|
||||||
pkce_verifier,
|
pkce_verifier,
|
||||||
csrf_token,
|
csrf_token,
|
||||||
nonce,
|
nonce,
|
||||||
source_url,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
url
|
url
|
||||||
|
|
@ -312,9 +316,7 @@ impl OpenidConnector {
|
||||||
.exchange_code(AuthorizationCode::new(code))
|
.exchange_code(AuthorizationCode::new(code))
|
||||||
.set_pkce_verifier(state.pkce_verifier)
|
.set_pkce_verifier(state.pkce_verifier)
|
||||||
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
||||||
self.domain
|
self.domain.join(&format!("login/redirect/{}", id)).unwrap(),
|
||||||
.join(&format!("api/login/redirect/{}", id))
|
|
||||||
.unwrap(),
|
|
||||||
)))
|
)))
|
||||||
.request_async(openidconnect::reqwest::async_http_client)
|
.request_async(openidconnect::reqwest::async_http_client)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
@ -344,7 +346,6 @@ impl OpenidConnector {
|
||||||
.map(|v| v.to_string())
|
.map(|v| v.to_string())
|
||||||
.or(claims.preferred_username().map(|v| v.to_string()))
|
.or(claims.preferred_username().map(|v| v.to_string()))
|
||||||
.ok_or_else(|| anyhow!("No name or preferred_username"))?,
|
.ok_or_else(|| anyhow!("No name or preferred_username"))?,
|
||||||
source_url: state.source_url,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -363,7 +364,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
|
|
||||||
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
|
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
|
||||||
|
|
||||||
let mut opt = ConnectOptions::new(config.database_url);
|
let mut opt = ConnectOptions::new(&config.database_url);
|
||||||
opt.sqlx_logging(config.sqlx_logging);
|
opt.sqlx_logging(config.sqlx_logging);
|
||||||
|
|
||||||
let oidc = match config.oidc {
|
let oidc = match config.oidc {
|
||||||
|
|
@ -371,9 +372,32 @@ async fn main() -> anyhow::Result<()> {
|
||||||
Some(settings) => Some(OpenidConnector::new(settings).await?),
|
Some(settings) => Some(OpenidConnector::new(settings).await?),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let sessions = PostgresStore::new(PgPool::connect(&config.database_url).await?);
|
||||||
|
sessions.migrate().await?;
|
||||||
|
|
||||||
|
let deletion_task = tokio::task::spawn(
|
||||||
|
sessions
|
||||||
|
.clone()
|
||||||
|
.continuously_delete_expired(tokio::time::Duration::from_secs(60)),
|
||||||
|
);
|
||||||
|
|
||||||
|
let session_service = ServiceBuilder::new()
|
||||||
|
.layer(HandleErrorLayer::new(|_: BoxError| async {
|
||||||
|
StatusCode::BAD_REQUEST
|
||||||
|
}))
|
||||||
|
.layer(
|
||||||
|
SessionManagerLayer::new(sessions.clone())
|
||||||
|
.with_secure(false)
|
||||||
|
.with_same_site(tower_sessions::cookie::SameSite::Lax)
|
||||||
|
.with_expiry(tower_sessions::Expiry::AtDateTime(
|
||||||
|
OffsetDateTime::now_utc().saturating_add(SESSION_DURATION),
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
|
||||||
let state = Arc::new(AppState {
|
let state = Arc::new(AppState {
|
||||||
jwt_secret: config.jwt_secret,
|
jwt_secret: config.jwt_secret,
|
||||||
db: Database::connect(opt).await?,
|
db: Database::connect(opt).await?,
|
||||||
|
sessions: sessions.into(),
|
||||||
oidc,
|
oidc,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -387,7 +411,9 @@ async fn main() -> anyhow::Result<()> {
|
||||||
state.oidc.is_some(),
|
state.oidc.is_some(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_state(state);
|
.merge(app::router())
|
||||||
|
.with_state(state)
|
||||||
|
.layer(session_service);
|
||||||
|
|
||||||
let router = match config.serve_app {
|
let router = match config.serve_app {
|
||||||
None => router,
|
None => router,
|
||||||
|
|
@ -398,7 +424,10 @@ async fn main() -> anyhow::Result<()> {
|
||||||
|
|
||||||
tracing::info!("Listening on {addr}");
|
tracing::info!("Listening on {addr}");
|
||||||
|
|
||||||
Ok(axum::Server::bind(&addr)
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||||
.serve(router.into_make_service())
|
axum::serve(listener, router).await?;
|
||||||
.await?)
|
|
||||||
|
deletion_task.await??;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ use api::{LoginRequest, LoginResponse, UserInfo};
|
||||||
use axum::{
|
use axum::{
|
||||||
async_trait,
|
async_trait,
|
||||||
extract::{FromRef, FromRequestParts, Path, Query, State},
|
extract::{FromRef, FromRequestParts, Path, Query, State},
|
||||||
headers::{authorization::Bearer, Authorization},
|
//headers::{authorization::Bearer, Authorization},
|
||||||
http::{
|
http::{
|
||||||
header::{AUTHORIZATION, CONTENT_TYPE},
|
header::{AUTHORIZATION, CONTENT_TYPE},
|
||||||
request::Parts,
|
request::Parts,
|
||||||
|
|
@ -12,7 +12,8 @@ use axum::{
|
||||||
},
|
},
|
||||||
response::{IntoResponse, Redirect},
|
response::{IntoResponse, Redirect},
|
||||||
routing::{delete, get, patch, post, put},
|
routing::{delete, get, patch, post, put},
|
||||||
Json, Router, TypedHeader,
|
Json,
|
||||||
|
Router,
|
||||||
};
|
};
|
||||||
use jwt_simple::prelude::*;
|
use jwt_simple::prelude::*;
|
||||||
use sea_orm::{prelude::*, ActiveValue, TransactionError};
|
use sea_orm::{prelude::*, ActiveValue, TransactionError};
|
||||||
|
|
@ -108,27 +109,28 @@ where
|
||||||
type Rejection = RouteError;
|
type Rejection = RouteError;
|
||||||
|
|
||||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
let State(app_state): State<AppState> = State::from_request_parts(parts, state)
|
// let State(app_state): State<AppState> = State::from_request_parts(parts, state)
|
||||||
.await
|
// .await
|
||||||
.expect("Could not get state");
|
// .expect("Could not get state");
|
||||||
|
//
|
||||||
let TypedHeader(Authorization(bearer)) =
|
// let TypedHeader(Authorization(bearer)) =
|
||||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
|
// TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
|
||||||
.await
|
// .await
|
||||||
.map_err(|_| RouteError::MissingAuthorization)?;
|
// .map_err(|_| RouteError::MissingAuthorization)?;
|
||||||
|
//
|
||||||
let claims = app_state
|
// let claims = app_state
|
||||||
.jwt_secret
|
// .jwt_secret
|
||||||
.0
|
// .0
|
||||||
.verify_token::<NoCustomClaims>(bearer.token(), None)
|
// .verify_token::<NoCustomClaims>(bearer.token(), None)
|
||||||
.map_err(RouteError::UserJwt)?;
|
// .map_err(RouteError::UserJwt)?;
|
||||||
|
//
|
||||||
let model = User::find_by_id(claims.subject.unwrap().parse::<Uuid>().unwrap())
|
// let model = User::find_by_id(claims.subject.unwrap().parse::<Uuid>().unwrap())
|
||||||
.one(&app_state.db)
|
// .one(&app_state.db)
|
||||||
.await?
|
// .await?
|
||||||
.unwrap();
|
// .unwrap();
|
||||||
|
//
|
||||||
Ok(AuthenticatedUser { model })
|
// Ok(AuthenticatedUser { model })
|
||||||
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -169,14 +171,11 @@ struct OidcStartParam {
|
||||||
r#return: String,
|
r#return: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn oidc_login(
|
async fn oidc_login(State(state): State<AppState>) -> Result<Redirect, RouteError> {
|
||||||
State(state): State<AppState>,
|
|
||||||
Query(start): Query<OidcStartParam>,
|
|
||||||
) -> Result<Redirect, RouteError> {
|
|
||||||
tracing::info!("Starting OIDC login");
|
tracing::info!("Starting OIDC login");
|
||||||
let oidc = state.oidc.as_ref().unwrap();
|
let oidc = state.oidc.as_ref().unwrap();
|
||||||
|
|
||||||
let redirect_url = oidc.start_auth(start.r#return);
|
let redirect_url = oidc.start_auth();
|
||||||
|
|
||||||
Ok(Redirect::to(redirect_url.as_str()))
|
Ok(Redirect::to(redirect_url.as_str()))
|
||||||
}
|
}
|
||||||
|
|
@ -236,16 +235,9 @@ async fn oidc_login_finish(
|
||||||
let mut claims = Claims::create(Duration::from_secs(3600 * 24 * 31 * 6));
|
let mut claims = Claims::create(Duration::from_secs(3600 * 24 * 31 * 6));
|
||||||
claims.subject = Some(user.id.to_string());
|
claims.subject = Some(user.id.to_string());
|
||||||
|
|
||||||
let token = state.jwt_secret.0.authenticate(claims)?;
|
let _token = state.jwt_secret.0.authenticate(claims)?;
|
||||||
|
|
||||||
let redirect = format!(
|
panic!("Oidc login app only");
|
||||||
"{}/{}---{}",
|
|
||||||
account.source_url,
|
|
||||||
urlencoding::encode(&user.name),
|
|
||||||
urlencoding::encode(&token),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(Redirect::to(&redirect))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue