From ec03532a88ce29d8059a5ea353098b569fc2e0e5 Mon Sep 17 00:00:00 2001 From: Quentin Boyer Date: Mon, 30 Dec 2024 22:30:27 +0100 Subject: [PATCH] Update login flow to have a fixed URL --- Cargo.lock | 54 +++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 3 ++- src/app/mod.rs | 26 +++++++++++++++++------ src/main.rs | 17 +++++++++------ src/routes/mod.rs | 34 +++++++++++++++++++++-------- 5 files changed, 110 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 18392e1..5b20691 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -351,6 +351,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", + "axum-macros", "bytes", "futures-util", "http 1.2.0", @@ -398,6 +399,41 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c794b30c904f0a1c2fb7740f7df7f7972dfaa14ef6f57cb6178dc63e5dca2f04" +dependencies = [ + "axum", + "axum-core", + "bytes", + "cookie", + "fastrand", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "multer", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.93", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -2224,6 +2260,23 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "multer" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http 1.2.0", + "httparse", + "memchr", + "mime", + "spin", + "version_check", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -2843,6 +2896,7 @@ dependencies = [ "anyhow", "api", "axum", + "axum-extra", "base64 0.22.1", "envious", "jwt-simple", diff --git a/Cargo.toml b/Cargo.toml index be2f0b8..b450b38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,8 @@ members = [".", "api", "migration"] [dependencies] anyhow = "1.0.95" -axum = { version = "0.7.9" } +axum = { version = "0.7.9", features = ["macros"] } +axum-extra = { version = "0.9", features = ["cookie"] } base64 = "0.22.1" jwt-simple = "0.12.11" serde = { version = "1.0.217", features = ["derive"] } diff --git a/src/app/mod.rs b/src/app/mod.rs index 5ef9aed..7913e3a 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -2,13 +2,14 @@ use std::sync::Arc; use axum::{ async_trait, - extract::{FromRef, FromRequestParts, Path, Query, State}, + extract::{FromRef, FromRequestParts, Query, State}, handler::HandlerWithoutStateExt, http::{request::Parts, StatusCode}, response::{IntoResponse, Redirect}, routing::get, Router, }; +use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use maud::{html, Markup}; use sea_orm::{prelude::*, ActiveValue, DbErr, TransactionError}; use serde::Deserialize; @@ -21,9 +22,9 @@ use crate::entity::{prelude::*, user}; use self::{household::CurrentHousehold, sidebar::SidebarLocation}; mod household; +mod ingredients; mod recipe; mod sidebar; -mod ingredients; type AppState = Arc; @@ -156,13 +157,20 @@ impl IntoResponse for RouteError { } } -async fn oidc_login(State(state): State) -> Redirect { +#[axum::debug_handler] +async fn oidc_login(State(state): State, jar: CookieJar) -> (CookieJar, Redirect) { tracing::info!("Starting OIDC login"); let oidc = state.oidc.as_ref().unwrap(); - let redirect_url = oidc.start_auth(); + let (flow_id, redirect_url) = oidc.start_auth(); + let jar = jar.add( + Cookie::build(("login_flow_id", flow_id.to_string())) + .secure(true) + .same_site(SameSite::Lax) + .build(), + ); - Redirect::to(redirect_url.as_str()) + (jar, Redirect::to(redirect_url.as_str())) } enum RedirectOrError { @@ -236,10 +244,14 @@ struct OidcRedirectParams { async fn oidc_login_finish( State(state): State, - Path(id): Path, Query(redirect): Query, session: Session, + jar: CookieJar, ) -> Result { + let Some(Ok(id)) = jar.get("login_flow_id").map(|c| c.value().parse()) else { + return Err(RouteError::Oauth2Failure); + }; + match state .oidc .as_ref() @@ -370,7 +382,7 @@ pub(crate) fn router() -> Router { .route("/", get(index)) .route("/login", get(oidc_login)) .route("/logout", get(logout)) - .route("/login/redirect/:id", get(oidc_login_finish)) + .route("/login/redirect", get(oidc_login_finish)) .nest("/household", household::routes()) .nest("/recipe", recipe::routes()) .nest("/ingredients", ingredients::routes()) diff --git a/src/main.rs b/src/main.rs index 3c2eb1f..f24eaa5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -241,6 +241,10 @@ pub struct OpenidAccount { } impl OpenidConnector { + fn redirect_uri(&self) -> Url { + self.domain.join("login/redirect").unwrap() + } + async fn new(settings: OpenidConnectSettings) -> anyhow::Result { let metadata = CoreProviderMetadata::discover_async( IssuerUrl::new(settings.url)?, @@ -259,14 +263,15 @@ impl OpenidConnector { }) } - pub fn start_auth(&self) -> Url { + pub fn start_auth(&self) -> (Uuid, Url) { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let mut inflight = self.inflight.lock(); let slot = inflight.new_entry(); + let id = slot.id; - tracing::info!("Login flow with id {}", slot.id); + tracing::info!("Login flow with id {id}"); let (url, csrf_token, nonce) = self .provider @@ -278,9 +283,7 @@ impl OpenidConnector { .add_scopes(self.scopes.iter().cloned()) .set_pkce_challenge(pkce_challenge) .set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url( - self.domain - .join(&format!("login/redirect/{}", slot.id)) - .unwrap(), + self.redirect_uri(), ))) .url(); @@ -290,7 +293,7 @@ impl OpenidConnector { nonce, }); - url + (id, url) } pub async fn redirected( @@ -317,7 +320,7 @@ impl OpenidConnector { .exchange_code(AuthorizationCode::new(code)) .set_pkce_verifier(state.pkce_verifier) .set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url( - self.domain.join(&format!("login/redirect/{}", id)).unwrap(), + self.redirect_uri(), ))) .request_async(openidconnect::reqwest::async_http_client) .await?; diff --git a/src/routes/mod.rs b/src/routes/mod.rs index a42f86a..09bad15 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -15,6 +15,7 @@ use axum::{ Json, Router, }; +use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use jwt_simple::prelude::*; use sea_orm::{prelude::*, ActiveValue, TransactionError}; use sha2::{Digest, Sha512}; @@ -141,8 +142,9 @@ async fn login( let Some(user) = User::find() .filter(user::Column::Name.eq(&req.username)) .one(&state.db) - .await? else { - return Err(RouteError::UnknownAccount) + .await? + else { + return Err(RouteError::UnknownAccount); }; let Some(password) = user.password.as_ref() else { @@ -171,13 +173,22 @@ struct OidcStartParam { r#return: String, } -async fn oidc_login(State(state): State) -> Result { +async fn oidc_login( + State(state): State, + jar: CookieJar, +) -> Result<(CookieJar, Redirect), RouteError> { tracing::info!("Starting OIDC login"); let oidc = state.oidc.as_ref().unwrap(); - let redirect_url = oidc.start_auth(); + let (flow_id, redirect_url) = oidc.start_auth(); + let jar = jar.add( + Cookie::build(("login_flow_id", flow_id.to_string())) + .secure(true) + .same_site(SameSite::Lax) + .build(), + ); - Ok(Redirect::to(redirect_url.as_str())) + Ok((jar, Redirect::to(redirect_url.as_str()))) } #[derive(Deserialize)] @@ -188,9 +199,13 @@ struct OidcRedirectParams { async fn oidc_login_finish( State(state): State, - Path(id): Path, Query(redirect): Query, + jar: CookieJar, ) -> Result { + let Some(Ok(id)) = jar.get("login_flow_id").map(|c| c.value().parse()) else { + return Err(RouteError::Unauthorized); + }; + match state .oidc .as_ref() @@ -250,8 +265,9 @@ async fn get_user_id( let Some(user) = User::find() .filter(user::Column::Name.eq(name)) .one(&state.db) - .await? else { - return Ok(Err(StatusCode::NOT_FOUND)) + .await? + else { + return Ok(Err(StatusCode::NOT_FOUND)); }; Ok(Ok(Json(UserInfo { @@ -343,7 +359,7 @@ pub(crate) fn router(api_allowed: Option, has_oidc: bool) -> Router "/login/has_oidc", get(unit).layer(mk_service(vec![Method::GET])), ) - .route("/login/redirect/:id", get(oidc_login_finish)) + .route("/login/redirect", get(oidc_login_finish)) } else { router }