use std::sync::Arc; use axum::{ async_trait, extract::{FromRef, FromRequestParts, Path, Query, State}, handler::HandlerWithoutStateExt, http::{request::Parts, StatusCode}, response::{IntoResponse, Redirect}, routing::{get, post}, Router, }; use maud::{html, Markup}; use sea_orm::{prelude::*, ActiveValue, DbErr, TransactionError}; use serde::Deserialize; use tower_http::services::ServeDir; use tower_sessions::{session, Session}; use uuid::Uuid; use crate::entity::{prelude::*, user}; use self::{household::CurrentHousehold, sidebar::SidebarLocation}; mod household; mod sidebar; type AppState = Arc; pub fn base_page_with_head(body: Markup, head: Option) -> Markup { html! { (maud::DOCTYPE) html lang="en" data-bs-theme="dark" { head { meta charset="utf-8"; meta name="viewport" content="width=device-width, initial-scale=1"; title { "Regalade" } link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous"; link rel="stylesheet" href="/regalade.css"; link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.11.3/font/bootstrap-icons.min.css" integrity="sha384-XGjxtQfXaH2tnPFa9x+ruJTuLE3Aa6LhHSWRr1XeTyhezb4abCG4ccI5AkVDxqC+" crossorigin="anonymous"; @if let Some(head) = head { (head) } } body { (body) script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/js/bootstrap.bundle.min.js" integrity="sha384-C6RzsynM9kWDrMNeT87bh95OGNyZPhcTNXj1NW7RuBCsyN/o0jlpcV8Qyq46cDfL" crossorigin="anonymous" {} script src="https://unpkg.com/htmx.org@1.9.10" integrity="sha384-D1Kt99CQMDuVetoL1lrYwg5t+9QdHe7NLX/SoJYkXDFfX37iInKRy5xLSi8nO7UC" crossorigin="anonymous" {} } } } } pub fn base_page(body: Markup) -> Markup { base_page_with_head(body, None) } pub fn error_alert(message: Option) -> Markup { html! { @if let Some(msg) = message { .alert.alert-danger { (msg) } } } } // pub fn base_page_with_error(error: Option, body: Markup) -> Markup { // base_page_with_head( // html! { // @if let Some(e) = error { // (e) // } // (body) // }, // None, // ) // } pub fn error_page(code: StatusCode, message: impl maud::Render) -> (StatusCode, Markup) { ( code, base_page(html! { div .container.text-center { h1 { "Regalade" } h2 { "Error" } hr; (message) } }), ) } #[derive(thiserror::Error, Debug)] enum RouteError { #[error("Database encountered an error")] Db(#[from] DbErr), #[error("Failure to login with OAuth2")] Oauth2Failure, #[error("Could not fetch required value from path")] PathRejection(#[from] axum::extract::rejection::PathRejection), #[error("The supplied ressource does not exist")] RessourceNotFound, #[error("The request was malformed")] InvalidRequest(String), #[error("Error in DB transaction")] TxnError(#[from] TransactionError>), #[error("Error in session management")] Session(#[from] session::Error), #[error("Could not extract session")] SessionExtract, } impl From for Box { fn from(value: DbErr) -> Self { Box::new(value.into()) } } impl From> for RouteError { fn from(value: TransactionError) -> Self { match value { TransactionError::Connection(e) => e.into(), TransactionError::Transaction(e) => e, } } } impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { match self { RouteError::TxnError(TransactionError::Transaction(e)) => e.into_response(), RouteError::PathRejection(p) => error_page(p.status(), p.body_text()).into_response(), RouteError::Oauth2Failure => { error_page(StatusCode::BAD_REQUEST, "Failure to login with OAuth2").into_response() } RouteError::RessourceNotFound => not_found().into_response(), RouteError::InvalidRequest(reason) => { error_page(StatusCode::BAD_REQUEST, reason).into_response() } e => { tracing::error!("Internal error: {e:?}"); error_page(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() } } } } async fn oidc_login(State(state): State) -> Redirect { tracing::info!("Starting OIDC login"); let oidc = state.oidc.as_ref().unwrap(); let redirect_url = oidc.start_auth(); Redirect::to(redirect_url.as_str()) } enum RedirectOrError { Redirect(Redirect), Err(RouteError), } impl From for RedirectOrError { fn from(value: Redirect) -> Self { Self::Redirect(value) } } impl From for RedirectOrError { fn from(value: RouteError) -> Self { Self::Err(value) } } impl IntoResponse for RedirectOrError { fn into_response(self) -> axum::response::Response { match self { RedirectOrError::Redirect(r) => r.into_response(), RedirectOrError::Err(e) => e.into_response(), } } } struct AuthenticatedUser { pub model: user::Model, } #[async_trait] impl FromRequestParts for AuthenticatedUser where S: Send + Sync, AppState: FromRef, { type Rejection = RedirectOrError; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let State(app_state): State = State::from_request_parts(parts, state) .await .expect("Could not get state"); let session = Session::from_request_parts(parts, state) .await .map_err(|_| RouteError::SessionExtract)?; let id: Uuid = session .get("id") .await .map_err(RouteError::from)? .ok_or_else(|| Redirect::to("/login"))?; let model = User::find_by_id(id) .one(&app_state.db) .await .map_err(RouteError::from)? .unwrap(); Ok(Self { model }) } } #[derive(Deserialize)] struct OidcRedirectParams { state: String, code: String, } async fn oidc_login_finish( State(state): State, Path(id): Path, Query(redirect): Query, session: Session, ) -> Result { match state .oidc .as_ref() .unwrap() .redirected(id, redirect.state, redirect.code) .await { Err(e) => { tracing::error!("Error when finishing OAuth2 flow: {e:#?}"); Err(RouteError::Oauth2Failure) } Ok(account) => { let user = User::find() .filter( user::Column::Name .eq(&account.name) .or(user::Column::OpenIdSubject.eq(&account.sub)), ) .one(&state.db) .await?; let user = match user { None => { let model = user::ActiveModel { id: ActiveValue::Set(Uuid::new_v4()), name: ActiveValue::Set(account.name), password: ActiveValue::NotSet, open_id_subject: ActiveValue::Set(Some(account.sub)), }; model.insert(&state.db).await? } Some(user) => user, }; session.insert("id", user.id).await?; Ok(Redirect::to("/")) } } } fn not_found() -> (StatusCode, Markup) { error_page(StatusCode::NOT_FOUND, "Page not found") } fn confirm_danger_modal(id: &str, inner: &str, action: &str, title: &str) -> Markup { html! { .modal #(id) tabindex="-1" aria-labelledby={(id) "Label"} aria-hidden="true" { .modal-dialog.modal-dialog-centered { .modal-content { .modal-header { h1 .modal-title."fs-5" #{(id) "Label"} { (title) } button .btn-close data-bs-dismiss="modal" aria-label="Cancel" {} } .modal-body { (inner) } .modal-footer { button .btn.btn-secondary data-bs-dismiss="modal" { "Cancel" } form action=(action) method="post" .inline { button type="submit" .btn.btn-danger { "Confirm" } } } } } } } } async fn index(user: AuthenticatedUser, household: CurrentHousehold) -> Markup { sidebar::sidebar( SidebarLocation::Home, &household, &user, html! { "Hello world in " (household.0.name) "!" }, ) } async fn logout(session: Session) -> Result { session.delete().await?; Ok(Redirect::to("/login")) } // #[derive(Serialize, Deserialize, Debug)] // enum UserError { // } // // #[derive(Debug, Serialize, Deserialize)] // struct ErrorQuery { // #[serde(default)] // err: Option, // } // // impl Render for ErrorQuery { // fn render(&self) -> Markup { // let err = match &self.err { // Some(e) => e, // None => return html! {}, // }; // // match err { // _ => todo!(), // } // } // } pub(crate) fn router() -> Router { let router = Router::new(); let public = option_env!("REGALADE_PUBLIC_DIR") .map(std::path::PathBuf::from) .unwrap_or_else(|| std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("public")); tracing::debug!("Public directory: {public:?}"); router .route("/", get(index)) .route("/login", get(oidc_login)) .route("/logout", get(logout)) .route("/login/redirect/:id", get(oidc_login_finish)) .nest("/household", household::routes()) .fallback_service(ServeDir::new(public).fallback((|| async { not_found() }).into_service())) }