Update login flow to have a fixed URL

This commit is contained in:
Quentin Boyer 2024-12-30 22:30:27 +01:00
parent 53ec443972
commit ec03532a88
5 changed files with 110 additions and 24 deletions

54
Cargo.lock generated
View file

@ -351,6 +351,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
"axum-macros",
"bytes", "bytes",
"futures-util", "futures-util",
"http 1.2.0", "http 1.2.0",
@ -398,6 +399,41 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "axum-extra"
version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c794b30c904f0a1c2fb7740f7df7f7972dfaa14ef6f57cb6178dc63e5dca2f04"
dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"fastrand",
"futures-util",
"http 1.2.0",
"http-body 1.0.1",
"http-body-util",
"mime",
"multer",
"pin-project-lite",
"serde",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-macros"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.93",
]
[[package]] [[package]]
name = "backtrace" name = "backtrace"
version = "0.3.74" version = "0.3.74"
@ -2224,6 +2260,23 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "multer"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b"
dependencies = [
"bytes",
"encoding_rs",
"futures-util",
"http 1.2.0",
"httparse",
"memchr",
"mime",
"spin",
"version_check",
]
[[package]] [[package]]
name = "new_debug_unreachable" name = "new_debug_unreachable"
version = "1.0.6" version = "1.0.6"
@ -2843,6 +2896,7 @@ dependencies = [
"anyhow", "anyhow",
"api", "api",
"axum", "axum",
"axum-extra",
"base64 0.22.1", "base64 0.22.1",
"envious", "envious",
"jwt-simple", "jwt-simple",

View file

@ -9,7 +9,8 @@ members = [".", "api", "migration"]
[dependencies] [dependencies]
anyhow = "1.0.95" anyhow = "1.0.95"
axum = { version = "0.7.9" } axum = { version = "0.7.9", features = ["macros"] }
axum-extra = { version = "0.9", features = ["cookie"] }
base64 = "0.22.1" base64 = "0.22.1"
jwt-simple = "0.12.11" jwt-simple = "0.12.11"
serde = { version = "1.0.217", features = ["derive"] } serde = { version = "1.0.217", features = ["derive"] }

View file

@ -2,13 +2,14 @@ use std::sync::Arc;
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRef, FromRequestParts, Path, Query, State}, extract::{FromRef, FromRequestParts, Query, State},
handler::HandlerWithoutStateExt, handler::HandlerWithoutStateExt,
http::{request::Parts, StatusCode}, http::{request::Parts, StatusCode},
response::{IntoResponse, Redirect}, response::{IntoResponse, Redirect},
routing::get, routing::get,
Router, Router,
}; };
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
use maud::{html, Markup}; use maud::{html, Markup};
use sea_orm::{prelude::*, ActiveValue, DbErr, TransactionError}; use sea_orm::{prelude::*, ActiveValue, DbErr, TransactionError};
use serde::Deserialize; use serde::Deserialize;
@ -21,9 +22,9 @@ use crate::entity::{prelude::*, user};
use self::{household::CurrentHousehold, sidebar::SidebarLocation}; use self::{household::CurrentHousehold, sidebar::SidebarLocation};
mod household; mod household;
mod ingredients;
mod recipe; mod recipe;
mod sidebar; mod sidebar;
mod ingredients;
type AppState = Arc<crate::AppState>; type AppState = Arc<crate::AppState>;
@ -156,13 +157,20 @@ impl IntoResponse for RouteError {
} }
} }
async fn oidc_login(State(state): State<AppState>) -> Redirect { #[axum::debug_handler]
async fn oidc_login(State(state): State<AppState>, jar: CookieJar) -> (CookieJar, Redirect) {
tracing::info!("Starting OIDC login"); tracing::info!("Starting OIDC login");
let oidc = state.oidc.as_ref().unwrap(); let oidc = state.oidc.as_ref().unwrap();
let redirect_url = oidc.start_auth(); let (flow_id, redirect_url) = oidc.start_auth();
let jar = jar.add(
Cookie::build(("login_flow_id", flow_id.to_string()))
.secure(true)
.same_site(SameSite::Lax)
.build(),
);
Redirect::to(redirect_url.as_str()) (jar, Redirect::to(redirect_url.as_str()))
} }
enum RedirectOrError { enum RedirectOrError {
@ -236,10 +244,14 @@ struct OidcRedirectParams {
async fn oidc_login_finish( async fn oidc_login_finish(
State(state): State<AppState>, State(state): State<AppState>,
Path(id): Path<Uuid>,
Query(redirect): Query<OidcRedirectParams>, Query(redirect): Query<OidcRedirectParams>,
session: Session, session: Session,
jar: CookieJar,
) -> Result<Redirect, RouteError> { ) -> Result<Redirect, RouteError> {
let Some(Ok(id)) = jar.get("login_flow_id").map(|c| c.value().parse()) else {
return Err(RouteError::Oauth2Failure);
};
match state match state
.oidc .oidc
.as_ref() .as_ref()
@ -370,7 +382,7 @@ pub(crate) fn router() -> Router<AppState> {
.route("/", get(index)) .route("/", get(index))
.route("/login", get(oidc_login)) .route("/login", get(oidc_login))
.route("/logout", get(logout)) .route("/logout", get(logout))
.route("/login/redirect/:id", get(oidc_login_finish)) .route("/login/redirect", get(oidc_login_finish))
.nest("/household", household::routes()) .nest("/household", household::routes())
.nest("/recipe", recipe::routes()) .nest("/recipe", recipe::routes())
.nest("/ingredients", ingredients::routes()) .nest("/ingredients", ingredients::routes())

View file

@ -241,6 +241,10 @@ pub struct OpenidAccount {
} }
impl OpenidConnector { impl OpenidConnector {
fn redirect_uri(&self) -> Url {
self.domain.join("login/redirect").unwrap()
}
async fn new(settings: OpenidConnectSettings) -> anyhow::Result<Self> { async fn new(settings: OpenidConnectSettings) -> anyhow::Result<Self> {
let metadata = CoreProviderMetadata::discover_async( let metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(settings.url)?, IssuerUrl::new(settings.url)?,
@ -259,14 +263,15 @@ impl OpenidConnector {
}) })
} }
pub fn start_auth(&self) -> Url { pub fn start_auth(&self) -> (Uuid, Url) {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut inflight = self.inflight.lock(); let mut inflight = self.inflight.lock();
let slot = inflight.new_entry(); let slot = inflight.new_entry();
let id = slot.id;
tracing::info!("Login flow with id {}", slot.id); tracing::info!("Login flow with id {id}");
let (url, csrf_token, nonce) = self let (url, csrf_token, nonce) = self
.provider .provider
@ -278,9 +283,7 @@ impl OpenidConnector {
.add_scopes(self.scopes.iter().cloned()) .add_scopes(self.scopes.iter().cloned())
.set_pkce_challenge(pkce_challenge) .set_pkce_challenge(pkce_challenge)
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url( .set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
self.domain self.redirect_uri(),
.join(&format!("login/redirect/{}", slot.id))
.unwrap(),
))) )))
.url(); .url();
@ -290,7 +293,7 @@ impl OpenidConnector {
nonce, nonce,
}); });
url (id, url)
} }
pub async fn redirected( pub async fn redirected(
@ -317,7 +320,7 @@ impl OpenidConnector {
.exchange_code(AuthorizationCode::new(code)) .exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(state.pkce_verifier) .set_pkce_verifier(state.pkce_verifier)
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url( .set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
self.domain.join(&format!("login/redirect/{}", id)).unwrap(), self.redirect_uri(),
))) )))
.request_async(openidconnect::reqwest::async_http_client) .request_async(openidconnect::reqwest::async_http_client)
.await?; .await?;

View file

@ -15,6 +15,7 @@ use axum::{
Json, Json,
Router, Router,
}; };
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
use jwt_simple::prelude::*; use jwt_simple::prelude::*;
use sea_orm::{prelude::*, ActiveValue, TransactionError}; use sea_orm::{prelude::*, ActiveValue, TransactionError};
use sha2::{Digest, Sha512}; use sha2::{Digest, Sha512};
@ -141,8 +142,9 @@ async fn login(
let Some(user) = User::find() let Some(user) = User::find()
.filter(user::Column::Name.eq(&req.username)) .filter(user::Column::Name.eq(&req.username))
.one(&state.db) .one(&state.db)
.await? else { .await?
return Err(RouteError::UnknownAccount) else {
return Err(RouteError::UnknownAccount);
}; };
let Some(password) = user.password.as_ref() else { let Some(password) = user.password.as_ref() else {
@ -171,13 +173,22 @@ struct OidcStartParam {
r#return: String, r#return: String,
} }
async fn oidc_login(State(state): State<AppState>) -> Result<Redirect, RouteError> { async fn oidc_login(
State(state): State<AppState>,
jar: CookieJar,
) -> Result<(CookieJar, Redirect), RouteError> {
tracing::info!("Starting OIDC login"); tracing::info!("Starting OIDC login");
let oidc = state.oidc.as_ref().unwrap(); let oidc = state.oidc.as_ref().unwrap();
let redirect_url = oidc.start_auth(); let (flow_id, redirect_url) = oidc.start_auth();
let jar = jar.add(
Cookie::build(("login_flow_id", flow_id.to_string()))
.secure(true)
.same_site(SameSite::Lax)
.build(),
);
Ok(Redirect::to(redirect_url.as_str())) Ok((jar, Redirect::to(redirect_url.as_str())))
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -188,9 +199,13 @@ struct OidcRedirectParams {
async fn oidc_login_finish( async fn oidc_login_finish(
State(state): State<AppState>, State(state): State<AppState>,
Path(id): Path<Uuid>,
Query(redirect): Query<OidcRedirectParams>, Query(redirect): Query<OidcRedirectParams>,
jar: CookieJar,
) -> Result<Redirect, RouteError> { ) -> Result<Redirect, RouteError> {
let Some(Ok(id)) = jar.get("login_flow_id").map(|c| c.value().parse()) else {
return Err(RouteError::Unauthorized);
};
match state match state
.oidc .oidc
.as_ref() .as_ref()
@ -250,8 +265,9 @@ async fn get_user_id(
let Some(user) = User::find() let Some(user) = User::find()
.filter(user::Column::Name.eq(name)) .filter(user::Column::Name.eq(name))
.one(&state.db) .one(&state.db)
.await? else { .await?
return Ok(Err(StatusCode::NOT_FOUND)) else {
return Ok(Err(StatusCode::NOT_FOUND));
}; };
Ok(Ok(Json(UserInfo { Ok(Ok(Json(UserInfo {
@ -343,7 +359,7 @@ pub(crate) fn router(api_allowed: Option<HeaderValue>, has_oidc: bool) -> Router
"/login/has_oidc", "/login/has_oidc",
get(unit).layer(mk_service(vec![Method::GET])), get(unit).layer(mk_service(vec![Method::GET])),
) )
.route("/login/redirect/:id", get(oidc_login_finish)) .route("/login/redirect", get(oidc_login_finish))
} else { } else {
router router
} }